| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599 |
- """This file implements a threaded stream controller to abstract a data stream
- back to the ray clientserver.
- """
- import logging
- import math
- import queue
- import threading
- import warnings
- from collections import OrderedDict
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
- import grpc
- import ray.core.generated.ray_client_pb2 as ray_client_pb2
- import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
- from ray.util.client.common import (
- INT32_MAX,
- OBJECT_TRANSFER_CHUNK_SIZE,
- OBJECT_TRANSFER_WARNING_SIZE,
- )
- from ray.util.debug import log_once
- if TYPE_CHECKING:
- from ray.util.client.worker import Worker
- logger = logging.getLogger(__name__)
- ResponseCallable = Callable[[Union[ray_client_pb2.DataResponse, Exception]], None]
- # Send an acknowledge on every 32nd response received
- ACKNOWLEDGE_BATCH_SIZE = 32
- def chunk_put(req: ray_client_pb2.DataRequest):
- """
- Chunks a put request. Doing this lazily is important for large objects,
- since taking slices of bytes objects does a copy. This means if we
- immediately materialized every chunk of a large object and inserted them
- into the result_queue, we would effectively double the memory needed
- on the client to handle the put.
- """
- # When accessing a protobuf field, deserialization is performed, which will
- # generate a copy. So we need to avoid accessing the `data` field multiple
- # times in the loop
- request_data = req.put.data
- total_size = len(request_data)
- assert total_size > 0, "Cannot chunk object with missing data"
- if total_size >= OBJECT_TRANSFER_WARNING_SIZE and log_once(
- "client_object_put_size_warning"
- ):
- size_gb = total_size / 2**30
- warnings.warn(
- "Ray Client is attempting to send a "
- f"{size_gb:.2f} GiB object over the network, which may "
- "be slow. Consider serializing the object and using a remote "
- "URI to transfer via S3 or Google Cloud Storage instead. "
- "Documentation for doing this can be found here: "
- "https://docs.ray.io/en/latest/handling-dependencies.html#remote-uris",
- UserWarning,
- )
- total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE)
- for chunk_id in range(0, total_chunks):
- start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
- end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE)
- chunk = ray_client_pb2.PutRequest(
- client_ref_id=req.put.client_ref_id,
- data=request_data[start:end],
- chunk_id=chunk_id,
- total_chunks=total_chunks,
- total_size=total_size,
- owner_id=req.put.owner_id,
- )
- yield ray_client_pb2.DataRequest(req_id=req.req_id, put=chunk)
- def chunk_task(req: ray_client_pb2.DataRequest):
- """
- Chunks a client task. Doing this lazily is important with large arguments,
- since taking slices of bytes objects does a copy. This means if we
- immediately materialized every chunk of a large argument and inserted them
- into the result_queue, we would effectively double the memory needed
- on the client to handle the task.
- """
- # When accessing a protobuf field, deserialization is performed, which will
- # generate a copy. So we need to avoid accessing the `data` field multiple
- # times in the loop
- request_data = req.task.data
- total_size = len(request_data)
- assert total_size > 0, "Cannot chunk object with missing data"
- total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE)
- for chunk_id in range(0, total_chunks):
- start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
- end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE)
- chunk = ray_client_pb2.ClientTask(
- type=req.task.type,
- name=req.task.name,
- payload_id=req.task.payload_id,
- client_id=req.task.client_id,
- options=req.task.options,
- baseline_options=req.task.baseline_options,
- namespace=req.task.namespace,
- data=request_data[start:end],
- chunk_id=chunk_id,
- total_chunks=total_chunks,
- )
- yield ray_client_pb2.DataRequest(req_id=req.req_id, task=chunk)
- class ChunkCollector:
- """
- This object collects chunks from async get requests via __call__, and
- calls the underlying callback when the object is fully received, or if an
- exception while retrieving the object occurs.
- This is not used in synchronous gets (synchronous gets interact with the
- raylet servicer directly, not through the datapath).
- __call__ returns true once the underlying call back has been called.
- """
- def __init__(self, callback: ResponseCallable, request: ray_client_pb2.DataRequest):
- # Bytearray containing data received so far
- self.data = bytearray()
- # The callback that will be called once all data is received
- self.callback = callback
- # The id of the last chunk we've received, or -1 if haven't seen any yet
- self.last_seen_chunk = -1
- # The GetRequest that initiated the transfer. start_chunk_id will be
- # updated as chunks are received to avoid re-requesting chunks that
- # we've already received.
- self.request = request
- def __call__(self, response: Union[ray_client_pb2.DataResponse, Exception]) -> bool:
- if isinstance(response, Exception):
- self.callback(response)
- return True
- get_resp = response.get
- if not get_resp.valid:
- self.callback(response)
- return True
- if get_resp.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once(
- "client_object_transfer_size_warning"
- ):
- size_gb = get_resp.total_size / 2**30
- warnings.warn(
- "Ray Client is attempting to retrieve a "
- f"{size_gb:.2f} GiB object over the network, which may "
- "be slow. Consider serializing the object to a file and "
- "using rsync or S3 instead.",
- UserWarning,
- )
- chunk_data = get_resp.data
- chunk_id = get_resp.chunk_id
- if chunk_id == self.last_seen_chunk + 1:
- self.data.extend(chunk_data)
- self.last_seen_chunk = chunk_id
- # If we disconnect partway through, restart the get request
- # at the first chunk we haven't seen
- self.request.get.start_chunk_id = self.last_seen_chunk + 1
- elif chunk_id > self.last_seen_chunk + 1:
- # A chunk was skipped. This shouldn't happen in practice since
- # grpc guarantees that chunks will arrive in order.
- msg = (
- f"Received chunk {chunk_id} when we expected "
- f"{self.last_seen_chunk + 1} for request {response.req_id}"
- )
- logger.warning(msg)
- self.callback(RuntimeError(msg))
- return True
- else:
- # We received a chunk that've already seen before. Ignore, since
- # it should already be appended to self.data.
- logger.debug(
- f"Received a repeated chunk {chunk_id} "
- f"from request {response.req_id}."
- )
- if get_resp.chunk_id == get_resp.total_chunks - 1:
- self.callback(self.data)
- return True
- else:
- # Not done yet
- return False
- class DataClient:
- def __init__(self, client_worker: "Worker", client_id: str, metadata: list):
- """Initializes a thread-safe datapath over a Ray Client gRPC channel.
- Args:
- client_worker: The Ray Client worker that manages this client
- client_id: the generated ID representing this client
- metadata: metadata to pass to gRPC requests
- """
- self.client_worker = client_worker
- self._client_id = client_id
- self._metadata = metadata
- self.data_thread = self._start_datathread()
- # Track outstanding requests to resend in case of disconnection
- self.outstanding_requests: Dict[int, Any] = OrderedDict()
- # Serialize access to all mutable internal states: self.request_queue,
- # self.ready_data, self.asyncio_waiting_data,
- # self._in_shutdown, self._req_id, self.outstanding_requests and
- # calling self._next_id()
- self.lock = threading.Lock()
- # Waiting for response or shutdown.
- self.cv = threading.Condition(lock=self.lock)
- self.request_queue = self._create_queue()
- self.ready_data: Dict[int, Any] = {}
- # NOTE: Dictionary insertion is guaranteed to complete before lookup
- # and/or removal because of synchronization via the request_queue.
- self.asyncio_waiting_data: Dict[int, ResponseCallable] = {}
- self._in_shutdown = False
- self._req_id = 0
- self._last_exception = None
- self._acknowledge_counter = 0
- self.data_thread.start()
- # Must hold self.lock when calling this function.
- def _next_id(self) -> int:
- assert self.lock.locked()
- self._req_id += 1
- if self._req_id > INT32_MAX:
- self._req_id = 1
- # Responses that aren't tracked (like opportunistic releases)
- # have req_id=0, so make sure we never mint such an id.
- assert self._req_id != 0
- return self._req_id
- def _start_datathread(self) -> threading.Thread:
- return threading.Thread(
- target=self._data_main,
- name="ray_client_streaming_rpc",
- args=(),
- daemon=True,
- )
- # A helper that takes requests from queue. If the request wraps a PutRequest,
- # lazily chunks and yields the request. Otherwise, yields the request directly.
- def _requests(self):
- while True:
- req = self.request_queue.get()
- if req is None:
- # Stop when client signals shutdown.
- return
- req_type = req.WhichOneof("type")
- if req_type == "put":
- yield from chunk_put(req)
- elif req_type == "task":
- yield from chunk_task(req)
- else:
- yield req
- def _data_main(self) -> None:
- reconnecting = False
- try:
- while not self.client_worker._in_shutdown:
- stub = ray_client_pb2_grpc.RayletDataStreamerStub(
- self.client_worker.channel
- )
- metadata = self._metadata + [("reconnecting", str(reconnecting))]
- resp_stream = stub.Datapath(
- self._requests(),
- metadata=metadata,
- wait_for_ready=True,
- )
- try:
- for response in resp_stream:
- self._process_response(response)
- return
- except grpc.RpcError as e:
- reconnecting = self._can_reconnect(e)
- if not reconnecting:
- self._last_exception = e
- return
- self._reconnect_channel()
- except Exception as e:
- self._last_exception = e
- finally:
- logger.debug("Shutting down data channel.")
- self._shutdown()
- def _process_response(self, response: Any) -> None:
- """
- Process responses from the data servicer.
- """
- if response.req_id == 0:
- # This is not being waited for.
- logger.debug(f"Got unawaited response {response}")
- return
- if response.req_id in self.asyncio_waiting_data:
- can_remove = True
- try:
- callback = self.asyncio_waiting_data[response.req_id]
- if isinstance(callback, ChunkCollector):
- can_remove = callback(response)
- elif callback:
- callback(response)
- if can_remove:
- # NOTE: calling del self.asyncio_waiting_data results
- # in the destructor of ClientObjectRef running, which
- # calls ReleaseObject(). So self.asyncio_waiting_data
- # is accessed without holding self.lock. Holding the
- # lock shouldn't be necessary either.
- del self.asyncio_waiting_data[response.req_id]
- except Exception:
- logger.exception("Callback error:")
- with self.lock:
- # Update outstanding requests
- if response.req_id in self.outstanding_requests and can_remove:
- del self.outstanding_requests[response.req_id]
- # Acknowledge response
- self._acknowledge(response.req_id)
- else:
- with self.lock:
- self.ready_data[response.req_id] = response
- self.cv.notify_all()
- def _can_reconnect(self, e: grpc.RpcError) -> bool:
- """
- Processes RPC errors that occur while reading from data stream.
- Returns True if the error can be recovered from, False otherwise.
- """
- if not self.client_worker._can_reconnect(e):
- logger.error("Unrecoverable error in data channel.")
- logger.debug(e)
- return False
- logger.debug("Recoverable error in data channel.")
- logger.debug(e)
- return True
- def _shutdown(self) -> None:
- """
- Shutdown the data channel
- """
- with self.lock:
- self._in_shutdown = True
- self.cv.notify_all()
- callbacks = self.asyncio_waiting_data.values()
- self.asyncio_waiting_data = {}
- if self._last_exception:
- # Abort async requests with the error.
- err = ConnectionError(
- "Failed during this or a previous request. Exception that "
- f"broke the connection: {self._last_exception}"
- )
- else:
- err = ConnectionError(
- "Request cannot be fulfilled because the data client has "
- "disconnected."
- )
- for callback in callbacks:
- if callback:
- callback(err)
- # Since self._in_shutdown is set to True, no new item
- # will be added to self.asyncio_waiting_data
- def _acknowledge(self, req_id: int) -> None:
- """
- Puts an acknowledge request on the request queue periodically.
- Lock should be held before calling this. Used when an async or
- blocking response is received.
- """
- if not self.client_worker._reconnect_enabled:
- # Skip ACKs if reconnect isn't enabled
- return
- assert self.lock.locked()
- self._acknowledge_counter += 1
- if self._acknowledge_counter % ACKNOWLEDGE_BATCH_SIZE == 0:
- self.request_queue.put(
- ray_client_pb2.DataRequest(
- acknowledge=ray_client_pb2.AcknowledgeRequest(req_id=req_id)
- )
- )
- def _reconnect_channel(self) -> None:
- """
- Attempts to reconnect the gRPC channel and resend outstanding
- requests. First, the server is pinged to see if the current channel
- still works. If the ping fails, then the current channel is closed
- and replaced with a new one.
- Once a working channel is available, a new request queue is made
- and filled with any outstanding requests to be resent to the server.
- """
- try:
- # Ping the server to see if the current channel is reuseable, for
- # example if gRPC reconnected the channel on its own or if the
- # RPC error was transient and the channel is still open
- ping_succeeded = self.client_worker.ping_server(timeout=5)
- except grpc.RpcError:
- ping_succeeded = False
- if not ping_succeeded:
- # Ping failed, try refreshing the data channel
- logger.warning(
- "Encountered connection issues in the data channel. "
- "Attempting to reconnect."
- )
- try:
- self.client_worker._connect_channel(reconnecting=True)
- except ConnectionError:
- logger.warning("Failed to reconnect the data channel")
- raise
- logger.debug("Reconnection succeeded!")
- # Recreate the request queue, and resend outstanding requests
- with self.lock:
- self.request_queue = self._create_queue()
- for request in self.outstanding_requests.values():
- # Resend outstanding requests
- self.request_queue.put(request)
- # Use SimpleQueue to avoid deadlocks when appending to queue from __del__()
- @staticmethod
- def _create_queue():
- return queue.SimpleQueue()
- def close(self) -> None:
- thread = None
- with self.lock:
- self._in_shutdown = True
- # Notify blocking operations to fail.
- self.cv.notify_all()
- # Add sentinel to terminate streaming RPC.
- if self.request_queue is not None:
- # Intentional shutdown, tell server it can clean up the
- # connection immediately and ignore the reconnect grace period.
- cleanup_request = ray_client_pb2.DataRequest(
- connection_cleanup=ray_client_pb2.ConnectionCleanupRequest()
- )
- self.request_queue.put(cleanup_request)
- self.request_queue.put(None)
- if self.data_thread is not None:
- thread = self.data_thread
- # Wait until streaming RPCs are done.
- if thread is not None:
- thread.join()
- def _blocking_send(
- self, req: ray_client_pb2.DataRequest
- ) -> ray_client_pb2.DataResponse:
- with self.lock:
- self._check_shutdown()
- req_id = self._next_id()
- req.req_id = req_id
- self.request_queue.put(req)
- self.outstanding_requests[req_id] = req
- self.cv.wait_for(lambda: req_id in self.ready_data or self._in_shutdown)
- self._check_shutdown()
- data = self.ready_data[req_id]
- del self.ready_data[req_id]
- del self.outstanding_requests[req_id]
- self._acknowledge(req_id)
- return data
- def _async_send(
- self,
- req: ray_client_pb2.DataRequest,
- callback: Optional[ResponseCallable] = None,
- ) -> None:
- with self.lock:
- self._check_shutdown()
- req_id = self._next_id()
- req.req_id = req_id
- self.asyncio_waiting_data[req_id] = callback
- self.outstanding_requests[req_id] = req
- self.request_queue.put(req)
- # Must hold self.lock when calling this function.
- def _check_shutdown(self):
- assert self.lock.locked()
- if not self._in_shutdown:
- return
- self.lock.release()
- # Do not try disconnect() or throw exceptions in self.data_thread.
- # Otherwise deadlock can occur.
- if threading.current_thread().ident == self.data_thread.ident:
- return
- from ray.util import disconnect
- disconnect()
- self.lock.acquire()
- if self._last_exception is not None:
- msg = (
- "Request can't be sent because the Ray client has already "
- "been disconnected due to an error. Last exception: "
- f"{self._last_exception}"
- )
- else:
- msg = (
- "Request can't be sent because the Ray client has already "
- "been disconnected."
- )
- raise ConnectionError(msg)
- def Init(
- self, request: ray_client_pb2.InitRequest, context=None
- ) -> ray_client_pb2.InitResponse:
- datareq = ray_client_pb2.DataRequest(
- init=request,
- )
- resp = self._blocking_send(datareq)
- return resp.init
- def PrepRuntimeEnv(
- self, request: ray_client_pb2.PrepRuntimeEnvRequest, context=None
- ) -> ray_client_pb2.PrepRuntimeEnvResponse:
- datareq = ray_client_pb2.DataRequest(
- prep_runtime_env=request,
- )
- resp = self._blocking_send(datareq)
- return resp.prep_runtime_env
- def ConnectionInfo(self, context=None) -> ray_client_pb2.ConnectionInfoResponse:
- datareq = ray_client_pb2.DataRequest(
- connection_info=ray_client_pb2.ConnectionInfoRequest()
- )
- resp = self._blocking_send(datareq)
- return resp.connection_info
- def GetObject(
- self, request: ray_client_pb2.GetRequest, context=None
- ) -> ray_client_pb2.GetResponse:
- datareq = ray_client_pb2.DataRequest(
- get=request,
- )
- resp = self._blocking_send(datareq)
- return resp.get
- def RegisterGetCallback(
- self, request: ray_client_pb2.GetRequest, callback: ResponseCallable
- ) -> None:
- if len(request.ids) != 1:
- raise ValueError(
- "RegisterGetCallback() must have exactly 1 Object ID. "
- f"Actual: {request}"
- )
- datareq = ray_client_pb2.DataRequest(
- get=request,
- )
- collector = ChunkCollector(callback=callback, request=datareq)
- self._async_send(datareq, collector)
- # TODO: convert PutObject to async
- def PutObject(
- self, request: ray_client_pb2.PutRequest, context=None
- ) -> ray_client_pb2.PutResponse:
- datareq = ray_client_pb2.DataRequest(
- put=request,
- )
- resp = self._blocking_send(datareq)
- return resp.put
- def ReleaseObject(
- self, request: ray_client_pb2.ReleaseRequest, context=None
- ) -> None:
- datareq = ray_client_pb2.DataRequest(
- release=request,
- )
- self._async_send(datareq)
- def Schedule(self, request: ray_client_pb2.ClientTask, callback: ResponseCallable):
- datareq = ray_client_pb2.DataRequest(task=request)
- self._async_send(datareq, callback)
- def Terminate(
- self, request: ray_client_pb2.TerminateRequest
- ) -> ray_client_pb2.TerminateResponse:
- req = ray_client_pb2.DataRequest(
- terminate=request,
- )
- resp = self._blocking_send(req)
- return resp.terminate
- def ListNamedActors(
- self, request: ray_client_pb2.ClientListNamedActorsRequest
- ) -> ray_client_pb2.ClientListNamedActorsResponse:
- req = ray_client_pb2.DataRequest(
- list_named_actors=request,
- )
- resp = self._blocking_send(req)
- return resp.list_named_actors
|