| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- import logging
- import os
- import threading
- import time
- from .start_hook_base import RayOnSparkStartHook
- from .utils import get_spark_session
- _logger = logging.getLogger(__name__)
- DATABRICKS_HOST = "DATABRICKS_HOST"
- DATABRICKS_TOKEN = "DATABRICKS_TOKEN"
- DATABRICKS_CLIENT_ID = "DATABRICKS_CLIENT_ID"
- DATABRICKS_CLIENT_SECRET = "DATABRICKS_CLIENT_SECRET"
- def verify_databricks_auth_env():
- return (DATABRICKS_HOST in os.environ and DATABRICKS_TOKEN in os.environ) or (
- DATABRICKS_HOST in os.environ
- and DATABRICKS_CLIENT_ID in os.environ
- and DATABRICKS_CLIENT_SECRET in os.environ
- )
- def get_databricks_function(func_name):
- import IPython
- ip_shell = IPython.get_ipython()
- if ip_shell is None:
- raise RuntimeError("No IPython environment.")
- return ip_shell.ns_table["user_global"][func_name]
- def get_databricks_display_html_function():
- return get_databricks_function("displayHTML")
- def get_db_entry_point():
- """
- Return databricks entry_point instance, it is for calling some
- internal API in databricks runtime
- """
- from dbruntime import UserNamespaceInitializer
- user_namespace_initializer = UserNamespaceInitializer.getOrCreate()
- return user_namespace_initializer.get_spark_entry_point()
- def display_databricks_driver_proxy_url(spark_context, port, title):
- """
- This helper function create a proxy URL for databricks driver webapp forwarding.
- In databricks runtime, user does not have permission to directly access web
- service binding on driver machine port, but user can visit it by a proxy URL with
- following format: "/driver-proxy/o/{orgId}/{clusterId}/{port}/".
- """
- driverLocal = spark_context._jvm.com.databricks.backend.daemon.driver.DriverLocal
- commandContextTags = driverLocal.commandContext().get().toStringMap().apply("tags")
- orgId = commandContextTags.apply("orgId")
- clusterId = commandContextTags.apply("clusterId")
- proxy_link = f"/driver-proxy/o/{orgId}/{clusterId}/{port}/"
- proxy_url = f"https://dbc-dp-{orgId}.cloud.databricks.com{proxy_link}"
- print("To monitor and debug Ray from Databricks, view the dashboard at ")
- print(f" {proxy_url}")
- get_databricks_display_html_function()(
- f"""
- <div style="margin-top: 16px;margin-bottom: 16px">
- <a href="{proxy_link}">
- Open {title} in a new tab
- </a>
- </div>
- """
- )
- DATABRICKS_AUTO_SHUTDOWN_POLL_INTERVAL_SECONDS = 3
- DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES = (
- "DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES"
- )
- _DATABRICKS_DEFAULT_TMP_ROOT_DIR = "/local_disk0/tmp"
- class DefaultDatabricksRayOnSparkStartHook(RayOnSparkStartHook):
- def get_default_temp_root_dir(self):
- return _DATABRICKS_DEFAULT_TMP_ROOT_DIR
- def on_ray_dashboard_created(self, port):
- display_databricks_driver_proxy_url(
- get_spark_session().sparkContext, port, "Ray Cluster Dashboard"
- )
- def on_cluster_created(self, ray_cluster_handler):
- db_api_entry = get_db_entry_point()
- if self.is_global:
- # Disable auto shutdown if
- # 1) autoscaling enabled
- # because in autoscaling mode, background spark job will be killed
- # automatically when ray cluster is idle.
- # 2) global mode cluster
- # Because global mode cluster is designed to keep running until
- # user request to shut down it, and global mode cluster is shared
- # by other users, the code here cannot track usage from other users
- # so that we don't know whether it is safe to shut down the global
- # cluster automatically.
- auto_shutdown_minutes = 0
- else:
- auto_shutdown_minutes = float(
- os.environ.get(DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES, "30")
- )
- if auto_shutdown_minutes == 0:
- _logger.info(
- "The Ray cluster will keep running until you manually detach the "
- "Databricks notebook or call "
- "`ray.util.spark.shutdown_ray_cluster()`."
- )
- return
- if auto_shutdown_minutes < 0:
- raise ValueError(
- "You must set "
- f"'{DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES}' "
- "to a value >= 0."
- )
- try:
- db_api_entry.getIdleTimeMillisSinceLastNotebookExecution()
- except Exception:
- _logger.warning(
- "Failed to retrieve idle time since last notebook execution, "
- "so that we cannot automatically shut down Ray cluster when "
- "Databricks notebook is inactive for the specified minutes. "
- "You need to manually detach Databricks notebook "
- "or call `ray.util.spark.shutdown_ray_cluster()` to shut down "
- "Ray cluster on spark."
- )
- return
- _logger.info(
- "The Ray cluster will be shut down automatically if you don't run "
- "commands on the Databricks notebook for "
- f"{auto_shutdown_minutes} minutes. You can change the "
- "auto-shutdown minutes by setting "
- f"'{DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES}' environment "
- "variable, setting it to 0 means that the Ray cluster keeps running "
- "until you manually call `ray.util.spark.shutdown_ray_cluster()` or "
- "detach Databricks notebook."
- )
- def auto_shutdown_watcher():
- auto_shutdown_millis = auto_shutdown_minutes * 60 * 1000
- while True:
- if ray_cluster_handler.is_shutdown:
- # The cluster is shut down. The watcher thread exits.
- return
- idle_time = db_api_entry.getIdleTimeMillisSinceLastNotebookExecution()
- if idle_time > auto_shutdown_millis:
- from ray.util.spark import cluster_init
- with cluster_init._active_ray_cluster_rwlock:
- if ray_cluster_handler is cluster_init._active_ray_cluster:
- cluster_init.shutdown_ray_cluster()
- return
- time.sleep(DATABRICKS_AUTO_SHUTDOWN_POLL_INTERVAL_SECONDS)
- threading.Thread(target=auto_shutdown_watcher, daemon=True).start()
- def on_spark_job_created(self, job_group_id):
- db_api_entry = get_db_entry_point()
- db_api_entry.registerBackgroundSparkJobGroup(job_group_id)
- def custom_environment_variables(self):
- conf = {
- **super().custom_environment_variables(),
- # Hardcode `GLOO_SOCKET_IFNAME` to `eth0` for Databricks runtime.
- # Torch on DBR does not reliably detect the correct interface to use,
- # and ends up selecting the loopback interface, breaking cross-node
- # commnication.
- "GLOO_SOCKET_IFNAME": "eth0",
- # 'DISABLE_MLFLOW_INTEGRATION' is the environmental variable to disable
- # huggingface transformers MLflow integration,
- # it doesn't work well in Databricks runtime,
- # So disable it by default.
- "DISABLE_MLFLOW_INTEGRATION": "TRUE",
- }
- if verify_databricks_auth_env():
- conf[DATABRICKS_HOST] = os.environ[DATABRICKS_HOST]
- if DATABRICKS_TOKEN in os.environ:
- # PAT auth
- conf[DATABRICKS_TOKEN] = os.environ[DATABRICKS_TOKEN]
- else:
- # OAuth
- conf[DATABRICKS_CLIENT_ID] = os.environ[DATABRICKS_CLIENT_ID]
- conf[DATABRICKS_CLIENT_SECRET] = os.environ[DATABRICKS_CLIENT_SECRET]
- else:
- warn_msg = (
- "MLflow support is not correctly configured within Ray tasks."
- "To enable MLflow integration, "
- "you need to set environmental variables DATABRICKS_HOST + "
- "DATABRICKS_TOKEN, or set environmental variables "
- "DATABRICKS_HOST + DATABRICKS_CLIENT_ID + DATABRICKS_CLIENT_SECRET "
- "before calling `ray.util.spark.setup_ray_cluster`, these variables "
- "are used to set up authentication with Databricks MLflow "
- "service. For details, you can refer to Databricks documentation at "
- "<a href='https://docs.databricks.com/en/dev-tools/auth/pat.html'>"
- "Databricks PAT auth</a> or "
- "<a href='https://docs.databricks.com/en/dev-tools/auth/"
- "oauth-m2m.html'>Databricks OAuth</a>."
- )
- get_databricks_display_html_function()(
- f"<b style='color:red;'>{warn_msg}<br></b>"
- )
- return conf
|