| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416 |
- import logging
- import sys
- import time
- from collections import defaultdict
- from queue import Queue
- from threading import Event, Lock, Thread
- from typing import TYPE_CHECKING, Any, Dict, Iterator, Union
- import grpc
- import ray
- import ray.core.generated.ray_client_pb2 as ray_client_pb2
- import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
- from ray._private.client_mode_hook import disable_client_hook
- from ray.util.client.common import (
- CLIENT_SERVER_MAX_THREADS,
- OrderedResponseCache,
- _propagate_error_in_context,
- )
- from ray.util.client.server.server_pickler import loads_from_client
- from ray.util.debug import log_once
- if TYPE_CHECKING:
- from ray.util.client.server.server import RayletServicer
- logger = logging.getLogger(__name__)
- QUEUE_JOIN_SECONDS = 10
- def _get_reconnecting_from_context(context: Any) -> bool:
- """
- Get `reconnecting` from gRPC metadata, or False if missing.
- """
- metadata = dict(context.invocation_metadata())
- val = metadata.get("reconnecting")
- if val is None or val not in ("True", "False"):
- logger.error(
- f'Client connecting with invalid value for "reconnecting": {val}, '
- "This may be because you have a mismatched client and server "
- "version."
- )
- return False
- return val == "True"
- def _should_cache(req: ray_client_pb2.DataRequest) -> bool:
- """
- Returns True if the response should to the given request should be cached,
- false otherwise. At the moment the only requests we do not cache are:
- - asynchronous gets: These arrive out of order. Skipping caching here
- is fine, since repeating an async get is idempotent
- - acks: Repeating acks is idempotent
- - clean up requests: Also idempotent, and client has likely already
- wrapped up the data connection by this point.
- - puts: We should only cache when we receive the final chunk, since
- any earlier chunks won't generate a response
- - tasks: We should only cache when we receive the final chunk,
- since any earlier chunks won't generate a response
- """
- req_type = req.WhichOneof("type")
- if req_type == "get" and req.get.asynchronous:
- return False
- if req_type == "put":
- return req.put.chunk_id == req.put.total_chunks - 1
- if req_type == "task":
- return req.task.chunk_id == req.task.total_chunks - 1
- return req_type not in ("acknowledge", "connection_cleanup")
- def fill_queue(
- grpc_input_generator: Iterator[ray_client_pb2.DataRequest],
- output_queue: "Queue[Union[ray_client_pb2.DataRequest, ray_client_pb2.DataResponse]]", # noqa: E501
- ) -> None:
- """
- Pushes incoming requests to a shared output_queue.
- """
- try:
- for req in grpc_input_generator:
- output_queue.put(req)
- except grpc.RpcError as e:
- logger.debug(
- "closing dataservicer reader thread "
- f"grpc error reading request_iterator: {e}"
- )
- finally:
- # Set the sentinel value for the output_queue
- output_queue.put(None)
- class ChunkCollector:
- """
- Helper class for collecting chunks from PutObject or ClientTask messages
- """
- def __init__(self):
- self.curr_req_id = None
- self.last_seen_chunk_id = -1
- self.data = bytearray()
- def add_chunk(
- self,
- req: ray_client_pb2.DataRequest,
- chunk: Union[ray_client_pb2.PutRequest, ray_client_pb2.ClientTask],
- ):
- if self.curr_req_id is not None and self.curr_req_id != req.req_id:
- raise RuntimeError(
- "Expected to receive a chunk from request with id "
- f"{self.curr_req_id}, but found {req.req_id} instead."
- )
- self.curr_req_id = req.req_id
- next_chunk = self.last_seen_chunk_id + 1
- if chunk.chunk_id < next_chunk:
- # Repeated chunk, ignore
- return
- if chunk.chunk_id > next_chunk:
- raise RuntimeError(
- f"A chunk {chunk.chunk_id} of request {req.req_id} was "
- "received out of order."
- )
- elif chunk.chunk_id == self.last_seen_chunk_id + 1:
- self.data.extend(chunk.data)
- self.last_seen_chunk_id = chunk.chunk_id
- return chunk.chunk_id + 1 == chunk.total_chunks
- def reset(self):
- self.curr_req_id = None
- self.last_seen_chunk_id = -1
- self.data = bytearray()
- class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
- def __init__(self, basic_service: "RayletServicer"):
- self.basic_service = basic_service
- self.clients_lock = Lock()
- self.num_clients = 0 # guarded by self.clients_lock
- # dictionary mapping client_id's to the last time they connected
- self.client_last_seen: Dict[str, float] = {}
- # dictionary mapping client_id's to their reconnect grace periods
- self.reconnect_grace_periods: Dict[str, float] = {}
- # dictionary mapping client_id's to their response cache
- self.response_caches: Dict[str, OrderedResponseCache] = defaultdict(
- OrderedResponseCache
- )
- # stopped event, useful for signals that the server is shut down
- self.stopped = Event()
- # Helper for collecting chunks from PutObject calls. Assumes that
- # that put requests from different objects aren't interleaved.
- self.put_request_chunk_collector = ChunkCollector()
- # Helper for collecting chunks from ClientTask calls. Assumes that
- # schedule requests from different remote calls aren't interleaved.
- self.client_task_chunk_collector = ChunkCollector()
- def Datapath(self, request_iterator, context):
- start_time = time.time()
- # set to True if client shuts down gracefully
- cleanup_requested = False
- metadata = dict(context.invocation_metadata())
- client_id = metadata.get("client_id")
- if client_id is None:
- logger.error("Client connecting with no client_id")
- return
- logger.debug(f"New data connection from client {client_id}: ")
- accepted_connection = self._init(client_id, context, start_time)
- response_cache = self.response_caches[client_id]
- # Set to False if client requests a reconnect grace period of 0
- reconnect_enabled = True
- if not accepted_connection:
- return
- try:
- request_queue = Queue()
- queue_filler_thread = Thread(
- target=fill_queue, daemon=True, args=(request_iterator, request_queue)
- )
- queue_filler_thread.start()
- """For non `async get` requests, this loop yields immediately
- For `async get` requests, this loop:
- 1) does not yield, it just continues
- 2) When the result is ready, it yields
- """
- for req in iter(request_queue.get, None):
- if isinstance(req, ray_client_pb2.DataResponse):
- # Early shortcut if this is the result of an async get.
- yield req
- continue
- assert isinstance(req, ray_client_pb2.DataRequest)
- if _should_cache(req) and reconnect_enabled:
- cached_resp = response_cache.check_cache(req.req_id)
- if isinstance(cached_resp, Exception):
- # Cache state is invalid, raise exception
- raise cached_resp
- if cached_resp is not None:
- yield cached_resp
- continue
- resp = None
- req_type = req.WhichOneof("type")
- if req_type == "init":
- resp_init = self.basic_service.Init(req.init)
- resp = ray_client_pb2.DataResponse(
- init=resp_init,
- )
- with self.clients_lock:
- self.reconnect_grace_periods[
- client_id
- ] = req.init.reconnect_grace_period
- if req.init.reconnect_grace_period == 0:
- reconnect_enabled = False
- elif req_type == "get":
- if req.get.asynchronous:
- get_resp = self.basic_service._async_get_object(
- req.get, client_id, req.req_id, request_queue
- )
- if get_resp is None:
- # Skip sending a response for this request and
- # continue to the next requst. The response for
- # this request will be sent when the object is
- # ready.
- continue
- else:
- get_resp = self.basic_service._get_object(req.get, client_id)
- resp = ray_client_pb2.DataResponse(get=get_resp)
- elif req_type == "put":
- if not self.put_request_chunk_collector.add_chunk(req, req.put):
- # Put request still in progress
- continue
- put_resp = self.basic_service._put_object(
- self.put_request_chunk_collector.data,
- req.put.client_ref_id,
- client_id,
- req.put.owner_id,
- )
- self.put_request_chunk_collector.reset()
- resp = ray_client_pb2.DataResponse(put=put_resp)
- elif req_type == "release":
- released = []
- for rel_id in req.release.ids:
- rel = self.basic_service.release(client_id, rel_id)
- released.append(rel)
- resp = ray_client_pb2.DataResponse(
- release=ray_client_pb2.ReleaseResponse(ok=released)
- )
- elif req_type == "connection_info":
- resp = ray_client_pb2.DataResponse(
- connection_info=self._build_connection_response()
- )
- elif req_type == "prep_runtime_env":
- with self.clients_lock:
- resp_prep = self.basic_service.PrepRuntimeEnv(
- req.prep_runtime_env
- )
- resp = ray_client_pb2.DataResponse(prep_runtime_env=resp_prep)
- elif req_type == "connection_cleanup":
- cleanup_requested = True
- cleanup_resp = ray_client_pb2.ConnectionCleanupResponse()
- resp = ray_client_pb2.DataResponse(connection_cleanup=cleanup_resp)
- elif req_type == "acknowledge":
- # Clean up acknowledged cache entries
- response_cache.cleanup(req.acknowledge.req_id)
- continue
- elif req_type == "task":
- with self.clients_lock:
- task = req.task
- if not self.client_task_chunk_collector.add_chunk(req, task):
- # Not all serialized arguments have arrived
- continue
- arglist, kwargs = loads_from_client(
- self.client_task_chunk_collector.data, self.basic_service
- )
- self.client_task_chunk_collector.reset()
- resp_ticket = self.basic_service.Schedule(
- req.task, arglist, kwargs, context
- )
- resp = ray_client_pb2.DataResponse(task_ticket=resp_ticket)
- del arglist
- del kwargs
- elif req_type == "terminate":
- with self.clients_lock:
- response = self.basic_service.Terminate(req.terminate, context)
- resp = ray_client_pb2.DataResponse(terminate=response)
- elif req_type == "list_named_actors":
- with self.clients_lock:
- response = self.basic_service.ListNamedActors(
- req.list_named_actors
- )
- resp = ray_client_pb2.DataResponse(list_named_actors=response)
- else:
- raise Exception(
- f"Unreachable code: Request type "
- f"{req_type} not handled in Datapath"
- )
- resp.req_id = req.req_id
- if _should_cache(req) and reconnect_enabled:
- response_cache.update_cache(req.req_id, resp)
- yield resp
- except Exception as e:
- logger.exception("Error in data channel:")
- recoverable = _propagate_error_in_context(e, context)
- invalid_cache = response_cache.invalidate(e)
- if not recoverable or invalid_cache:
- context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
- # Connection isn't recoverable, skip cleanup
- cleanup_requested = True
- finally:
- logger.debug(f"Stream is broken with client {client_id}")
- queue_filler_thread.join(QUEUE_JOIN_SECONDS)
- if queue_filler_thread.is_alive():
- logger.error(
- "Queue filler thread failed to join before timeout: {}".format(
- QUEUE_JOIN_SECONDS
- )
- )
- cleanup_delay = self.reconnect_grace_periods.get(client_id)
- if not cleanup_requested and cleanup_delay is not None:
- logger.debug(
- "Cleanup wasn't requested, delaying cleanup by"
- f"{cleanup_delay} seconds."
- )
- # Delay cleanup, since client may attempt a reconnect
- # Wait on the "stopped" event in case the grpc server is
- # stopped and we can clean up earlier.
- self.stopped.wait(timeout=cleanup_delay)
- else:
- logger.debug("Cleanup was requested, cleaning up immediately.")
- with self.clients_lock:
- if client_id not in self.client_last_seen:
- logger.debug("Connection already cleaned up.")
- # Some other connection has already cleaned up this
- # this client's session. This can happen if the client
- # reconnects and then gracefully shut's down immediately.
- return
- last_seen = self.client_last_seen[client_id]
- if last_seen > start_time:
- # The client successfully reconnected and updated
- # last seen some time during the grace period
- logger.debug("Client reconnected, skipping cleanup")
- return
- # Either the client shut down gracefully, or the client
- # failed to reconnect within the grace period. Clean up
- # the connection.
- self.basic_service.release_all(client_id)
- del self.client_last_seen[client_id]
- if client_id in self.reconnect_grace_periods:
- del self.reconnect_grace_periods[client_id]
- if client_id in self.response_caches:
- del self.response_caches[client_id]
- self.num_clients -= 1
- logger.debug(
- f"Removed client {client_id}, " f"remaining={self.num_clients}"
- )
- # It's important to keep the Ray shutdown
- # within this locked context or else Ray could hang.
- # NOTE: it is strange to start ray in server.py but shut it
- # down here. Consider consolidating ray lifetime management.
- with disable_client_hook():
- if self.num_clients == 0:
- logger.debug("Shutting down ray.")
- ray.shutdown()
- def _init(self, client_id: str, context: Any, start_time: float):
- """
- Checks if resources allow for another client.
- Returns a boolean indicating if initialization was successful.
- """
- with self.clients_lock:
- reconnecting = _get_reconnecting_from_context(context)
- threshold = int(CLIENT_SERVER_MAX_THREADS / 2)
- if self.num_clients >= threshold:
- logger.warning(
- f"[Data Servicer]: Num clients {self.num_clients} "
- f"has reached the threshold {threshold}. "
- f"Rejecting client: {client_id}. "
- )
- if log_once("client_threshold"):
- logger.warning(
- "You can configure the client connection "
- "threshold by setting the "
- "RAY_CLIENT_SERVER_MAX_THREADS env var "
- f"(currently set to {CLIENT_SERVER_MAX_THREADS})."
- )
- context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
- return False
- if reconnecting and client_id not in self.client_last_seen:
- # Client took too long to reconnect, session has been
- # cleaned up.
- context.set_code(grpc.StatusCode.NOT_FOUND)
- context.set_details(
- "Attempted to reconnect to a session that has already "
- "been cleaned up."
- )
- return False
- if client_id in self.client_last_seen:
- logger.debug(f"Client {client_id} has reconnected.")
- else:
- self.num_clients += 1
- logger.debug(
- f"Accepted data connection from {client_id}. "
- f"Total clients: {self.num_clients}"
- )
- self.client_last_seen[client_id] = start_time
- return True
- def _build_connection_response(self):
- with self.clients_lock:
- cur_num_clients = self.num_clients
- return ray_client_pb2.ConnectionInfoResponse(
- num_clients=cur_num_clients,
- python_version="{}.{}.{}".format(
- sys.version_info[0], sys.version_info[1], sys.version_info[2]
- ),
- ray_version=ray.__version__,
- ray_commit=ray.__commit__,
- )
|