import atexit import json import logging import socket import sys import time import traceback import urllib from concurrent import futures from dataclasses import dataclass from itertools import chain from threading import Event, Lock, RLock, Thread from typing import Callable, Dict, List, Optional, Tuple from urllib.parse import urlparse, urlunparse import grpc import ray import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import ray.core.generated.runtime_env_agent_pb2 as runtime_env_agent_pb2 from ray._common.network_utils import ( build_address, is_ipv6, is_localhost, ) from ray._private.authentication.http_token_authentication import ( format_authentication_http_error, get_auth_headers_if_auth_enabled, ) from ray._private.client_mode_hook import disable_client_hook from ray._private.grpc_utils import init_grpc_channel from ray._private.parameter import RayParams from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.services import ( ProcessInfo, get_node_with_retry, start_ray_client_server, ) from ray._private.tls_utils import add_port_to_grpc_server from ray._private.utils import detect_fate_sharing_support from ray._raylet import GcsClient from ray.cloudpickle.compat import pickle from ray.exceptions import AuthenticationError from ray.job_config import JobConfig from ray.util.client.common import ( CLIENT_SERVER_MAX_THREADS, GRPC_OPTIONS, ClientServerHandle, _get_client_id_from_context, _propagate_error_in_context, ) from ray.util.client.server.dataservicer import _get_reconnecting_from_context # Import psutil after ray so the packaged version is used. import psutil logger = logging.getLogger(__name__) CHECK_PROCESS_INTERVAL_S = 30 MIN_SPECIFIC_SERVER_PORT = 23000 MAX_SPECIFIC_SERVER_PORT = 24000 CHECK_CHANNEL_TIMEOUT_S = 30 LOGSTREAM_RETRIES = 5 LOGSTREAM_RETRY_INTERVAL_SEC = 2 @dataclass class SpecificServer: port: int process_handle_future: futures.Future channel: "grpc._channel.Channel" def is_ready(self) -> bool: """Check if the server is ready or not (doesn't block).""" return self.process_handle_future.done() def wait_ready(self, timeout: Optional[float] = None) -> None: """ Wait for the server to actually start up. """ res = self.process_handle_future.result(timeout=timeout) if res is None: # This is only set to none when server creation specifically fails. raise RuntimeError("Server startup failed.") def poll(self) -> Optional[int]: """Check if the process has exited.""" try: proc = self.process_handle_future.result(timeout=0.1) if proc is not None: return proc.process.poll() except futures.TimeoutError: return def kill(self) -> None: """Try to send a KILL signal to the process.""" try: proc = self.process_handle_future.result(timeout=0.1) if proc is not None: proc.process.kill() except futures.TimeoutError: # Server has not been started yet. pass def set_result(self, proc: Optional[ProcessInfo]) -> None: """Set the result of the internal future if it is currently unset.""" if not self.is_ready(): self.process_handle_future.set_result(proc) def _match_running_client_server(command: List[str]) -> bool: """ Detects if the main process in the given command is the RayClient Server. This works by ensuring that the command is of the form: -m ray.util.client.server """ flattened = " ".join(command) return "-m ray.util.client.server" in flattened class ProxyManager: def __init__( self, address: Optional[str], runtime_env_agent_address: str, *, session_dir: Optional[str] = None, redis_username: Optional[str] = None, redis_password: Optional[str] = None, node_id: Optional[str] = None, ): self.servers: Dict[str, SpecificServer] = dict() self.server_lock = RLock() self._address = address self._redis_username = redis_username self._redis_password = redis_password self._free_ports: List[int] = list( range(MIN_SPECIFIC_SERVER_PORT, MAX_SPECIFIC_SERVER_PORT) ) if runtime_env_agent_address: parsed = urlparse(runtime_env_agent_address) # runtime env agent self-assigns a free port, fetch it from GCS if parsed.port is None or parsed.port == 0: if node_id is None: raise ValueError( "node_id is required when runtime_env_agent_address " "has no port specified" ) node_info = get_node_with_retry(address, node_id) runtime_env_agent_address = urlunparse( parsed._replace( netloc=f"{parsed.hostname}:{node_info['runtime_env_agent_port']}" ) ) self._runtime_env_agent_address = runtime_env_agent_address self._check_thread = Thread(target=self._check_processes, daemon=True) self._check_thread.start() self.fate_share = bool(detect_fate_sharing_support()) self._node: Optional[ray._private.node.Node] = None atexit.register(self._cleanup) def _get_unused_port(self, family: int = socket.AF_INET) -> int: """ Search for a port in _free_ports that is unused. """ with self.server_lock: num_ports = len(self._free_ports) for _ in range(num_ports): port = self._free_ports.pop(0) s = socket.socket(family, socket.SOCK_STREAM) try: s.bind(("", port)) except OSError: self._free_ports.append(port) continue finally: s.close() return port raise RuntimeError("Unable to succeed in selecting a random port.") @property def address(self) -> str: """ Returns the provided Ray bootstrap address, or creates a new cluster. """ if self._address: return self._address # Start a new, locally scoped cluster. connection_tuple = ray.init() self._address = connection_tuple["address"] self._session_dir = connection_tuple["session_dir"] return self._address @property def node(self) -> ray._private.node.Node: """Gets a 'ray.Node' object for this node (the head node). If it does not already exist, one is created using the bootstrap address. """ if self._node: return self._node ray_params = RayParams(gcs_address=self.address) self._node = ray._private.node.Node( ray_params, head=False, shutdown_at_exit=False, spawn_reaper=False, connect_only=True, ) return self._node def create_specific_server(self, client_id: str) -> SpecificServer: """ Create, but not start a SpecificServer for a given client. This method must be called once per client. """ with self.server_lock: assert ( self.servers.get(client_id) is None ), f"Server already created for Client: {client_id}" host = "127.0.0.1" port = self._get_unused_port( socket.AF_INET6 if is_ipv6(host) else socket.AF_INET ) server = SpecificServer( port=port, process_handle_future=futures.Future(), channel=init_grpc_channel( build_address(host, port), options=GRPC_OPTIONS ), ) self.servers[client_id] = server return server def _create_runtime_env( self, serialized_runtime_env: str, runtime_env_config: str, specific_server: SpecificServer, ): """Increase the runtime_env reference by sending an RPC to the agent. Includes retry logic to handle the case when the agent is temporarily unreachable (e.g., hasn't been started up yet). """ logger.info( f"Increasing runtime env reference for " f"ray_client_server_{specific_server.port}." f"Serialized runtime env is {serialized_runtime_env}." ) assert ( len(self._runtime_env_agent_address) > 0 ), "runtime_env_agent_address not set" create_env_request = runtime_env_agent_pb2.GetOrCreateRuntimeEnvRequest( serialized_runtime_env=serialized_runtime_env, runtime_env_config=runtime_env_config, job_id=f"ray_client_server_{specific_server.port}".encode("utf-8"), source_process="client_server", ) retries = 0 max_retries = 5 wait_time_s = 0.5 last_exception = None while retries <= max_retries: try: url = urllib.parse.urljoin( self._runtime_env_agent_address, "/get_or_create_runtime_env" ) data = create_env_request.SerializeToString() headers = {"Content-Type": "application/octet-stream"} headers.update(**get_auth_headers_if_auth_enabled(headers)) req = urllib.request.Request( url, data=data, method="POST", headers=headers ) response = urllib.request.urlopen(req, timeout=None) response_data = response.read() r = runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply() r.ParseFromString(response_data) if r.status == runtime_env_agent_pb2.AgentRpcStatus.AGENT_RPC_STATUS_OK: return r.serialized_runtime_env_context elif ( r.status == runtime_env_agent_pb2.AgentRpcStatus.AGENT_RPC_STATUS_FAILED ): raise RuntimeError( "Failed to create runtime_env for Ray client " f"server, it is caused by:\n{r.error_message}" ) else: assert False, f"Unknown status: {r.status}." except urllib.error.HTTPError as e: body = "" try: body = e.read().decode("utf-8", "ignore") except Exception: body = e.reason if hasattr(e, "reason") else str(e) formatted_error = format_authentication_http_error(e.code, body or "") if formatted_error: raise AuthenticationError(formatted_error) from e # Treat non-auth HTTP errors like URLError (retry with backoff) last_exception = e logger.warning( f"GetOrCreateRuntimeEnv request failed with HTTP {e.code}: {body or e}. " f"Retrying after {wait_time_s}s. " f"{max_retries-retries} retries remaining." ) except urllib.error.URLError as e: last_exception = e logger.warning( f"GetOrCreateRuntimeEnv request failed: {e}. " f"Retrying after {wait_time_s}s. " f"{max_retries-retries} retries remaining." ) # Exponential backoff. time.sleep(wait_time_s) retries += 1 wait_time_s *= 2 raise TimeoutError( f"GetOrCreateRuntimeEnv request failed after {max_retries} attempts." f" Last exception: {last_exception}" ) def start_specific_server(self, client_id: str, job_config: JobConfig) -> bool: """ Start up a RayClient Server for an incoming client to communicate with. Returns whether creation was successful. """ specific_server = self._get_server_for_client(client_id) assert specific_server, f"Server has not been created for: {client_id}" output, error = self.node.get_log_file_handles( f"ray_client_server_{specific_server.port}", unique=True ) serialized_runtime_env = job_config._get_serialized_runtime_env() runtime_env_config = job_config._get_proto_runtime_env_config() if not serialized_runtime_env or serialized_runtime_env == "{}": # TODO(edoakes): can we just remove this case and always send it # to the agent? serialized_runtime_env_context = RuntimeEnvContext().serialize() else: serialized_runtime_env_context = self._create_runtime_env( serialized_runtime_env=serialized_runtime_env, runtime_env_config=runtime_env_config, specific_server=specific_server, ) proc = start_ray_client_server( self.address, "127.0.0.1", specific_server.port, stdout_file=output, stderr_file=error, fate_share=self.fate_share, server_type="specific-server", serialized_runtime_env_context=serialized_runtime_env_context, redis_username=self._redis_username, redis_password=self._redis_password, ) # Wait for the process being run transitions from the shim process # to the actual RayClient Server. pid = proc.process.pid if sys.platform != "win32": psutil_proc = psutil.Process(pid) else: psutil_proc = None # Don't use `psutil` on Win32 while psutil_proc is not None: if proc.process.poll() is not None: logger.error(f"SpecificServer startup failed for client: {client_id}") break cmd = psutil_proc.cmdline() if _match_running_client_server(cmd): break logger.debug("Waiting for Process to reach the actual client server.") time.sleep(0.5) specific_server.set_result(proc) logger.info( f"SpecificServer started on port: {specific_server.port} " f"with PID: {pid} for client: {client_id}" ) return proc.process.poll() is None def _get_server_for_client(self, client_id: str) -> Optional[SpecificServer]: with self.server_lock: client = self.servers.get(client_id) if client is None: logger.error(f"Unable to find channel for client: {client_id}") return client def has_channel(self, client_id: str) -> bool: server = self._get_server_for_client(client_id) if server is None: return False return server.is_ready() def get_channel( self, client_id: str, ) -> Optional["grpc._channel.Channel"]: """ Find the gRPC Channel for the given client_id. This will block until the server process has started. """ server = self._get_server_for_client(client_id) if server is None: return None # Wait for the SpecificServer to become ready. server.wait_ready() try: grpc.channel_ready_future(server.channel).result( timeout=CHECK_CHANNEL_TIMEOUT_S ) return server.channel except grpc.FutureTimeoutError: logger.exception(f"Timeout waiting for channel for {client_id}") return None def _check_processes(self): """ Keeps the internal servers dictionary up-to-date with running servers. """ while True: with self.server_lock: for client_id, specific_server in list(self.servers.items()): if specific_server.poll() is not None: logger.info( f"Specific server {client_id} is no longer running" f", freeing its port {specific_server.port}" ) del self.servers[client_id] # Port is available to use again. self._free_ports.append(specific_server.port) time.sleep(CHECK_PROCESS_INTERVAL_S) def _cleanup(self) -> None: """ Forcibly kill all spawned RayClient Servers. This ensures cleanup for platforms where fate sharing is not supported. """ for server in self.servers.values(): server.kill() class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer): def __init__(self, ray_connect_handler: Callable, proxy_manager: ProxyManager): self.proxy_manager = proxy_manager self.ray_connect_handler = ray_connect_handler def _call_inner_function( self, request, context, method: str ) -> Optional[ray_client_pb2_grpc.RayletDriverStub]: client_id = _get_client_id_from_context(context) chan = self.proxy_manager.get_channel(client_id) if not chan: logger.error(f"Channel for Client: {client_id} not found!") context.set_code(grpc.StatusCode.NOT_FOUND) return None stub = ray_client_pb2_grpc.RayletDriverStub(chan) try: metadata = [("client_id", client_id)] if context: metadata = context.invocation_metadata() return getattr(stub, method)(request, metadata=metadata) except Exception as e: # Error while proxying -- propagate the error's context to user logger.exception(f"Proxying call to {method} failed!") _propagate_error_in_context(e, context) def _has_channel_for_request(self, context): client_id = _get_client_id_from_context(context) return self.proxy_manager.has_channel(client_id) def Init(self, request, context=None) -> ray_client_pb2.InitResponse: return self._call_inner_function(request, context, "Init") def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse: """Proxies internal_kv.put. This is used by the working_dir code to upload to the GCS before ray.init is called. In that case (if we don't have a server yet) we directly make the internal KV call from the proxier. Otherwise, we proxy the call to the downstream server as usual. """ if self._has_channel_for_request(context): return self._call_inner_function(request, context, "KVPut") with disable_client_hook(): already_exists = ray.experimental.internal_kv._internal_kv_put( request.key, request.value, overwrite=request.overwrite ) return ray_client_pb2.KVPutResponse(already_exists=already_exists) def KVGet(self, request, context=None) -> ray_client_pb2.KVGetResponse: """Proxies internal_kv.get. This is used by the working_dir code to upload to the GCS before ray.init is called. In that case (if we don't have a server yet) we directly make the internal KV call from the proxier. Otherwise, we proxy the call to the downstream server as usual. """ if self._has_channel_for_request(context): return self._call_inner_function(request, context, "KVGet") with disable_client_hook(): value = ray.experimental.internal_kv._internal_kv_get(request.key) return ray_client_pb2.KVGetResponse(value=value) def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse: """Proxies internal_kv.delete. This is used by the working_dir code to upload to the GCS before ray.init is called. In that case (if we don't have a server yet) we directly make the internal KV call from the proxier. Otherwise, we proxy the call to the downstream server as usual. """ if self._has_channel_for_request(context): return self._call_inner_function(request, context, "KVDel") with disable_client_hook(): ray.experimental.internal_kv._internal_kv_del(request.key) return ray_client_pb2.KVDelResponse() def KVList(self, request, context=None) -> ray_client_pb2.KVListResponse: """Proxies internal_kv.list. This is used by the working_dir code to upload to the GCS before ray.init is called. In that case (if we don't have a server yet) we directly make the internal KV call from the proxier. Otherwise, we proxy the call to the downstream server as usual. """ if self._has_channel_for_request(context): return self._call_inner_function(request, context, "KVList") with disable_client_hook(): keys = ray.experimental.internal_kv._internal_kv_list(request.prefix) return ray_client_pb2.KVListResponse(keys=keys) def KVExists(self, request, context=None) -> ray_client_pb2.KVExistsResponse: """Proxies internal_kv.exists. This is used by the working_dir code to upload to the GCS before ray.init is called. In that case (if we don't have a server yet) we directly make the internal KV call from the proxier. Otherwise, we proxy the call to the downstream server as usual. """ if self._has_channel_for_request(context): return self._call_inner_function(request, context, "KVExists") with disable_client_hook(): exists = ray.experimental.internal_kv._internal_kv_exists(request.key) return ray_client_pb2.KVExistsResponse(exists=exists) def PinRuntimeEnvURI( self, request, context=None ) -> ray_client_pb2.ClientPinRuntimeEnvURIResponse: """Proxies internal_kv.pin_runtime_env_uri. This is used by the working_dir code to upload to the GCS before ray.init is called. In that case (if we don't have a server yet) we directly make the internal KV call from the proxier. Otherwise, we proxy the call to the downstream server as usual. """ if self._has_channel_for_request(context): return self._call_inner_function(request, context, "PinRuntimeEnvURI") with disable_client_hook(): ray.experimental.internal_kv._pin_runtime_env_uri( request.uri, expiration_s=request.expiration_s ) return ray_client_pb2.ClientPinRuntimeEnvURIResponse() def ListNamedActors( self, request, context=None ) -> ray_client_pb2.ClientListNamedActorsResponse: return self._call_inner_function(request, context, "ListNamedActors") def ClusterInfo(self, request, context=None) -> ray_client_pb2.ClusterInfoResponse: # NOTE: We need to respond to the PING request here to allow the client # to continue with connecting. if request.type == ray_client_pb2.ClusterInfoType.PING: resp = ray_client_pb2.ClusterInfoResponse(json=json.dumps({})) return resp return self._call_inner_function(request, context, "ClusterInfo") def Terminate(self, req, context=None): return self._call_inner_function(req, context, "Terminate") def GetObject(self, request, context=None): try: yield from self._call_inner_function(request, context, "GetObject") except Exception as e: # Error while iterating over response from GetObject stream logger.exception("Proxying call to GetObject failed!") _propagate_error_in_context(e, context) def PutObject( self, request: ray_client_pb2.PutRequest, context=None ) -> ray_client_pb2.PutResponse: return self._call_inner_function(request, context, "PutObject") def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse: return self._call_inner_function(request, context, "WaitObject") def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket: return self._call_inner_function(task, context, "Schedule") def ray_client_server_env_prep(job_config: JobConfig) -> JobConfig: return job_config def prepare_runtime_init_req( init_request: ray_client_pb2.DataRequest, ) -> Tuple[ray_client_pb2.DataRequest, JobConfig]: """ Extract JobConfig and possibly mutate InitRequest before it is passed to the specific RayClient Server. """ init_type = init_request.WhichOneof("type") assert init_type == "init", ( "Received initial message of type " f"{init_type}, not 'init'." ) req = init_request.init job_config = JobConfig() if req.job_config: job_config = pickle.loads(req.job_config) new_job_config = ray_client_server_env_prep(job_config) modified_init_req = ray_client_pb2.InitRequest( job_config=pickle.dumps(new_job_config), ray_init_kwargs=init_request.init.ray_init_kwargs, reconnect_grace_period=init_request.init.reconnect_grace_period, ) init_request.init.CopyFrom(modified_init_req) return (init_request, new_job_config) class RequestIteratorProxy: def __init__(self, request_iterator): self.request_iterator = request_iterator def __iter__(self): return self def __next__(self): try: return next(self.request_iterator) except grpc.RpcError as e: # To stop proxying already CANCLLED request stream gracefully, # we only translate the exact grpc.RpcError to StopIteration, # not its subsclasses. ex: grpc._Rendezvous # https://github.com/grpc/grpc/blob/v1.43.0/src/python/grpcio/grpc/_server.py#L353-L354 # This fixes the https://github.com/ray-project/ray/issues/23865 if type(e) is not grpc.RpcError: raise e # re-raise other grpc exceptions logger.exception( "Stop iterating cancelled request stream with the following exception:" ) raise StopIteration class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer): def __init__(self, proxy_manager: ProxyManager): self.num_clients = 0 # dictionary mapping client_id's to the last time they connected self.clients_last_seen: Dict[str, float] = {} self.reconnect_grace_periods: Dict[str, float] = {} self.clients_lock = Lock() self.proxy_manager = proxy_manager self.stopped = Event() def modify_connection_info_resp( self, init_resp: ray_client_pb2.DataResponse ) -> ray_client_pb2.DataResponse: """ Modify the `num_clients` returned the ConnectionInfoResponse because individual SpecificServers only have **one** client. """ init_type = init_resp.WhichOneof("type") if init_type != "connection_info": return init_resp modified_resp = ray_client_pb2.DataResponse() modified_resp.CopyFrom(init_resp) with self.clients_lock: modified_resp.connection_info.num_clients = self.num_clients return modified_resp def Datapath(self, request_iterator, context): request_iterator = RequestIteratorProxy(request_iterator) cleanup_requested = False start_time = time.time() client_id = _get_client_id_from_context(context) if client_id == "": return reconnecting = _get_reconnecting_from_context(context) if reconnecting: with self.clients_lock: if client_id not in self.clients_last_seen: # Client took too long to reconnect, session has already # been cleaned up context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details( "Attempted to reconnect a session that has already " "been cleaned up" ) return self.clients_last_seen[client_id] = start_time server = self.proxy_manager._get_server_for_client(client_id) channel = self.proxy_manager.get_channel(client_id) # iterator doesn't need modification on reconnect new_iter = request_iterator else: # Create Placeholder *before* reading the first request. server = self.proxy_manager.create_specific_server(client_id) with self.clients_lock: self.clients_last_seen[client_id] = start_time self.num_clients += 1 try: if not reconnecting: logger.info(f"New data connection from client {client_id}: ") init_req = next(request_iterator) with self.clients_lock: self.reconnect_grace_periods[ client_id ] = init_req.init.reconnect_grace_period try: modified_init_req, job_config = prepare_runtime_init_req(init_req) if not self.proxy_manager.start_specific_server( client_id, job_config ): logger.error( f"Server startup failed for client: {client_id}, " f"using JobConfig: {job_config}!" ) raise RuntimeError( "Starting Ray client server failed. See " f"ray_client_server_{server.port}.err for " "detailed logs." ) channel = self.proxy_manager.get_channel(client_id) if channel is None: logger.error(f"Channel not found for {client_id}") raise RuntimeError( "Proxy failed to Connect to backend! Check " "`ray_client_server.err` and " f"`ray_client_server_{server.port}.err` on the " "head node of the cluster for the relevant logs. " "By default these are located at " "/tmp/ray/session_latest/logs." ) except Exception: init_resp = ray_client_pb2.DataResponse( init=ray_client_pb2.InitResponse( ok=False, msg=traceback.format_exc() ) ) init_resp.req_id = init_req.req_id yield init_resp return None new_iter = chain([modified_init_req], request_iterator) stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel) metadata = [("client_id", client_id), ("reconnecting", str(reconnecting))] resp_stream = stub.Datapath(new_iter, metadata=metadata) for resp in resp_stream: resp_type = resp.WhichOneof("type") if resp_type == "connection_cleanup": # Specific server is skipping cleanup, proxier should too cleanup_requested = True yield self.modify_connection_info_resp(resp) except Exception as e: logger.exception("Proxying Datapath failed!") # Propogate error through context recoverable = _propagate_error_in_context(e, context) if not recoverable: # Client shouldn't attempt to recover, clean up connection cleanup_requested = True finally: cleanup_delay = self.reconnect_grace_periods.get(client_id) if not cleanup_requested and cleanup_delay is not None: # Delay cleanup, since client may attempt a reconnect # Wait on stopped event in case the server closes and we # can clean up earlier self.stopped.wait(timeout=cleanup_delay) with self.clients_lock: if client_id not in self.clients_last_seen: logger.info(f"{client_id} not found. Skipping clean up.") # Connection has already been cleaned up return last_seen = self.clients_last_seen[client_id] logger.info( f"{client_id} last started stream at {last_seen}. Current " f"stream started at {start_time}." ) if last_seen > start_time: logger.info("Client reconnected. Skipping cleanup.") # Client has reconnected, don't clean up return logger.debug(f"Client detached: {client_id}") self.num_clients -= 1 del self.clients_last_seen[client_id] if client_id in self.reconnect_grace_periods: del self.reconnect_grace_periods[client_id] server.set_result(None) class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer): def __init__(self, proxy_manager: ProxyManager): super().__init__() self.proxy_manager = proxy_manager def Logstream(self, request_iterator, context): request_iterator = RequestIteratorProxy(request_iterator) client_id = _get_client_id_from_context(context) if client_id == "": return logger.debug(f"New logstream connection from client {client_id}: ") channel = None # We need to retry a few times because the LogClient *may* connect # Before the DataClient has finished connecting. for i in range(LOGSTREAM_RETRIES): channel = self.proxy_manager.get_channel(client_id) if channel is not None: break logger.warning(f"Retrying Logstream connection. {i+1} attempts failed.") time.sleep(LOGSTREAM_RETRY_INTERVAL_SEC) if channel is None: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details( "Logstream proxy failed to connect. Channel for client " f"{client_id} not found." ) return None stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel) resp_stream = stub.Logstream( request_iterator, metadata=[("client_id", client_id)] ) try: for resp in resp_stream: yield resp except Exception: logger.exception("Proxying Logstream failed!") def serve_proxier( host: str, port: int, gcs_address: Optional[str], *, redis_username: Optional[str] = None, redis_password: Optional[str] = None, session_dir: Optional[str] = None, runtime_env_agent_address: Optional[str] = None, node_id: Optional[str] = None, ): # Initialize internal KV to be used to upload and download working_dir # before calling ray.init within the RayletServicers. # NOTE(edoakes): redis_address and redis_password should only be None in # tests. if gcs_address is not None: gcs_cli = GcsClient(address=gcs_address) ray.experimental.internal_kv._initialize_internal_kv(gcs_cli) from ray._private.grpc_utils import create_grpc_server_with_interceptors server = create_grpc_server_with_interceptors( max_workers=CLIENT_SERVER_MAX_THREADS, thread_name_prefix="ray_client_proxier", options=GRPC_OPTIONS, asynchronous=False, ) proxy_manager = ProxyManager( gcs_address, session_dir=session_dir, redis_username=redis_username, redis_password=redis_password, runtime_env_agent_address=runtime_env_agent_address, node_id=node_id, ) task_servicer = RayletServicerProxy(None, proxy_manager) data_servicer = DataServicerProxy(proxy_manager) logs_servicer = LogstreamServicerProxy(proxy_manager) ray_client_pb2_grpc.add_RayletDriverServicer_to_server(task_servicer, server) ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(data_servicer, server) ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(logs_servicer, server) if not is_localhost(host): add_port_to_grpc_server(server, f"127.0.0.1:{port}") add_port_to_grpc_server(server, f"{host}:{port}") server.start() return ClientServerHandle( task_servicer=task_servicer, data_servicer=data_servicer, logs_servicer=logs_servicer, grpc_server=server, )