client_builder.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. import importlib
  2. import inspect
  3. import json
  4. import logging
  5. import os
  6. import sys
  7. import warnings
  8. from dataclasses import dataclass
  9. from typing import Any, Dict, Optional, Tuple
  10. import ray.util.client_connect
  11. from ray._private.ray_constants import (
  12. RAY_ADDRESS_ENVIRONMENT_VARIABLE,
  13. RAY_NAMESPACE_ENVIRONMENT_VARIABLE,
  14. RAY_RUNTIME_ENV_ENVIRONMENT_VARIABLE,
  15. )
  16. from ray._private.utils import get_ray_client_dependency_error, split_address
  17. from ray._private.worker import BaseContext, init as ray_driver_init
  18. from ray.job_config import JobConfig
  19. from ray.util.annotations import Deprecated, PublicAPI
  20. logger = logging.getLogger(__name__)
  21. CLIENT_DOCS_URL = (
  22. "https://docs.ray.io/en/latest/cluster/running-applications/"
  23. "job-submission/ray-client.html"
  24. )
  25. @dataclass
  26. @PublicAPI
  27. class ClientContext(BaseContext):
  28. """
  29. Basic context manager for a ClientBuilder connection.
  30. """
  31. dashboard_url: Optional[str]
  32. python_version: str
  33. ray_version: str
  34. ray_commit: str
  35. _num_clients: int
  36. _context_to_restore: Optional[ray.util.client.RayAPIStub]
  37. def __enter__(self) -> "ClientContext":
  38. self._swap_context()
  39. return self
  40. def __exit__(self, *exc) -> None:
  41. self._disconnect_with_context(False)
  42. self._swap_context()
  43. def disconnect(self) -> None:
  44. self._swap_context()
  45. self._disconnect_with_context(True)
  46. self._swap_context()
  47. def _swap_context(self):
  48. if self._context_to_restore is not None:
  49. self._context_to_restore = ray.util.client.ray.set_context(
  50. self._context_to_restore
  51. )
  52. def _disconnect_with_context(self, force_disconnect: bool) -> None:
  53. """
  54. Disconnect Ray. If it's a ray client and created with `allow_multiple`,
  55. it will do nothing. For other cases this either disconnects from the
  56. remote Client Server or shuts the current driver down.
  57. """
  58. if ray.util.client.ray.is_connected():
  59. if ray.util.client.ray.is_default() or force_disconnect:
  60. # This is the only client connection
  61. ray.util.client_connect.disconnect()
  62. elif ray._private.worker.global_worker.node is None:
  63. # Already disconnected.
  64. return
  65. elif ray._private.worker.global_worker.node.is_head():
  66. logger.debug(
  67. "The current Ray Cluster is scoped to this process. "
  68. "Disconnecting is not possible as it will shutdown the "
  69. "cluster."
  70. )
  71. else:
  72. # This is only a driver connected to an existing cluster.
  73. ray.shutdown()
  74. @Deprecated
  75. class ClientBuilder:
  76. """
  77. Builder for a Ray Client connection. This class can be subclassed by
  78. custom builder classes to modify connection behavior to include additional
  79. features or altered semantics. One example is the ``_LocalClientBuilder``.
  80. """
  81. def __init__(self, address: Optional[str]) -> None:
  82. if get_ray_client_dependency_error() is not None:
  83. raise ValueError(
  84. "Ray Client requires pip package `ray[client]`. "
  85. "If you installed the minimal Ray (e.g. `pip install ray`), "
  86. "please reinstall by executing `pip install ray[client]`."
  87. )
  88. self.address = address
  89. self._job_config = JobConfig()
  90. self._remote_init_kwargs = {}
  91. # Whether to allow connections to multiple clusters"
  92. # " (allow_multiple=True).
  93. self._allow_multiple_connections = False
  94. self._credentials = None
  95. self._metadata = None
  96. # Set to False if ClientBuilder is being constructed by internal
  97. # methods
  98. self._deprecation_warn_enabled = True
  99. def env(self, env: Dict[str, Any]) -> "ClientBuilder":
  100. """
  101. Set an environment for the session.
  102. Args:
  103. env (Dict[st, Any]): A runtime environment to use for this
  104. connection. See :ref:`runtime-environments` for what values are
  105. accepted in this dict.
  106. """
  107. self._job_config.set_runtime_env(env)
  108. return self
  109. def namespace(self, namespace: str) -> "ClientBuilder":
  110. """
  111. Sets the namespace for the session.
  112. Args:
  113. namespace: Namespace to use.
  114. """
  115. self._job_config.set_ray_namespace(namespace)
  116. return self
  117. def connect(self) -> ClientContext:
  118. """
  119. Begin a connection to the address passed in via ray.client(...).
  120. Returns:
  121. ClientInfo: Dataclass with information about the setting. This
  122. includes the server's version of Python & Ray as well as the
  123. dashboard_url.
  124. """
  125. if self._deprecation_warn_enabled:
  126. self._client_deprecation_warn()
  127. # Fill runtime env/namespace from environment if not already set.
  128. # Should be done *after* the deprecation warning, since warning will
  129. # check if those values are already set.
  130. self._fill_defaults_from_env()
  131. # If it has already connected to the cluster with allow_multiple=True,
  132. # connect to the default one is not allowed.
  133. # But if it has connected to the default one, connect to other clients
  134. # with allow_multiple=True is allowed
  135. default_cli_connected = ray.util.client.ray.is_connected()
  136. has_cli_connected = ray.util.client.num_connected_contexts() > 0
  137. if (
  138. not self._allow_multiple_connections
  139. and not default_cli_connected
  140. and has_cli_connected
  141. ):
  142. raise ValueError(
  143. "The client has already connected to the cluster "
  144. "with allow_multiple=True. Please set allow_multiple=True"
  145. " to proceed"
  146. )
  147. old_ray_cxt = None
  148. if self._allow_multiple_connections:
  149. old_ray_cxt = ray.util.client.ray.set_context(None)
  150. client_info_dict = ray.util.client_connect.connect(
  151. self.address,
  152. job_config=self._job_config,
  153. _credentials=self._credentials,
  154. ray_init_kwargs=self._remote_init_kwargs,
  155. metadata=self._metadata,
  156. )
  157. dashboard_url = ray.util.client.ray._get_dashboard_url()
  158. cxt = ClientContext(
  159. dashboard_url=dashboard_url,
  160. python_version=client_info_dict["python_version"],
  161. ray_version=client_info_dict["ray_version"],
  162. ray_commit=client_info_dict["ray_commit"],
  163. _num_clients=client_info_dict["num_clients"],
  164. _context_to_restore=ray.util.client.ray.get_context(),
  165. )
  166. if self._allow_multiple_connections:
  167. ray.util.client.ray.set_context(old_ray_cxt)
  168. return cxt
  169. def _fill_defaults_from_env(self):
  170. # Check environment variables for default values
  171. namespace_env_var = os.environ.get(RAY_NAMESPACE_ENVIRONMENT_VARIABLE)
  172. if namespace_env_var and self._job_config.ray_namespace is None:
  173. self.namespace(namespace_env_var)
  174. runtime_env_var = os.environ.get(RAY_RUNTIME_ENV_ENVIRONMENT_VARIABLE)
  175. if runtime_env_var and self._job_config.runtime_env is None:
  176. self.env(json.loads(runtime_env_var))
  177. def _init_args(self, **kwargs) -> "ClientBuilder":
  178. """
  179. When a client builder is constructed through ray.init, for example
  180. `ray.init(ray://..., namespace=...)`, all of the
  181. arguments passed into ray.init with non-default values are passed
  182. again into this method. Custom client builders can override this method
  183. to do their own handling/validation of arguments.
  184. """
  185. # Use namespace and runtime_env from ray.init call
  186. if kwargs.get("namespace") is not None:
  187. self.namespace(kwargs["namespace"])
  188. del kwargs["namespace"]
  189. if kwargs.get("runtime_env") is not None:
  190. self.env(kwargs["runtime_env"])
  191. del kwargs["runtime_env"]
  192. if kwargs.get("allow_multiple") is True:
  193. self._allow_multiple_connections = True
  194. del kwargs["allow_multiple"]
  195. if "_credentials" in kwargs.keys():
  196. self._credentials = kwargs["_credentials"]
  197. del kwargs["_credentials"]
  198. if "_metadata" in kwargs.keys():
  199. self._metadata = kwargs["_metadata"]
  200. del kwargs["_metadata"]
  201. if kwargs:
  202. expected_sig = inspect.signature(ray_driver_init)
  203. extra_args = set(kwargs.keys()).difference(expected_sig.parameters.keys())
  204. if len(extra_args) > 0:
  205. raise RuntimeError(
  206. "Got unexpected kwargs: {}".format(", ".join(extra_args))
  207. )
  208. self._remote_init_kwargs = kwargs
  209. unknown = ", ".join(kwargs)
  210. logger.info(
  211. "Passing the following kwargs to ray.init() "
  212. f"on the server: {unknown}"
  213. )
  214. return self
  215. def _client_deprecation_warn(self) -> None:
  216. """
  217. Generates a warning for user's if this ClientBuilder instance was
  218. created directly or through ray.client, instead of relying on
  219. internal methods (ray.init, or auto init)
  220. """
  221. namespace = self._job_config.ray_namespace
  222. runtime_env = self._job_config.runtime_env
  223. replacement_args = []
  224. if self.address:
  225. if isinstance(self, _LocalClientBuilder):
  226. # Address might be set for LocalClientBuilder if ray.client()
  227. # is called while ray_current_cluster is set
  228. # (see _get_builder_from_address). In this case,
  229. # leave off the ray:// so the user attaches the driver directly
  230. replacement_args.append(f'"{self.address}"')
  231. else:
  232. replacement_args.append(f'"ray://{self.address}"')
  233. if namespace:
  234. replacement_args.append(f'namespace="{namespace}"')
  235. if runtime_env:
  236. # Use a placeholder here, since the real runtime_env would be
  237. # difficult to read if formatted in directly
  238. replacement_args.append("runtime_env=<your_runtime_env>")
  239. args_str = ", ".join(replacement_args)
  240. replacement_call = f"ray.init({args_str})"
  241. # Note: stack level is set to 3 since we want the warning to reach the
  242. # call to ray.client(...).connect(). The intervening frames are
  243. # connect() -> client_deprecation_warn() -> warnings.warn()
  244. # https://docs.python.org/3/library/warnings.html#available-functions
  245. warnings.warn(
  246. "Starting a connection through `ray.client` will be deprecated "
  247. "in future ray versions in favor of `ray.init`. See the docs for "
  248. f"more details: {CLIENT_DOCS_URL}. You can replace your call to "
  249. "`ray.client().connect()` with the following:\n"
  250. f" {replacement_call}\n",
  251. DeprecationWarning,
  252. stacklevel=3,
  253. )
  254. class _LocalClientBuilder(ClientBuilder):
  255. def connect(self) -> ClientContext:
  256. """
  257. Begin a connection to the address passed in via ray.client(...)
  258. """
  259. if self._deprecation_warn_enabled:
  260. self._client_deprecation_warn()
  261. # Fill runtime env/namespace from environment if not already set.
  262. # Should be done *after* the deprecation warning, since warning will
  263. # check if those values are already set.
  264. self._fill_defaults_from_env()
  265. connection_dict = ray.init(address=self.address, job_config=self._job_config)
  266. return ClientContext(
  267. dashboard_url=connection_dict["webui_url"],
  268. python_version="{}.{}.{}".format(
  269. sys.version_info[0], sys.version_info[1], sys.version_info[2]
  270. ),
  271. ray_version=ray.__version__,
  272. ray_commit=ray.__commit__,
  273. _num_clients=1,
  274. _context_to_restore=None,
  275. )
  276. def _split_address(address: str) -> Tuple[str, str]:
  277. """
  278. Splits address into a module string (scheme) and an inner_address.
  279. If the scheme is not present, then "ray://" is prepended to the address.
  280. """
  281. if "://" not in address:
  282. address = "ray://" + address
  283. return split_address(address)
  284. def _get_builder_from_address(address: Optional[str]) -> ClientBuilder:
  285. if address == "local":
  286. return _LocalClientBuilder("local")
  287. if address is None:
  288. # NOTE: This is not placed in `Node::get_temp_dir_path`, because
  289. # this file is accessed before the `Node` object is created.
  290. address = ray._private.services.canonicalize_bootstrap_address(address)
  291. return _LocalClientBuilder(address)
  292. module_string, inner_address = _split_address(address)
  293. try:
  294. module = importlib.import_module(module_string)
  295. except Exception as e:
  296. raise RuntimeError(
  297. f"Module: {module_string} does not exist.\n"
  298. f"This module was parsed from Address: {address}"
  299. ) from e
  300. assert "ClientBuilder" in dir(
  301. module
  302. ), f"Module: {module_string} does not have ClientBuilder."
  303. return module.ClientBuilder(inner_address)
  304. @Deprecated
  305. def client(
  306. address: Optional[str] = None, _deprecation_warn_enabled: bool = True
  307. ) -> ClientBuilder:
  308. """
  309. Creates a ClientBuilder based on the provided address. The address can be
  310. of the following forms:
  311. * None: Connects to or creates a local cluster and connects to it.
  312. * ``"local"``: Creates a new cluster locally and connects to it.
  313. * ``"IP:Port"``: Connects to a Ray Client Server at the given address.
  314. * ``"module://inner_address"``: load module.ClientBuilder & pass
  315. inner_address
  316. The _deprecation_warn_enabled flag enables deprecation warnings, and is
  317. for internal use only. Set it to False to suppress client deprecation
  318. warnings.
  319. """
  320. env_address = os.environ.get(RAY_ADDRESS_ENVIRONMENT_VARIABLE)
  321. if env_address and address is None:
  322. logger.debug(
  323. f"Using address ({env_address}) instead of auto-detection "
  324. f"because {RAY_ADDRESS_ENVIRONMENT_VARIABLE} is set."
  325. )
  326. address = env_address
  327. builder = _get_builder_from_address(address)
  328. # Disable client deprecation warn when ray.client is used internally
  329. builder._deprecation_warn_enabled = _deprecation_warn_enabled
  330. return builder