| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968 |
- """This file includes the Worker class which sits on the client side.
- It implements the Ray API functions that are forwarded through grpc calls
- to the server.
- """
- import base64
- import json
- import logging
- import os
- import queue
- import tempfile
- import threading
- import time
- import uuid
- import warnings
- from collections import defaultdict
- from concurrent.futures import Future
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
- import grpc
- import ray.cloudpickle as cloudpickle
- import ray.core.generated.ray_client_pb2 as ray_client_pb2
- import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
- from ray._private.ray_constants import (
- DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD,
- env_float,
- env_integer,
- )
- from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
- from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
- # Use cloudpickle's version of pickle for UnpicklingError
- from ray.cloudpickle.compat import pickle
- from ray.exceptions import GetTimeoutError
- from ray.job_config import JobConfig
- from ray.util.client.client_pickler import dumps_from_client, loads_from_server
- from ray.util.client.common import (
- GRPC_OPTIONS,
- GRPC_UNRECOVERABLE_ERRORS,
- INT32_MAX,
- OBJECT_TRANSFER_WARNING_SIZE,
- ClientActorClass,
- ClientActorHandle,
- ClientActorRef,
- ClientObjectRef,
- ClientRemoteFunc,
- ClientStub,
- )
- from ray.util.client.dataclient import DataClient
- from ray.util.client.logsclient import LogstreamClient
- from ray.util.debug import log_once
- if TYPE_CHECKING:
- from ray.actor import ActorClass
- from ray.remote_function import RemoteFunction
- logger = logging.getLogger(__name__)
- INITIAL_TIMEOUT_SEC = env_integer("RAY_CLIENT_INITIAL_CONNECTION_TIMEOUT_S", 5)
- MAX_TIMEOUT_SEC = env_integer("RAY_CLIENT_MAX_CONNECTION_TIMEOUT_S", 30)
- # The max amount of time an operation can run blocking in the server. This
- # allows for Ctrl-C of the client to work without explicitly cancelling server
- # operations.
- MAX_BLOCKING_OPERATION_TIME_S: float = env_float(
- "RAY_CLIENT_MAX_BLOCKING_OPERATION_TIME_S", 2.0
- )
- # If the total size (bytes) of all outbound messages to schedule tasks since
- # the connection began exceeds this value, a warning should be raised
- MESSAGE_SIZE_THRESHOLD = 10 * 2**20 # 10 MB
- # Links to the Ray Design Pattern doc to use in the task overhead warning
- # message
- DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = "https://docs.ray.io/en/latest/ray-core/patterns/too-fine-grained-tasks.html" # noqa E501
- DESIGN_PATTERN_LARGE_OBJECTS_LINK = "https://docs.ray.io/en/latest/ray-core/patterns/closure-capture-large-objects.html" # noqa E501
- def backoff(timeout: int) -> int:
- timeout = timeout + 5
- if timeout > MAX_TIMEOUT_SEC:
- timeout = MAX_TIMEOUT_SEC
- return timeout
- class Worker:
- def __init__(
- self,
- conn_str: str = "",
- secure: bool = False,
- metadata: List[Tuple[str, str]] = None,
- connection_retries: int = 3,
- _credentials: Optional[grpc.ChannelCredentials] = None,
- ):
- """Initializes the worker side grpc client.
- Args:
- conn_str: The host:port connection string for the ray server.
- secure: whether to use SSL secure channel or not.
- metadata: additional metadata passed in the grpc request headers.
- connection_retries: Number of times to attempt to reconnect to the
- ray server if it doesn't respond immediately. Setting to 0 tries
- at least once. For infinite retries, catch the ConnectionError
- exception.
- _credentials: gprc channel credentials. Default ones will be used
- if None.
- """
- self._client_id = make_client_id()
- self.metadata = [("client_id", self._client_id)] + (
- metadata if metadata else []
- )
- self.channel = None
- self.server = None
- self._conn_state = grpc.ChannelConnectivity.IDLE
- self._converted: Dict[str, ClientStub] = {}
- self._secure = secure or os.environ.get("RAY_USE_TLS", "0").lower() in (
- "1",
- "true",
- )
- self._conn_str = conn_str
- self._connection_retries = connection_retries
- if _credentials is not None:
- self._credentials = _credentials
- self._secure = True
- else:
- self._credentials = None
- self._reconnect_grace_period = DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD
- if "RAY_CLIENT_RECONNECT_GRACE_PERIOD" in os.environ:
- # Use value in environment variable if available
- self._reconnect_grace_period = int(
- os.environ["RAY_CLIENT_RECONNECT_GRACE_PERIOD"]
- )
- # Disable retries if grace period is set to 0
- self._reconnect_enabled = self._reconnect_grace_period != 0
- # Set to True when the connection cannot be recovered and reconnect
- # attempts should be stopped
- self._in_shutdown = False
- # Set to True after initial connection succeeds
- self._has_connected = False
- self._connect_channel()
- self._has_connected = True
- # Has Ray been initialized on the server?
- self._serverside_ray_initialized = False
- # Initialize the streams to finish protocol negotiation.
- self.data_client = DataClient(self, self._client_id, self.metadata)
- self.reference_count: Dict[bytes, int] = defaultdict(int)
- self.log_client = LogstreamClient(self, self.metadata)
- self.log_client.set_logstream_level(logging.INFO)
- self.closed = False
- # Track this value to raise a warning if a lot of data are transferred.
- self.total_outbound_message_size_bytes = 0
- # Used to create unique IDs for RPCs to the RayletServicer
- self._req_id_lock = threading.Lock()
- self._req_id = 0
- # ReleaseObject grabs a lock, so it should not be called directly from
- # __del__ methods that may be executed at any time on the Python main thread.
- self._release_queue = queue.SimpleQueue()
- self._release_thread = threading.Thread(
- target=self._release_server_worker, daemon=True
- )
- self._release_thread.start()
- def _connect_channel(self, reconnecting=False) -> None:
- """
- Attempts to connect to the server specified by conn_str. If
- reconnecting after an RPC error, cleans up the old channel and
- continues to attempt to connect until the grace period is over.
- """
- if self.channel is not None:
- self.channel.unsubscribe(self._on_channel_state_change)
- self.channel.close()
- from ray._private.grpc_utils import init_grpc_channel
- # Prepare credentials if secure connection is requested
- credentials = None
- if self._secure:
- if self._credentials is not None:
- credentials = self._credentials
- elif os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"):
- # init_grpc_channel will handle this via load_certs_from_env()
- credentials = None
- else:
- # Default SSL credentials (no specific certs)
- credentials = grpc.ssl_channel_credentials()
- # Create channel with auth interceptors via helper
- # This automatically adds auth interceptors when token auth is enabled
- self.channel = init_grpc_channel(
- self._conn_str,
- options=GRPC_OPTIONS,
- asynchronous=False,
- credentials=credentials,
- )
- self.channel.subscribe(self._on_channel_state_change)
- # Retry the connection until the channel responds to something
- # looking like a gRPC connection, though it may be a proxy.
- start_time = time.time()
- conn_attempts = 0
- timeout = INITIAL_TIMEOUT_SEC
- service_ready = False
- while conn_attempts < max(self._connection_retries, 1) or reconnecting:
- conn_attempts += 1
- if self._in_shutdown:
- # User manually closed the worker before connection finished
- break
- elapsed_time = time.time() - start_time
- if reconnecting and elapsed_time > self._reconnect_grace_period:
- self._in_shutdown = True
- raise ConnectionError(
- "Failed to reconnect within the reconnection grace period "
- f"({self._reconnect_grace_period}s)"
- )
- try:
- # Let gRPC wait for us to see if the channel becomes ready.
- # If it throws, we couldn't connect.
- grpc.channel_ready_future(self.channel).result(timeout=timeout)
- # The HTTP2 channel is ready. Wrap the channel with the
- # RayletDriverStub, allowing for unary requests.
- self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
- service_ready = bool(self.ping_server())
- if service_ready:
- break
- # Ray is not ready yet, wait a timeout
- time.sleep(timeout)
- except grpc.FutureTimeoutError:
- logger.debug(f"Couldn't connect channel in {timeout} seconds, retrying")
- # Note that channel_ready_future constitutes its own timeout,
- # which is why we do not sleep here.
- except grpc.RpcError as e:
- logger.debug(
- f"Ray client server unavailable, retrying in {timeout}s..."
- )
- logger.debug(f"Received when checking init: {e.details()}")
- # Ray is not ready yet, wait a timeout.
- time.sleep(timeout)
- # Fallthrough, backoff, and retry at the top of the loop
- logger.debug(
- f"Waiting for Ray to become ready on the server, retry in {timeout}s..."
- )
- if not reconnecting:
- # Don't increase backoff when trying to reconnect --
- # we already know the server exists, attempt to reconnect
- # as soon as we can
- timeout = backoff(timeout)
- # If we made it through the loop without service_ready
- # it means we've used up our retries and
- # should error back to the user.
- if not service_ready:
- self._in_shutdown = True
- if log_once("ray_client_security_groups"):
- warnings.warn(
- "Ray Client connection timed out. Ensure that "
- "the Ray Client port on the head node is reachable "
- "from your local machine. See https://docs.ray.io/en"
- "/latest/cluster/ray-client.html#step-2-check-ports for "
- "more information."
- )
- raise ConnectionError("ray client connection timeout")
- def _can_reconnect(self, e: grpc.RpcError) -> bool:
- """
- Returns True if the RPC error can be recovered from and a retry is
- appropriate, false otherwise.
- """
- if not self._reconnect_enabled:
- return False
- if self._in_shutdown:
- # Channel is being shutdown, don't try to reconnect
- return False
- if e.code() in GRPC_UNRECOVERABLE_ERRORS:
- # Unrecoverable error -- These errors are specifically raised
- # by the server's application logic
- return False
- if e.code() == grpc.StatusCode.INTERNAL:
- details = e.details()
- if details == "Exception serializing request!":
- # The client failed tried to send a bad request (for example,
- # passing "None" instead of a valid grpc message). Don't
- # try to reconnect/retry.
- return False
- # All other errors can be treated as recoverable
- return True
- def _call_stub(self, stub_name: str, *args, **kwargs) -> Any:
- """
- Calls the stub specified by stub_name (Schedule, WaitObject, etc...).
- If a recoverable error occurrs while calling the stub, attempts to
- retry the RPC.
- """
- while not self._in_shutdown:
- try:
- return getattr(self.server, stub_name)(*args, **kwargs)
- except grpc.RpcError as e:
- if self._can_reconnect(e):
- time.sleep(0.5)
- continue
- raise
- except ValueError:
- # Trying to use the stub on a cancelled channel will raise
- # ValueError. This should only happen when the data client
- # is attempting to reset the connection -- sleep and try
- # again.
- time.sleep(0.5)
- continue
- raise ConnectionError("Client is shutting down.")
- def _get_object_iterator(
- self, req: ray_client_pb2.GetRequest, *args, **kwargs
- ) -> Any:
- """
- Calls the stub for GetObject on the underlying server stub. If a
- recoverable error occurs while streaming the response, attempts
- to retry the get starting from the first chunk that hasn't been
- received.
- """
- last_seen_chunk = -1
- while not self._in_shutdown:
- # If we disconnect partway through, restart the get request
- # at the first chunk we haven't seen
- req.start_chunk_id = last_seen_chunk + 1
- try:
- for chunk in self.server.GetObject(req, *args, **kwargs):
- if chunk.chunk_id <= last_seen_chunk:
- # Ignore repeat chunks
- logger.debug(
- f"Received a repeated chunk {chunk.chunk_id} "
- f"from request {req.req_id}."
- )
- continue
- if last_seen_chunk + 1 != chunk.chunk_id:
- raise RuntimeError(
- f"Received chunk {chunk.chunk_id} when we expected "
- f"{self.last_seen_chunk + 1}"
- )
- last_seen_chunk = chunk.chunk_id
- yield chunk
- if last_seen_chunk == chunk.total_chunks - 1:
- # We've yielded the last chunk, exit early
- return
- return
- except grpc.RpcError as e:
- if self._can_reconnect(e):
- time.sleep(0.5)
- continue
- raise
- except ValueError:
- # Trying to use the stub on a cancelled channel will raise
- # ValueError. This should only happen when the data client
- # is attempting to reset the connection -- sleep and try
- # again.
- time.sleep(0.5)
- continue
- raise ConnectionError("Client is shutting down.")
- def _add_ids_to_metadata(self, metadata: Any):
- """
- Adds a unique req_id and the current thread's identifier to the
- metadata. These values are useful for preventing mutating operations
- from being replayed on the server side in the event that the client
- must retry a requsest.
- Args:
- metadata: the gRPC metadata to append the IDs to
- """
- if not self._reconnect_enabled:
- # IDs not needed if the reconnects are disabled
- return metadata
- thread_id = str(threading.get_ident())
- with self._req_id_lock:
- self._req_id += 1
- if self._req_id > INT32_MAX:
- self._req_id = 1
- req_id = str(self._req_id)
- return metadata + [("thread_id", thread_id), ("req_id", req_id)]
- def _on_channel_state_change(self, conn_state: grpc.ChannelConnectivity):
- logger.debug(f"client gRPC channel state change: {conn_state}")
- self._conn_state = conn_state
- def connection_info(self):
- try:
- data = self.data_client.ConnectionInfo()
- except grpc.RpcError as e:
- raise decode_exception(e)
- return {
- "num_clients": data.num_clients,
- "python_version": data.python_version,
- "ray_version": data.ray_version,
- "ray_commit": data.ray_commit,
- }
- def register_callback(
- self,
- ref: ClientObjectRef,
- callback: Callable[[ray_client_pb2.DataResponse], None],
- ) -> None:
- req = ray_client_pb2.GetRequest(ids=[ref.id], asynchronous=True)
- self.data_client.RegisterGetCallback(req, callback)
- def get(self, vals, *, timeout: Optional[float] = None) -> Any:
- if isinstance(vals, list):
- if not vals:
- return []
- to_get = vals
- elif isinstance(vals, ClientObjectRef):
- to_get = [vals]
- else:
- raise Exception(
- "Can't get something that's not a "
- "list of IDs or just an ID: %s" % type(vals)
- )
- if timeout is None:
- deadline = None
- else:
- deadline = time.monotonic() + timeout
- while True:
- if deadline:
- op_timeout = min(
- MAX_BLOCKING_OPERATION_TIME_S,
- max(deadline - time.monotonic(), 0.001),
- )
- else:
- op_timeout = MAX_BLOCKING_OPERATION_TIME_S
- try:
- res = self._get(to_get, op_timeout)
- break
- except GetTimeoutError:
- if deadline and time.monotonic() > deadline:
- raise
- logger.debug("Internal retry for get {}".format(to_get))
- if len(to_get) != len(res):
- raise Exception(
- "Mismatched number of items in request ({}) and response ({})".format(
- len(to_get), len(res)
- )
- )
- if isinstance(vals, ClientObjectRef):
- res = res[0]
- return res
- def _get(self, ref: List[ClientObjectRef], timeout: float):
- req = ray_client_pb2.GetRequest(ids=[r.id for r in ref], timeout=timeout)
- data = bytearray()
- try:
- resp = self._get_object_iterator(req, metadata=self.metadata)
- for chunk in resp:
- if not chunk.valid:
- try:
- err = cloudpickle.loads(chunk.error)
- except (pickle.UnpicklingError, TypeError):
- logger.exception("Failed to deserialize {}".format(chunk.error))
- raise
- raise err
- if chunk.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once(
- "client_object_transfer_size_warning"
- ):
- size_gb = chunk.total_size / 2**30
- warnings.warn(
- "Ray Client is attempting to retrieve a "
- f"{size_gb:.2f} GiB object over the network, which may "
- "be slow. Consider serializing the object to a file "
- "and using S3 or rsync instead.",
- UserWarning,
- stacklevel=5,
- )
- data.extend(chunk.data)
- except grpc.RpcError as e:
- raise decode_exception(e)
- return loads_from_server(data)
- def put(
- self,
- val,
- *,
- client_ref_id: bytes = None,
- _owner: Optional[ClientActorHandle] = None,
- ):
- if isinstance(val, ClientObjectRef):
- raise TypeError(
- "Calling 'put' on an ObjectRef is not allowed "
- "(similarly, returning an ObjectRef from a remote "
- "function is not allowed). If you really want to "
- "do this, you can wrap the ObjectRef in a list and "
- "call 'put' on it (or return it)."
- )
- data = dumps_from_client(val, self._client_id)
- return self._put_pickled(data, client_ref_id, _owner)
- def _put_pickled(
- self, data, client_ref_id: bytes, owner: Optional[ClientActorHandle] = None
- ):
- req = ray_client_pb2.PutRequest(data=data)
- if client_ref_id is not None:
- req.client_ref_id = client_ref_id
- if owner is not None:
- req.owner_id = owner.actor_ref.id
- resp = self.data_client.PutObject(req)
- if not resp.valid:
- try:
- raise cloudpickle.loads(resp.error)
- except (pickle.UnpicklingError, TypeError):
- logger.exception("Failed to deserialize {}".format(resp.error))
- raise
- return ClientObjectRef(resp.id)
- # TODO(ekl) respect MAX_BLOCKING_OPERATION_TIME_S for wait too
- def wait(
- self,
- object_refs: List[ClientObjectRef],
- *,
- num_returns: int = 1,
- timeout: float = None,
- fetch_local: bool = True,
- ) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
- if not isinstance(object_refs, list):
- raise TypeError(
- f"wait() expected a list of ClientObjectRef, got {type(object_refs)}"
- )
- for ref in object_refs:
- if not isinstance(ref, ClientObjectRef):
- raise TypeError(
- "wait() expected a list of ClientObjectRef, "
- f"got list containing {type(ref)}"
- )
- data = {
- "object_ids": [object_ref.id for object_ref in object_refs],
- "num_returns": num_returns,
- "timeout": timeout if (timeout is not None) else -1,
- "client_id": self._client_id,
- }
- req = ray_client_pb2.WaitRequest(**data)
- resp = self._call_stub("WaitObject", req, metadata=self.metadata)
- if not resp.valid:
- # TODO(ameer): improve error/exceptions messages.
- raise Exception("Client Wait request failed. Reference invalid?")
- client_ready_object_ids = [
- ClientObjectRef(ref) for ref in resp.ready_object_ids
- ]
- client_remaining_object_ids = [
- ClientObjectRef(ref) for ref in resp.remaining_object_ids
- ]
- return (client_ready_object_ids, client_remaining_object_ids)
- def call_remote(self, instance, *args, **kwargs) -> List[Future]:
- task = instance._prepare_client_task()
- # data is serialized tuple of (args, kwargs)
- task.data = dumps_from_client((args, kwargs), self._client_id)
- num_returns = instance._num_returns()
- if num_returns == "dynamic":
- num_returns = -1
- if num_returns == "streaming":
- raise RuntimeError(
- 'Streaming actor methods (num_returns="streaming") '
- "are not currently supported when using Ray Client."
- )
- return self._call_schedule_for_task(task, num_returns)
- def _call_schedule_for_task(
- self, task: ray_client_pb2.ClientTask, num_returns: Optional[int]
- ) -> List[Future]:
- logger.debug(f"Scheduling task {task.name} {task.type} {task.payload_id}")
- task.client_id = self._client_id
- if num_returns is None:
- num_returns = 1
- num_return_refs = num_returns
- if num_return_refs == -1:
- num_return_refs = 1
- id_futures = [Future() for _ in range(num_return_refs)]
- def populate_ids(resp: Union[ray_client_pb2.DataResponse, Exception]) -> None:
- if isinstance(resp, Exception):
- if isinstance(resp, grpc.RpcError):
- resp = decode_exception(resp)
- for future in id_futures:
- future.set_exception(resp)
- return
- ticket = resp.task_ticket
- if not ticket.valid:
- try:
- ex = cloudpickle.loads(ticket.error)
- except (pickle.UnpicklingError, TypeError) as e_new:
- ex = e_new
- for future in id_futures:
- future.set_exception(ex)
- return
- if len(ticket.return_ids) != num_return_refs:
- exc = ValueError(
- f"Expected {num_return_refs} returns but received "
- f"{len(ticket.return_ids)}"
- )
- for future, raw_id in zip(id_futures, ticket.return_ids):
- future.set_exception(exc)
- return
- for future, raw_id in zip(id_futures, ticket.return_ids):
- future.set_result(raw_id)
- self.data_client.Schedule(task, populate_ids)
- self.total_outbound_message_size_bytes += task.ByteSize()
- if (
- self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD
- and log_once("client_communication_overhead_warning")
- ):
- warnings.warn(
- "More than 10MB of messages have been created to schedule "
- "tasks on the server. This can be slow on Ray Client due to "
- "communication overhead over the network. If you're running "
- "many fine-grained tasks, consider running them inside a "
- 'single remote function. See the section on "Too '
- 'fine-grained tasks" in the Ray Design Patterns document for '
- f"more details: {DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK}. If "
- "your functions frequently use large objects, consider "
- "storing the objects remotely with ray.put. An example of "
- 'this is shown in the "Closure capture of large / '
- 'unserializable object" section of the Ray Design Patterns '
- "document, available here: "
- f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}",
- UserWarning,
- )
- return id_futures
- def call_release(self, id: bytes) -> None:
- if self.closed:
- return
- self.reference_count[id] -= 1
- if self.reference_count[id] == 0:
- self._release_server(id)
- del self.reference_count[id]
- def _release_server(self, id: bytes) -> None:
- if self.data_client is not None:
- logger.debug(f"Put {id.hex()} to release queue")
- self._release_queue.put(id)
- def _release_server_worker(self):
- """Background thread to release objects from the server.
- Runs forever until a sentinel is received.
- """
- while not self.closed:
- try:
- id = self._release_queue.get(timeout=1)
- if id is None: # Sentinel value for shutdown
- logger.debug("Received sentinel, will stop release thread.")
- break
- if self.data_client is not None:
- logger.debug(f"Releasing {id.hex()}")
- try:
- self.data_client.ReleaseObject(
- ray_client_pb2.ReleaseRequest(ids=[id])
- )
- except Exception as e:
- # Log the error but continue processing
- # This prevents the release thread from crashing
- logger.warning(
- f"Failed to release object {id.hex()}: {e}. "
- "This is expected if the connection is closed."
- )
- except queue.Empty:
- continue
- logger.debug("Release thread finished.")
- def call_retain(self, id: bytes) -> None:
- logger.debug(f"Retaining {id.hex()}")
- self.reference_count[id] += 1
- def close(self):
- self._in_shutdown = True
- self._release_queue.put(None) # Sentinel
- timeout = 5
- self._release_thread.join(timeout=timeout)
- if self._release_thread.is_alive():
- logger.warning(f"The release thread failed to join in {timeout}s.")
- self.closed = True
- self.data_client.close()
- self.log_client.close()
- self.server = None
- if self.channel:
- self.channel.close()
- self.channel = None
- def get_actor(
- self, name: str, namespace: Optional[str] = None
- ) -> ClientActorHandle:
- task = ray_client_pb2.ClientTask()
- task.type = ray_client_pb2.ClientTask.NAMED_ACTOR
- task.name = name
- task.namespace = namespace or ""
- # Populate task.data with empty args and kwargs
- task.data = dumps_from_client(([], {}), self._client_id)
- futures = self._call_schedule_for_task(task, 1)
- assert len(futures) == 1
- handle = ClientActorHandle(ClientActorRef(futures[0], weak_ref=True))
- # `actor_ref.is_nil()` waits until the underlying ID is resolved.
- # This is needed because `get_actor` is often used to check the
- # existence of an actor.
- if handle.actor_ref.is_nil():
- raise ValueError(f"ActorID for {name} is empty")
- return handle
- def terminate_actor(self, actor: ClientActorHandle, no_restart: bool) -> None:
- if not isinstance(actor, ClientActorHandle):
- raise ValueError(
- "ray.kill() only supported for actors. Got: {}.".format(type(actor))
- )
- term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
- term_actor.id = actor.actor_ref.id
- term_actor.no_restart = no_restart
- term = ray_client_pb2.TerminateRequest(actor=term_actor)
- term.client_id = self._client_id
- try:
- self.data_client.Terminate(term)
- except grpc.RpcError as e:
- raise decode_exception(e)
- def terminate_task(
- self, obj: ClientObjectRef, force: bool, recursive: bool
- ) -> None:
- if not isinstance(obj, ClientObjectRef):
- raise TypeError(
- "ray.cancel() only supported for non-actor object refs. "
- f"Got: {type(obj)}."
- )
- term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate()
- term_object.id = obj.id
- term_object.force = force
- term_object.recursive = recursive
- term = ray_client_pb2.TerminateRequest(task_object=term_object)
- term.client_id = self._client_id
- try:
- self.data_client.Terminate(term)
- except grpc.RpcError as e:
- raise decode_exception(e)
- def get_cluster_info(
- self,
- req_type: ray_client_pb2.ClusterInfoType.TypeEnum,
- timeout: Optional[float] = None,
- ):
- req = ray_client_pb2.ClusterInfoRequest()
- req.type = req_type
- resp = self.server.ClusterInfo(req, timeout=timeout, metadata=self.metadata)
- if resp.WhichOneof("response_type") == "resource_table":
- # translate from a proto map to a python dict
- output_dict = dict(resp.resource_table.table)
- return output_dict
- elif resp.WhichOneof("response_type") == "runtime_context":
- return resp.runtime_context
- return json.loads(resp.json)
- def internal_kv_get(self, key: bytes, namespace: Optional[bytes]) -> bytes:
- req = ray_client_pb2.KVGetRequest(key=key, namespace=namespace)
- try:
- resp = self._call_stub("KVGet", req, metadata=self.metadata)
- except grpc.RpcError as e:
- raise decode_exception(e)
- if resp.HasField("value"):
- return resp.value
- # Value is None when the key does not exist in the KV.
- return None
- def internal_kv_exists(self, key: bytes, namespace: Optional[bytes]) -> bool:
- req = ray_client_pb2.KVExistsRequest(key=key, namespace=namespace)
- try:
- resp = self._call_stub("KVExists", req, metadata=self.metadata)
- except grpc.RpcError as e:
- raise decode_exception(e)
- return resp.exists
- def internal_kv_put(
- self, key: bytes, value: bytes, overwrite: bool, namespace: Optional[bytes]
- ) -> bool:
- req = ray_client_pb2.KVPutRequest(
- key=key, value=value, overwrite=overwrite, namespace=namespace
- )
- metadata = self._add_ids_to_metadata(self.metadata)
- try:
- resp = self._call_stub("KVPut", req, metadata=metadata)
- except grpc.RpcError as e:
- raise decode_exception(e)
- return resp.already_exists
- def internal_kv_del(
- self, key: bytes, del_by_prefix: bool, namespace: Optional[bytes]
- ) -> int:
- req = ray_client_pb2.KVDelRequest(
- key=key, del_by_prefix=del_by_prefix, namespace=namespace
- )
- metadata = self._add_ids_to_metadata(self.metadata)
- try:
- resp = self._call_stub("KVDel", req, metadata=metadata)
- except grpc.RpcError as e:
- raise decode_exception(e)
- return resp.deleted_num
- def internal_kv_list(
- self, prefix: bytes, namespace: Optional[bytes]
- ) -> List[bytes]:
- try:
- req = ray_client_pb2.KVListRequest(prefix=prefix, namespace=namespace)
- return self._call_stub("KVList", req, metadata=self.metadata).keys
- except grpc.RpcError as e:
- raise decode_exception(e)
- def pin_runtime_env_uri(self, uri: str, expiration_s: int) -> None:
- req = ray_client_pb2.ClientPinRuntimeEnvURIRequest(
- uri=uri, expiration_s=expiration_s
- )
- self._call_stub("PinRuntimeEnvURI", req, metadata=self.metadata)
- def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]:
- req = ray_client_pb2.ClientListNamedActorsRequest(all_namespaces=all_namespaces)
- return json.loads(self.data_client.ListNamedActors(req).actors_json)
- def is_initialized(self) -> bool:
- if not self.is_connected() or self.server is None:
- return False
- if not self._serverside_ray_initialized:
- # We only check that Ray is initialized on the server once to
- # avoid making an RPC every time this function is called. This is
- # safe to do because Ray only 'un-initializes' on the server when
- # the Client connection is torn down.
- self._serverside_ray_initialized = self.get_cluster_info(
- ray_client_pb2.ClusterInfoType.IS_INITIALIZED
- )
- return self._serverside_ray_initialized
- def ping_server(self, timeout=None) -> bool:
- """Simple health check.
- Piggybacks the IS_INITIALIZED call to check if the server provides
- an actual response.
- """
- if self.server is not None:
- logger.debug("Pinging server.")
- result = self.get_cluster_info(
- ray_client_pb2.ClusterInfoType.PING, timeout=timeout
- )
- return result is not None
- return False
- def is_connected(self) -> bool:
- return not self._in_shutdown and self._has_connected
- def _server_init(
- self, job_config: JobConfig, ray_init_kwargs: Optional[Dict[str, Any]] = None
- ):
- """Initialize the server"""
- if ray_init_kwargs is None:
- ray_init_kwargs = {}
- try:
- if job_config is None:
- serialized_job_config = None
- else:
- with tempfile.TemporaryDirectory() as tmp_dir:
- from ray._private.ray_constants import (
- RAY_RUNTIME_ENV_IGNORE_GITIGNORE,
- )
- runtime_env = job_config.runtime_env or {}
- # Determine whether to respect .gitignore files based on environment variable
- # Default is True (respect .gitignore). Set to False if env var is "1".
- include_gitignore = (
- os.environ.get(RAY_RUNTIME_ENV_IGNORE_GITIGNORE, "0") != "1"
- )
- runtime_env = upload_py_modules_if_needed(
- runtime_env,
- scratch_dir=tmp_dir,
- include_gitignore=include_gitignore,
- logger=logger,
- )
- runtime_env = upload_working_dir_if_needed(
- runtime_env,
- scratch_dir=tmp_dir,
- include_gitignore=include_gitignore,
- logger=logger,
- )
- # Remove excludes, it isn't relevant after the upload step.
- runtime_env.pop("excludes", None)
- job_config.set_runtime_env(runtime_env, validate=True)
- serialized_job_config = pickle.dumps(job_config)
- response = self.data_client.Init(
- ray_client_pb2.InitRequest(
- job_config=serialized_job_config,
- ray_init_kwargs=json.dumps(ray_init_kwargs),
- reconnect_grace_period=self._reconnect_grace_period,
- )
- )
- if not response.ok:
- raise ConnectionAbortedError(
- f"Initialization failure from server:\n{response.msg}"
- )
- except grpc.RpcError as e:
- raise decode_exception(e)
- def _convert_actor(self, actor: "ActorClass") -> str:
- """Register a ClientActorClass for the ActorClass and return a UUID"""
- key = uuid.uuid4().hex
- cls = actor.__ray_metadata__.modified_class
- self._converted[key] = ClientActorClass(cls, options=actor._default_options)
- return key
- def _convert_function(self, func: "RemoteFunction") -> str:
- """Register a ClientRemoteFunc for the ActorClass and return a UUID"""
- key = uuid.uuid4().hex
- self._converted[key] = ClientRemoteFunc(
- func._function, options=func._default_options
- )
- return key
- def _get_converted(self, key: str) -> "ClientStub":
- """Given a UUID, return the converted object"""
- return self._converted[key]
- def _converted_key_exists(self, key: str) -> bool:
- """Check if a key UUID is present in the store of converted objects."""
- return key in self._converted
- def _dumps_from_client(self, val) -> bytes:
- return dumps_from_client(val, self._client_id)
- def make_client_id() -> str:
- id = uuid.uuid4()
- return id.hex
- def decode_exception(e: grpc.RpcError) -> Exception:
- if e.code() != grpc.StatusCode.ABORTED:
- # The ABORTED status code is used by the server when an application
- # error is serialized into the exception details. If the code
- # isn't ABORTED, then return the original error since there's no
- # serialized error to decode.
- # See server.py::return_exception_in_context for details
- return ConnectionError(f"GRPC connection failed: {e}")
- data = base64.standard_b64decode(e.details())
- return loads_from_server(data)
|