databricks_hook.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import logging
  2. import os
  3. import threading
  4. import time
  5. from .start_hook_base import RayOnSparkStartHook
  6. from .utils import get_spark_session
  7. _logger = logging.getLogger(__name__)
  8. DATABRICKS_HOST = "DATABRICKS_HOST"
  9. DATABRICKS_TOKEN = "DATABRICKS_TOKEN"
  10. DATABRICKS_CLIENT_ID = "DATABRICKS_CLIENT_ID"
  11. DATABRICKS_CLIENT_SECRET = "DATABRICKS_CLIENT_SECRET"
  12. def verify_databricks_auth_env():
  13. return (DATABRICKS_HOST in os.environ and DATABRICKS_TOKEN in os.environ) or (
  14. DATABRICKS_HOST in os.environ
  15. and DATABRICKS_CLIENT_ID in os.environ
  16. and DATABRICKS_CLIENT_SECRET in os.environ
  17. )
  18. def get_databricks_function(func_name):
  19. import IPython
  20. ip_shell = IPython.get_ipython()
  21. if ip_shell is None:
  22. raise RuntimeError("No IPython environment.")
  23. return ip_shell.ns_table["user_global"][func_name]
  24. def get_databricks_display_html_function():
  25. return get_databricks_function("displayHTML")
  26. def get_db_entry_point():
  27. """
  28. Return databricks entry_point instance, it is for calling some
  29. internal API in databricks runtime
  30. """
  31. from dbruntime import UserNamespaceInitializer
  32. user_namespace_initializer = UserNamespaceInitializer.getOrCreate()
  33. return user_namespace_initializer.get_spark_entry_point()
  34. def display_databricks_driver_proxy_url(spark_context, port, title):
  35. """
  36. This helper function create a proxy URL for databricks driver webapp forwarding.
  37. In databricks runtime, user does not have permission to directly access web
  38. service binding on driver machine port, but user can visit it by a proxy URL with
  39. following format: "/driver-proxy/o/{orgId}/{clusterId}/{port}/".
  40. """
  41. driverLocal = spark_context._jvm.com.databricks.backend.daemon.driver.DriverLocal
  42. commandContextTags = driverLocal.commandContext().get().toStringMap().apply("tags")
  43. orgId = commandContextTags.apply("orgId")
  44. clusterId = commandContextTags.apply("clusterId")
  45. proxy_link = f"/driver-proxy/o/{orgId}/{clusterId}/{port}/"
  46. proxy_url = f"https://dbc-dp-{orgId}.cloud.databricks.com{proxy_link}"
  47. print("To monitor and debug Ray from Databricks, view the dashboard at ")
  48. print(f" {proxy_url}")
  49. get_databricks_display_html_function()(
  50. f"""
  51. <div style="margin-top: 16px;margin-bottom: 16px">
  52. <a href="{proxy_link}">
  53. Open {title} in a new tab
  54. </a>
  55. </div>
  56. """
  57. )
  58. DATABRICKS_AUTO_SHUTDOWN_POLL_INTERVAL_SECONDS = 3
  59. DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES = (
  60. "DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES"
  61. )
  62. _DATABRICKS_DEFAULT_TMP_ROOT_DIR = "/local_disk0/tmp"
  63. class DefaultDatabricksRayOnSparkStartHook(RayOnSparkStartHook):
  64. def get_default_temp_root_dir(self):
  65. return _DATABRICKS_DEFAULT_TMP_ROOT_DIR
  66. def on_ray_dashboard_created(self, port):
  67. display_databricks_driver_proxy_url(
  68. get_spark_session().sparkContext, port, "Ray Cluster Dashboard"
  69. )
  70. def on_cluster_created(self, ray_cluster_handler):
  71. db_api_entry = get_db_entry_point()
  72. if self.is_global:
  73. # Disable auto shutdown if
  74. # 1) autoscaling enabled
  75. # because in autoscaling mode, background spark job will be killed
  76. # automatically when ray cluster is idle.
  77. # 2) global mode cluster
  78. # Because global mode cluster is designed to keep running until
  79. # user request to shut down it, and global mode cluster is shared
  80. # by other users, the code here cannot track usage from other users
  81. # so that we don't know whether it is safe to shut down the global
  82. # cluster automatically.
  83. auto_shutdown_minutes = 0
  84. else:
  85. auto_shutdown_minutes = float(
  86. os.environ.get(DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES, "30")
  87. )
  88. if auto_shutdown_minutes == 0:
  89. _logger.info(
  90. "The Ray cluster will keep running until you manually detach the "
  91. "Databricks notebook or call "
  92. "`ray.util.spark.shutdown_ray_cluster()`."
  93. )
  94. return
  95. if auto_shutdown_minutes < 0:
  96. raise ValueError(
  97. "You must set "
  98. f"'{DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES}' "
  99. "to a value >= 0."
  100. )
  101. try:
  102. db_api_entry.getIdleTimeMillisSinceLastNotebookExecution()
  103. except Exception:
  104. _logger.warning(
  105. "Failed to retrieve idle time since last notebook execution, "
  106. "so that we cannot automatically shut down Ray cluster when "
  107. "Databricks notebook is inactive for the specified minutes. "
  108. "You need to manually detach Databricks notebook "
  109. "or call `ray.util.spark.shutdown_ray_cluster()` to shut down "
  110. "Ray cluster on spark."
  111. )
  112. return
  113. _logger.info(
  114. "The Ray cluster will be shut down automatically if you don't run "
  115. "commands on the Databricks notebook for "
  116. f"{auto_shutdown_minutes} minutes. You can change the "
  117. "auto-shutdown minutes by setting "
  118. f"'{DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES}' environment "
  119. "variable, setting it to 0 means that the Ray cluster keeps running "
  120. "until you manually call `ray.util.spark.shutdown_ray_cluster()` or "
  121. "detach Databricks notebook."
  122. )
  123. def auto_shutdown_watcher():
  124. auto_shutdown_millis = auto_shutdown_minutes * 60 * 1000
  125. while True:
  126. if ray_cluster_handler.is_shutdown:
  127. # The cluster is shut down. The watcher thread exits.
  128. return
  129. idle_time = db_api_entry.getIdleTimeMillisSinceLastNotebookExecution()
  130. if idle_time > auto_shutdown_millis:
  131. from ray.util.spark import cluster_init
  132. with cluster_init._active_ray_cluster_rwlock:
  133. if ray_cluster_handler is cluster_init._active_ray_cluster:
  134. cluster_init.shutdown_ray_cluster()
  135. return
  136. time.sleep(DATABRICKS_AUTO_SHUTDOWN_POLL_INTERVAL_SECONDS)
  137. threading.Thread(target=auto_shutdown_watcher, daemon=True).start()
  138. def on_spark_job_created(self, job_group_id):
  139. db_api_entry = get_db_entry_point()
  140. db_api_entry.registerBackgroundSparkJobGroup(job_group_id)
  141. def custom_environment_variables(self):
  142. conf = {
  143. **super().custom_environment_variables(),
  144. # Hardcode `GLOO_SOCKET_IFNAME` to `eth0` for Databricks runtime.
  145. # Torch on DBR does not reliably detect the correct interface to use,
  146. # and ends up selecting the loopback interface, breaking cross-node
  147. # commnication.
  148. "GLOO_SOCKET_IFNAME": "eth0",
  149. # 'DISABLE_MLFLOW_INTEGRATION' is the environmental variable to disable
  150. # huggingface transformers MLflow integration,
  151. # it doesn't work well in Databricks runtime,
  152. # So disable it by default.
  153. "DISABLE_MLFLOW_INTEGRATION": "TRUE",
  154. }
  155. if verify_databricks_auth_env():
  156. conf[DATABRICKS_HOST] = os.environ[DATABRICKS_HOST]
  157. if DATABRICKS_TOKEN in os.environ:
  158. # PAT auth
  159. conf[DATABRICKS_TOKEN] = os.environ[DATABRICKS_TOKEN]
  160. else:
  161. # OAuth
  162. conf[DATABRICKS_CLIENT_ID] = os.environ[DATABRICKS_CLIENT_ID]
  163. conf[DATABRICKS_CLIENT_SECRET] = os.environ[DATABRICKS_CLIENT_SECRET]
  164. else:
  165. warn_msg = (
  166. "MLflow support is not correctly configured within Ray tasks."
  167. "To enable MLflow integration, "
  168. "you need to set environmental variables DATABRICKS_HOST + "
  169. "DATABRICKS_TOKEN, or set environmental variables "
  170. "DATABRICKS_HOST + DATABRICKS_CLIENT_ID + DATABRICKS_CLIENT_SECRET "
  171. "before calling `ray.util.spark.setup_ray_cluster`, these variables "
  172. "are used to set up authentication with Databricks MLflow "
  173. "service. For details, you can refer to Databricks documentation at "
  174. "<a href='https://docs.databricks.com/en/dev-tools/auth/pat.html'>"
  175. "Databricks PAT auth</a> or "
  176. "<a href='https://docs.databricks.com/en/dev-tools/auth/"
  177. "oauth-m2m.html'>Databricks OAuth</a>."
  178. )
  179. get_databricks_display_html_function()(
  180. f"<b style='color:red;'>{warn_msg}<br></b>"
  181. )
  182. return conf