| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- import logging
- import os
- import threading
- from typing import Any, Dict, List, Optional, Tuple
- import ray._private.ray_constants as ray_constants
- from ray._private.client_mode_hook import (
- _explicitly_disable_client_mode,
- _explicitly_enable_client_mode,
- )
- from ray._private.ray_logging import setup_logger
- from ray._private.utils import check_version_info
- from ray.job_config import JobConfig
- from ray.util.annotations import DeveloperAPI
- logger = logging.getLogger(__name__)
- class _ClientContext:
- def __init__(self):
- from ray.util.client.api import _ClientAPI
- self.api = _ClientAPI()
- self.client_worker = None
- self._server = None
- self._connected_with_init = False
- self._inside_client_test = False
- def connect(
- self,
- conn_str: str,
- job_config: JobConfig = None,
- secure: bool = False,
- metadata: List[Tuple[str, str]] = None,
- connection_retries: int = 3,
- namespace: str = None,
- *,
- ignore_version: bool = False,
- _credentials: Optional["grpc.ChannelCredentials"] = None, # noqa: F821
- ray_init_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Dict[str, Any]:
- """Connect the Ray Client to a server.
- Args:
- conn_str: Connection string, in the form "[host]:port"
- job_config: The job config of the server.
- secure: Whether to use a TLS secured gRPC channel
- metadata: gRPC metadata to send on connect
- connection_retries: number of connection attempts to make
- ignore_version: whether to ignore Python or Ray version mismatches.
- This should only be used for debugging purposes.
- Returns:
- Dictionary of connection info, e.g., {"num_clients": 1}.
- """
- # Delay imports until connect to avoid circular imports.
- from ray.util.client.worker import Worker
- if self.client_worker is not None:
- if self._connected_with_init:
- return
- raise Exception("ray.init() called, but ray client is already connected")
- if not self._inside_client_test:
- # If we're calling a client connect specifically and we're not
- # currently in client mode, ensure we are.
- _explicitly_enable_client_mode()
- if namespace is not None:
- job_config = job_config or JobConfig()
- job_config.set_ray_namespace(namespace)
- logging_level = ray_constants.LOGGER_LEVEL
- logging_format = ray_constants.LOGGER_FORMAT
- if ray_init_kwargs is None:
- ray_init_kwargs = {}
- # NOTE(architkulkarni): env_hook is not supported with Ray Client.
- ray_init_kwargs["_skip_env_hook"] = True
- if ray_init_kwargs.get("logging_level") is not None:
- logging_level = ray_init_kwargs["logging_level"]
- if ray_init_kwargs.get("logging_format") is not None:
- logging_format = ray_init_kwargs["logging_format"]
- setup_logger(logging_level, logging_format)
- try:
- self.client_worker = Worker(
- conn_str,
- secure=secure,
- _credentials=_credentials,
- metadata=metadata,
- connection_retries=connection_retries,
- )
- self.api.worker = self.client_worker
- self.client_worker._server_init(job_config, ray_init_kwargs)
- conn_info = self.client_worker.connection_info()
- self._check_versions(conn_info, ignore_version)
- self._register_serializers()
- return conn_info
- except Exception:
- self.disconnect()
- raise
- def _register_serializers(self):
- """Register the custom serializer addons at the client side.
- The server side should have already registered the serializers via
- regular worker's serialization_context mechanism.
- """
- import ray.util.serialization_addons
- from ray.util.serialization import StandaloneSerializationContext
- ctx = StandaloneSerializationContext()
- ray.util.serialization_addons.apply(ctx)
- def _check_versions(self, conn_info: Dict[str, Any], ignore_version: bool) -> None:
- # conn_info has "python_version" and "ray_version" so it can be used to compare.
- ignore_version = ignore_version or ("RAY_IGNORE_VERSION_MISMATCH" in os.environ)
- check_version_info(
- conn_info,
- "Ray Client",
- raise_on_mismatch=not ignore_version,
- python_version_match_level="minor",
- )
- def disconnect(self):
- """Disconnect the Ray Client."""
- from ray.util.client.api import _ClientAPI
- if self.client_worker is not None:
- self.client_worker.close()
- self.api = _ClientAPI()
- self.client_worker = None
- # remote can be called outside of a connection, which is why it
- # exists on the same API layer as connect() itself.
- def remote(self, *args, **kwargs):
- """remote is the hook stub passed on to replace `ray.remote`.
- This sets up remote functions or actors, as the decorator,
- but does not execute them.
- Args:
- args: opaque arguments
- kwargs: opaque keyword arguments
- """
- return self.api.remote(*args, **kwargs)
- def __getattr__(self, key: str):
- if self.is_connected():
- return getattr(self.api, key)
- elif key in ["is_initialized", "_internal_kv_initialized"]:
- # Client is not connected, thus Ray is not considered initialized.
- return lambda: False
- else:
- raise Exception(
- "Ray Client is not connected. Please connect by calling `ray.init`."
- )
- def is_connected(self) -> bool:
- if self.client_worker is None:
- return False
- return self.client_worker.is_connected()
- def init(self, *args, **kwargs):
- if self._server is not None:
- raise Exception("Trying to start two instances of ray via client")
- import ray.util.client.server.server as ray_client_server
- server_handle, address_info = ray_client_server.init_and_serve(
- "127.0.0.1", 50051, *args, **kwargs
- )
- self._server = server_handle.grpc_server
- self.connect("127.0.0.1:50051")
- self._connected_with_init = True
- return address_info
- def shutdown(self, _exiting_interpreter=False):
- self.disconnect()
- import ray.util.client.server.server as ray_client_server
- if self._server is None:
- return
- ray_client_server.shutdown_with_server(self._server, _exiting_interpreter)
- self._server = None
- # All connected context will be put here
- # This struct will be guarded by a lock for thread safety
- _all_contexts = set()
- _lock = threading.Lock()
- # This is the default context which is used when allow_multiple is not True
- _default_context = _ClientContext()
- @DeveloperAPI
- class RayAPIStub:
- """This class stands in as the replacement API for the `import ray` module.
- Much like the ray module, this mostly delegates the work to the
- _client_worker. As parts of the ray API are covered, they are piped through
- here or on the client worker API.
- """
- def __init__(self):
- self._cxt = threading.local()
- self._cxt.handler = _default_context
- self._inside_client_test = False
- def get_context(self):
- try:
- return self._cxt.__getattribute__("handler")
- except AttributeError:
- self._cxt.handler = _default_context
- return self._cxt.handler
- def set_context(self, cxt):
- old_cxt = self.get_context()
- if cxt is None:
- self._cxt.handler = _ClientContext()
- else:
- self._cxt.handler = cxt
- return old_cxt
- def is_default(self):
- return self.get_context() == _default_context
- def connect(self, *args, **kw_args):
- self.get_context()._inside_client_test = self._inside_client_test
- conn = self.get_context().connect(*args, **kw_args)
- global _lock, _all_contexts
- with _lock:
- _all_contexts.add(self._cxt.handler)
- return conn
- def disconnect(self, *args, **kw_args):
- global _lock, _all_contexts, _default_context
- with _lock:
- if _default_context == self.get_context():
- for cxt in _all_contexts:
- cxt.disconnect(*args, **kw_args)
- _all_contexts = set()
- else:
- self.get_context().disconnect(*args, **kw_args)
- if self.get_context() in _all_contexts:
- _all_contexts.remove(self.get_context())
- if len(_all_contexts) == 0:
- _explicitly_disable_client_mode()
- def remote(self, *args, **kwargs):
- return self.get_context().remote(*args, **kwargs)
- def __getattr__(self, name):
- return self.get_context().__getattr__(name)
- def is_connected(self, *args, **kwargs):
- return self.get_context().is_connected(*args, **kwargs)
- def init(self, *args, **kwargs):
- ret = self.get_context().init(*args, **kwargs)
- global _lock, _all_contexts
- with _lock:
- _all_contexts.add(self._cxt.handler)
- return ret
- def shutdown(self, *args, **kwargs):
- global _lock, _all_contexts
- with _lock:
- if _default_context == self.get_context():
- for cxt in _all_contexts:
- cxt.shutdown(*args, **kwargs)
- _all_contexts = set()
- else:
- self.get_context().shutdown(*args, **kwargs)
- if self.get_context() in _all_contexts:
- _all_contexts.remove(self.get_context())
- if len(_all_contexts) == 0:
- _explicitly_disable_client_mode()
- ray = RayAPIStub()
- @DeveloperAPI
- def num_connected_contexts():
- """Return the number of client connections active."""
- global _lock, _all_contexts
- with _lock:
- return len(_all_contexts)
- # Someday we might add methods in this module so that someone who
- # tries to `import ray_client as ray` -- as a module, instead of
- # `from ray_client import ray` -- as the API stub
- # still gets expected functionality. This is the way the ray package
- # worked in the past.
- #
- # This really calls for PEP 562: https://www.python.org/dev/peps/pep-0562/
- # But until Python 3.6 is EOL, here we are.
|