| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376 |
- import importlib
- import inspect
- import json
- import logging
- import os
- import sys
- import warnings
- from dataclasses import dataclass
- from typing import Any, Dict, Optional, Tuple
- import ray.util.client_connect
- from ray._private.ray_constants import (
- RAY_ADDRESS_ENVIRONMENT_VARIABLE,
- RAY_NAMESPACE_ENVIRONMENT_VARIABLE,
- RAY_RUNTIME_ENV_ENVIRONMENT_VARIABLE,
- )
- from ray._private.utils import get_ray_client_dependency_error, split_address
- from ray._private.worker import BaseContext, init as ray_driver_init
- from ray.job_config import JobConfig
- from ray.util.annotations import Deprecated, PublicAPI
- logger = logging.getLogger(__name__)
- CLIENT_DOCS_URL = (
- "https://docs.ray.io/en/latest/cluster/running-applications/"
- "job-submission/ray-client.html"
- )
- @dataclass
- @PublicAPI
- class ClientContext(BaseContext):
- """
- Basic context manager for a ClientBuilder connection.
- """
- dashboard_url: Optional[str]
- python_version: str
- ray_version: str
- ray_commit: str
- _num_clients: int
- _context_to_restore: Optional[ray.util.client.RayAPIStub]
- def __enter__(self) -> "ClientContext":
- self._swap_context()
- return self
- def __exit__(self, *exc) -> None:
- self._disconnect_with_context(False)
- self._swap_context()
- def disconnect(self) -> None:
- self._swap_context()
- self._disconnect_with_context(True)
- self._swap_context()
- def _swap_context(self):
- if self._context_to_restore is not None:
- self._context_to_restore = ray.util.client.ray.set_context(
- self._context_to_restore
- )
- def _disconnect_with_context(self, force_disconnect: bool) -> None:
- """
- Disconnect Ray. If it's a ray client and created with `allow_multiple`,
- it will do nothing. For other cases this either disconnects from the
- remote Client Server or shuts the current driver down.
- """
- if ray.util.client.ray.is_connected():
- if ray.util.client.ray.is_default() or force_disconnect:
- # This is the only client connection
- ray.util.client_connect.disconnect()
- elif ray._private.worker.global_worker.node is None:
- # Already disconnected.
- return
- elif ray._private.worker.global_worker.node.is_head():
- logger.debug(
- "The current Ray Cluster is scoped to this process. "
- "Disconnecting is not possible as it will shutdown the "
- "cluster."
- )
- else:
- # This is only a driver connected to an existing cluster.
- ray.shutdown()
- @Deprecated
- class ClientBuilder:
- """
- Builder for a Ray Client connection. This class can be subclassed by
- custom builder classes to modify connection behavior to include additional
- features or altered semantics. One example is the ``_LocalClientBuilder``.
- """
- def __init__(self, address: Optional[str]) -> None:
- if get_ray_client_dependency_error() is not None:
- raise ValueError(
- "Ray Client requires pip package `ray[client]`. "
- "If you installed the minimal Ray (e.g. `pip install ray`), "
- "please reinstall by executing `pip install ray[client]`."
- )
- self.address = address
- self._job_config = JobConfig()
- self._remote_init_kwargs = {}
- # Whether to allow connections to multiple clusters"
- # " (allow_multiple=True).
- self._allow_multiple_connections = False
- self._credentials = None
- self._metadata = None
- # Set to False if ClientBuilder is being constructed by internal
- # methods
- self._deprecation_warn_enabled = True
- def env(self, env: Dict[str, Any]) -> "ClientBuilder":
- """
- Set an environment for the session.
- Args:
- env (Dict[st, Any]): A runtime environment to use for this
- connection. See :ref:`runtime-environments` for what values are
- accepted in this dict.
- """
- self._job_config.set_runtime_env(env)
- return self
- def namespace(self, namespace: str) -> "ClientBuilder":
- """
- Sets the namespace for the session.
- Args:
- namespace: Namespace to use.
- """
- self._job_config.set_ray_namespace(namespace)
- return self
- def connect(self) -> ClientContext:
- """
- Begin a connection to the address passed in via ray.client(...).
- Returns:
- ClientInfo: Dataclass with information about the setting. This
- includes the server's version of Python & Ray as well as the
- dashboard_url.
- """
- if self._deprecation_warn_enabled:
- self._client_deprecation_warn()
- # Fill runtime env/namespace from environment if not already set.
- # Should be done *after* the deprecation warning, since warning will
- # check if those values are already set.
- self._fill_defaults_from_env()
- # If it has already connected to the cluster with allow_multiple=True,
- # connect to the default one is not allowed.
- # But if it has connected to the default one, connect to other clients
- # with allow_multiple=True is allowed
- default_cli_connected = ray.util.client.ray.is_connected()
- has_cli_connected = ray.util.client.num_connected_contexts() > 0
- if (
- not self._allow_multiple_connections
- and not default_cli_connected
- and has_cli_connected
- ):
- raise ValueError(
- "The client has already connected to the cluster "
- "with allow_multiple=True. Please set allow_multiple=True"
- " to proceed"
- )
- old_ray_cxt = None
- if self._allow_multiple_connections:
- old_ray_cxt = ray.util.client.ray.set_context(None)
- client_info_dict = ray.util.client_connect.connect(
- self.address,
- job_config=self._job_config,
- _credentials=self._credentials,
- ray_init_kwargs=self._remote_init_kwargs,
- metadata=self._metadata,
- )
- dashboard_url = ray.util.client.ray._get_dashboard_url()
- cxt = ClientContext(
- dashboard_url=dashboard_url,
- python_version=client_info_dict["python_version"],
- ray_version=client_info_dict["ray_version"],
- ray_commit=client_info_dict["ray_commit"],
- _num_clients=client_info_dict["num_clients"],
- _context_to_restore=ray.util.client.ray.get_context(),
- )
- if self._allow_multiple_connections:
- ray.util.client.ray.set_context(old_ray_cxt)
- return cxt
- def _fill_defaults_from_env(self):
- # Check environment variables for default values
- namespace_env_var = os.environ.get(RAY_NAMESPACE_ENVIRONMENT_VARIABLE)
- if namespace_env_var and self._job_config.ray_namespace is None:
- self.namespace(namespace_env_var)
- runtime_env_var = os.environ.get(RAY_RUNTIME_ENV_ENVIRONMENT_VARIABLE)
- if runtime_env_var and self._job_config.runtime_env is None:
- self.env(json.loads(runtime_env_var))
- def _init_args(self, **kwargs) -> "ClientBuilder":
- """
- When a client builder is constructed through ray.init, for example
- `ray.init(ray://..., namespace=...)`, all of the
- arguments passed into ray.init with non-default values are passed
- again into this method. Custom client builders can override this method
- to do their own handling/validation of arguments.
- """
- # Use namespace and runtime_env from ray.init call
- if kwargs.get("namespace") is not None:
- self.namespace(kwargs["namespace"])
- del kwargs["namespace"]
- if kwargs.get("runtime_env") is not None:
- self.env(kwargs["runtime_env"])
- del kwargs["runtime_env"]
- if kwargs.get("allow_multiple") is True:
- self._allow_multiple_connections = True
- del kwargs["allow_multiple"]
- if "_credentials" in kwargs.keys():
- self._credentials = kwargs["_credentials"]
- del kwargs["_credentials"]
- if "_metadata" in kwargs.keys():
- self._metadata = kwargs["_metadata"]
- del kwargs["_metadata"]
- if kwargs:
- expected_sig = inspect.signature(ray_driver_init)
- extra_args = set(kwargs.keys()).difference(expected_sig.parameters.keys())
- if len(extra_args) > 0:
- raise RuntimeError(
- "Got unexpected kwargs: {}".format(", ".join(extra_args))
- )
- self._remote_init_kwargs = kwargs
- unknown = ", ".join(kwargs)
- logger.info(
- "Passing the following kwargs to ray.init() "
- f"on the server: {unknown}"
- )
- return self
- def _client_deprecation_warn(self) -> None:
- """
- Generates a warning for user's if this ClientBuilder instance was
- created directly or through ray.client, instead of relying on
- internal methods (ray.init, or auto init)
- """
- namespace = self._job_config.ray_namespace
- runtime_env = self._job_config.runtime_env
- replacement_args = []
- if self.address:
- if isinstance(self, _LocalClientBuilder):
- # Address might be set for LocalClientBuilder if ray.client()
- # is called while ray_current_cluster is set
- # (see _get_builder_from_address). In this case,
- # leave off the ray:// so the user attaches the driver directly
- replacement_args.append(f'"{self.address}"')
- else:
- replacement_args.append(f'"ray://{self.address}"')
- if namespace:
- replacement_args.append(f'namespace="{namespace}"')
- if runtime_env:
- # Use a placeholder here, since the real runtime_env would be
- # difficult to read if formatted in directly
- replacement_args.append("runtime_env=<your_runtime_env>")
- args_str = ", ".join(replacement_args)
- replacement_call = f"ray.init({args_str})"
- # Note: stack level is set to 3 since we want the warning to reach the
- # call to ray.client(...).connect(). The intervening frames are
- # connect() -> client_deprecation_warn() -> warnings.warn()
- # https://docs.python.org/3/library/warnings.html#available-functions
- warnings.warn(
- "Starting a connection through `ray.client` will be deprecated "
- "in future ray versions in favor of `ray.init`. See the docs for "
- f"more details: {CLIENT_DOCS_URL}. You can replace your call to "
- "`ray.client().connect()` with the following:\n"
- f" {replacement_call}\n",
- DeprecationWarning,
- stacklevel=3,
- )
- class _LocalClientBuilder(ClientBuilder):
- def connect(self) -> ClientContext:
- """
- Begin a connection to the address passed in via ray.client(...)
- """
- if self._deprecation_warn_enabled:
- self._client_deprecation_warn()
- # Fill runtime env/namespace from environment if not already set.
- # Should be done *after* the deprecation warning, since warning will
- # check if those values are already set.
- self._fill_defaults_from_env()
- connection_dict = ray.init(address=self.address, job_config=self._job_config)
- return ClientContext(
- dashboard_url=connection_dict["webui_url"],
- python_version="{}.{}.{}".format(
- sys.version_info[0], sys.version_info[1], sys.version_info[2]
- ),
- ray_version=ray.__version__,
- ray_commit=ray.__commit__,
- _num_clients=1,
- _context_to_restore=None,
- )
- def _split_address(address: str) -> Tuple[str, str]:
- """
- Splits address into a module string (scheme) and an inner_address.
- If the scheme is not present, then "ray://" is prepended to the address.
- """
- if "://" not in address:
- address = "ray://" + address
- return split_address(address)
- def _get_builder_from_address(address: Optional[str]) -> ClientBuilder:
- if address == "local":
- return _LocalClientBuilder("local")
- if address is None:
- # NOTE: This is not placed in `Node::get_temp_dir_path`, because
- # this file is accessed before the `Node` object is created.
- address = ray._private.services.canonicalize_bootstrap_address(address)
- return _LocalClientBuilder(address)
- module_string, inner_address = _split_address(address)
- try:
- module = importlib.import_module(module_string)
- except Exception as e:
- raise RuntimeError(
- f"Module: {module_string} does not exist.\n"
- f"This module was parsed from Address: {address}"
- ) from e
- assert "ClientBuilder" in dir(
- module
- ), f"Module: {module_string} does not have ClientBuilder."
- return module.ClientBuilder(inner_address)
- @Deprecated
- def client(
- address: Optional[str] = None, _deprecation_warn_enabled: bool = True
- ) -> ClientBuilder:
- """
- Creates a ClientBuilder based on the provided address. The address can be
- of the following forms:
- * None: Connects to or creates a local cluster and connects to it.
- * ``"local"``: Creates a new cluster locally and connects to it.
- * ``"IP:Port"``: Connects to a Ray Client Server at the given address.
- * ``"module://inner_address"``: load module.ClientBuilder & pass
- inner_address
- The _deprecation_warn_enabled flag enables deprecation warnings, and is
- for internal use only. Set it to False to suppress client deprecation
- warnings.
- """
- env_address = os.environ.get(RAY_ADDRESS_ENVIRONMENT_VARIABLE)
- if env_address and address is None:
- logger.debug(
- f"Using address ({env_address}) instead of auto-detection "
- f"because {RAY_ADDRESS_ENVIRONMENT_VARIABLE} is set."
- )
- address = env_address
- builder = _get_builder_from_address(address)
- # Disable client deprecation warn when ray.client is used internally
- builder._deprecation_warn_enabled = _deprecation_warn_enabled
- return builder
|