import asyncio import gc import json import logging import os import pickle import time from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Generator, Optional, Set, Tuple import grpc import starlette import starlette.routing from packaging import version from starlette.types import Receive import ray from ray._common.filters import CoreContextFilter from ray._common.utils import get_or_create_event_loop from ray.serve._private.common import ( DeploymentID, EndpointInfo, NodeId, ReplicaID, RequestMetadata, RequestProtocol, ) from ray.serve._private.constants import ( HEALTHY_MESSAGE, PROXY_MIN_DRAINING_PERIOD_S, RAY_SERVE_ENABLE_PROXY_GC_OPTIMIZATIONS, RAY_SERVE_PROXY_GC_THRESHOLD, RAY_SERVE_REQUEST_PATH_LOG_BUFFER_SIZE, REQUEST_LATENCY_BUCKETS_MS, SERVE_CONTROLLER_NAME, SERVE_HTTP_REQUEST_ID_HEADER, SERVE_LOG_COMPONENT, SERVE_LOG_COMPONENT_ID, SERVE_LOG_REQUEST_ID, SERVE_LOG_ROUTE, SERVE_LOGGER_NAME, SERVE_MULTIPLEXED_MODEL_ID, SERVE_NAMESPACE, ) from ray.serve._private.default_impl import get_proxy_handle from ray.serve._private.event_loop_monitoring import EventLoopMonitor from ray.serve._private.grpc_util import ( get_grpc_response_status, set_grpc_code_and_details, start_grpc_server, ) from ray.serve._private.http_util import ( MessageQueue, configure_http_middlewares, convert_object_to_asgi_messages, get_http_response_status, receive_http_body, send_http_response_on_exception, start_asgi_http_server, ) from ray.serve._private.logging_utils import ( access_log_msg, configure_component_logger, configure_component_memory_profiler, get_component_logger_file_path, ) from ray.serve._private.long_poll import LongPollClient, LongPollNamespace from ray.serve._private.proxy_request_response import ( ASGIProxyRequest, HandlerMetadata, ProxyRequest, ResponseGenerator, ResponseHandlerInfo, ResponseStatus, gRPCProxyRequest, ) from ray.serve._private.proxy_response_generator import ProxyResponseGenerator from ray.serve._private.proxy_router import ProxyRouter from ray.serve._private.usage import ServeUsageTag from ray.serve._private.utils import ( asyncio_grpc_exception_handler, generate_request_id, get_head_node_id, is_grpc_enabled, ) from ray.serve.config import HTTPOptions, gRPCOptions from ray.serve.generated.serve_pb2 import HealthzResponse, ListApplicationsResponse from ray.serve.handle import DeploymentHandle from ray.serve.schema import EncodingType, LoggingConfig from ray.util import metrics logger = logging.getLogger(SERVE_LOGGER_NAME) SOCKET_REUSE_PORT_ENABLED = ( os.environ.get("SERVE_SOCKET_REUSE_PORT_ENABLED", "1") == "1" ) if os.environ.get("SERVE_REQUEST_PROCESSING_TIMEOUT_S") is not None: logger.warning( "The `SERVE_REQUEST_PROCESSING_TIMEOUT_S` environment variable has " "been deprecated. Please set `request_timeout_s` in your Serve config's " "`http_options` or `grpc_options` field instead. `SERVE_REQUEST_PROCESSING_TIMEOUT_S` will be " "ignored in future versions. See: https://docs.ray.io/en/releases-2.5.1/serve/a" "pi/doc/ray.serve.schema.HTTPOptionsSchema.html#ray.serve.schema.HTTPOptionsSch" "ema.request_timeout_s and https://docs.ray.io/en/latest/serve/api/" "doc/ray.serve.config.gRPCOptions.request_timeout_s.html#" "ray.serve.config.gRPCOptions.request_timeout_s" ) INITIAL_BACKOFF_PERIOD_SEC = 0.05 MAX_BACKOFF_PERIOD_SEC = 5 DRAINING_MESSAGE = "This node is being drained." class GenericProxy(ABC): """This class is served as the base class for different types of proxies. It contains all the common setup and methods required for running a proxy. The proxy subclass need to implement the following methods: - `protocol()` - `not_found_response()` - `routes_response()` - `health_response()` - `setup_request_context_and_handle()` - `send_request_to_replica()` """ def __init__( self, node_id: NodeId, node_ip_address: str, is_head: bool, proxy_router: ProxyRouter, request_timeout_s: Optional[float] = None, access_log_context: Dict[str, Any] = None, ): self.request_timeout_s = request_timeout_s if self.request_timeout_s is not None and self.request_timeout_s < 0: self.request_timeout_s = None self._node_id = node_id self._is_head = is_head self.proxy_router = proxy_router self.request_counter = metrics.Counter( f"serve_num_{self.protocol.lower()}_requests", description=f"The number of {self.protocol} requests processed.", tag_keys=("route", "method", "application", "status_code"), ) self.request_error_counter = metrics.Counter( f"serve_num_{self.protocol.lower()}_error_requests", description=f"The number of errored {self.protocol} responses.", tag_keys=( "route", "error_code", "method", "application", ), ) self.deployment_request_error_counter = metrics.Counter( f"serve_num_deployment_{self.protocol.lower()}_error_requests", description=( f"The number of errored {self.protocol} " "responses returned by each deployment." ), tag_keys=( "deployment", "error_code", "method", "route", "application", ), ) # log REQUEST_LATENCY_BUCKET_MS logger.debug(f"REQUEST_LATENCY_BUCKET_MS: {REQUEST_LATENCY_BUCKETS_MS}") self.processing_latency_tracker = metrics.Histogram( f"serve_{self.protocol.lower()}_request_latency_ms", description=( f"The end-to-end latency of {self.protocol} requests " f"(measured from the Serve {self.protocol} proxy)." ), boundaries=REQUEST_LATENCY_BUCKETS_MS, tag_keys=( "method", "route", "application", "status_code", ), ) self.num_ongoing_requests_gauge = metrics.Gauge( name=f"serve_num_ongoing_{self.protocol.lower()}_requests", description=f"The number of ongoing requests in this {self.protocol} " "proxy.", tag_keys=("node_id", "node_ip_address"), ).set_default_tags( { "node_id": node_id, "node_ip_address": node_ip_address, } ) # `self._ongoing_requests` is used to count the number of ongoing requests self._ongoing_requests = 0 # The time when the node starts to drain. # The node is not draining if it's None. self._draining_start_time: Optional[float] = None self._access_log_context = access_log_context or {} getattr(ServeUsageTag, f"{self.protocol.upper()}_PROXY_USED").record("1") @property @abstractmethod def protocol(self) -> RequestProtocol: """Protocol used in the proxy. Each proxy needs to implement its own logic for setting up the protocol. """ raise NotImplementedError def _is_draining(self) -> bool: """Whether is proxy actor is in the draining status or not.""" return self._draining_start_time is not None def is_drained(self): """Check whether the proxy actor is drained or not. A proxy actor is drained if it has no ongoing requests AND it has been draining for more than `PROXY_MIN_DRAINING_PERIOD_S` seconds. """ if not self._is_draining(): return False return (not self._ongoing_requests) and ( (time.time() - self._draining_start_time) > PROXY_MIN_DRAINING_PERIOD_S ) def update_draining(self, draining: bool): """Update the draining status of the proxy. This is called by the proxy state manager to drain or un-drain the proxy actor. """ if draining and (not self._is_draining()): logger.info( f"Start to drain the proxy actor on node {self._node_id}.", extra={"log_to_stderr": False}, ) self._draining_start_time = time.time() if (not draining) and self._is_draining(): logger.info( f"Stop draining the proxy actor on node {self._node_id}.", extra={"log_to_stderr": False}, ) self._draining_start_time = None @abstractmethod async def not_found_response( self, proxy_request: ProxyRequest ) -> ResponseGenerator: raise NotImplementedError @abstractmethod async def routes_response( self, *, healthy: bool, message: str ) -> ResponseGenerator: raise NotImplementedError @abstractmethod async def health_response( self, *, healthy: bool, message: str ) -> ResponseGenerator: raise NotImplementedError def _ongoing_requests_start(self): """Ongoing requests start. The current autoscale logic can downscale nodes with ongoing requests if the node doesn't have replicas and has no primary copies of objects in the object store. The counter and the dummy object reference will help to keep the node alive while draining requests, so they are not dropped unintentionally. """ self._ongoing_requests += 1 self.num_ongoing_requests_gauge.set(self._ongoing_requests) def _ongoing_requests_end(self): """Ongoing requests end. Decrement the ongoing request counter and drop the dummy object reference signaling that the node can be downscaled safely. """ self._ongoing_requests -= 1 self.num_ongoing_requests_gauge.set(self._ongoing_requests) def _get_health_or_routes_reponse( self, proxy_request: ProxyRequest ) -> ResponseHandlerInfo: """Get the response handler for system health and route endpoints. If the proxy is draining or has not yet received a route table update from the controller, both will return a non-OK status. """ router_ready_for_traffic, router_msg = self.proxy_router.ready_for_traffic( self._is_head ) if self._is_draining(): healthy = False message = DRAINING_MESSAGE elif not router_ready_for_traffic: healthy = False message = router_msg else: healthy = True message = HEALTHY_MESSAGE if proxy_request.is_health_request: response_generator = self.health_response(healthy=healthy, message=message) else: assert proxy_request.is_route_request response_generator = self.routes_response(healthy=healthy, message=message) return ResponseHandlerInfo( response_generator=response_generator, metadata=HandlerMetadata( route=proxy_request.route_path, ), should_record_access_log=False, should_increment_ongoing_requests=False, ) def _get_response_handler_info( self, proxy_request: ProxyRequest ) -> ResponseHandlerInfo: if proxy_request.is_health_request or proxy_request.is_route_request: return self._get_health_or_routes_reponse(proxy_request) matched_route = None if self.protocol == RequestProtocol.HTTP: matched_route = self.proxy_router.match_route(proxy_request.route_path) elif self.protocol == RequestProtocol.GRPC: matched_route = self.proxy_router.get_handle_for_endpoint( proxy_request.route_path ) if matched_route is None: return ResponseHandlerInfo( response_generator=self.not_found_response(proxy_request), metadata=HandlerMetadata( # Don't include the invalid route prefix because it can blow up our # metrics' cardinality. # See: https://github.com/ray-project/ray/issues/47999 route="", ), should_record_access_log=True, should_increment_ongoing_requests=False, ) else: route_prefix, handle, app_is_cross_language = matched_route # Modify the path and root path so that reverse lookups and redirection # work as expected. We do this here instead of in replicas so it can be # changed without restarting the replicas. route_path = proxy_request.route_path if route_prefix != "/" and self.protocol == RequestProtocol.HTTP: assert not route_prefix.endswith("/") proxy_request.set_root_path(proxy_request.root_path + route_prefix) # NOTE(edoakes): starlette<0.33.0 expected the ASGI 'root_prefix' # to be stripped from the 'path', which wasn't technically following # the standard. See https://github.com/encode/starlette/pull/2352. if version.parse(starlette.__version__) < version.parse("0.33.0"): proxy_request.set_path(route_path.replace(route_prefix, "", 1)) # NOTE(abrar): we try to match to a specific route pattern (e.g., /api/{user_id}) # for logs & metrics when available. If no pattern matches, we fall back to the # route_prefix to avoid high cardinality. # See: https://github.com/ray-project/ray/issues/47999 and # https://github.com/ray-project/ray/issues/52212 if self.protocol == RequestProtocol.HTTP: logs_and_metrics_route = self.proxy_router.match_route_pattern( route_prefix, proxy_request.scope ) else: logs_and_metrics_route = handle.deployment_id.app_name internal_request_id = generate_request_id() handle, request_id = self.setup_request_context_and_handle( app_name=handle.deployment_id.app_name, handle=handle, route=logs_and_metrics_route, proxy_request=proxy_request, internal_request_id=internal_request_id, ) response_generator = self.send_request_to_replica( request_id=request_id, internal_request_id=internal_request_id, handle=handle, proxy_request=proxy_request, app_is_cross_language=app_is_cross_language, ) return ResponseHandlerInfo( response_generator=response_generator, metadata=HandlerMetadata( application_name=handle.deployment_id.app_name, deployment_name=handle.deployment_id.name, route=logs_and_metrics_route, ), should_record_access_log=True, should_increment_ongoing_requests=True, ) async def proxy_request(self, proxy_request: ProxyRequest) -> ResponseGenerator: """Wrapper for proxy request. This method is served as common entry point by the proxy. It handles the routing, including routes and health checks, ongoing request counter, and metrics. """ assert proxy_request.request_type in {"http", "websocket", "grpc"} response_handler_info = self._get_response_handler_info(proxy_request) start_time = time.time() if response_handler_info.should_increment_ongoing_requests: self._ongoing_requests_start() try: # The final message yielded must always be the `ResponseStatus`. status: Optional[ResponseStatus] = None async for message in response_handler_info.response_generator: if isinstance(message, ResponseStatus): status = message yield message assert status is not None and isinstance(status, ResponseStatus) finally: # If anything during the request failed, we still want to ensure the ongoing # request counter is decremented. if response_handler_info.should_increment_ongoing_requests: self._ongoing_requests_end() latency_ms = (time.time() - start_time) * 1000.0 if response_handler_info.should_record_access_log: request_context = ray.serve.context._get_serve_request_context() self._access_log_context[SERVE_LOG_ROUTE] = request_context.route self._access_log_context[SERVE_LOG_REQUEST_ID] = request_context.request_id logger.info( access_log_msg( method=proxy_request.method, route=request_context.route, status=str(status.code), latency_ms=latency_ms, ), extra=self._access_log_context, ) self.request_counter.inc( tags={ "route": response_handler_info.metadata.route, "method": proxy_request.method, "application": response_handler_info.metadata.application_name, "status_code": str(status.code), } ) self.processing_latency_tracker.observe( latency_ms, tags={ "route": response_handler_info.metadata.route, "method": proxy_request.method, "application": response_handler_info.metadata.application_name, "status_code": str(status.code), }, ) if status.is_error: self.request_error_counter.inc( tags={ "route": response_handler_info.metadata.route, "method": proxy_request.method, "application": response_handler_info.metadata.application_name, "error_code": str(status.code), } ) self.deployment_request_error_counter.inc( tags={ "route": response_handler_info.metadata.route, "method": proxy_request.method, "application": response_handler_info.metadata.application_name, "error_code": str(status.code), "deployment": response_handler_info.metadata.deployment_name, } ) @abstractmethod def setup_request_context_and_handle( self, app_name: str, handle: DeploymentHandle, route: str, proxy_request: ProxyRequest, internal_request_id: str, ) -> Tuple[DeploymentHandle, str]: """Setup the request context and handle for the request. Each proxy needs to implement its own logic for setting up the request context and handle. """ raise NotImplementedError @abstractmethod async def send_request_to_replica( self, request_id: str, internal_request_id: str, handle: DeploymentHandle, proxy_request: ProxyRequest, app_is_cross_language: bool = False, ) -> ResponseGenerator: """Send the request to the replica and handle streaming response. Each proxy needs to implement its own logic for sending the request and handling the streaming response. """ raise NotImplementedError class gRPCProxy(GenericProxy): """This class is meant to be instantiated and run by an gRPC server. This is the servicer class for the gRPC server. It implements `unary_unary` as the entry point for unary gRPC request and `unary_stream` as the entry point for streaming gRPC request. """ @property def protocol(self) -> RequestProtocol: return RequestProtocol.GRPC async def not_found_response( self, proxy_request: ProxyRequest ) -> ResponseGenerator: if not proxy_request.app_name: application_message = "Application metadata not set." else: application_message = f"Application '{proxy_request.app_name}' not found." not_found_message = ( f"{application_message} Ping " "/ray.serve.RayServeAPIService/ListApplications for available applications." ) yield ResponseStatus( code=grpc.StatusCode.NOT_FOUND, message=not_found_message, is_error=True, ) async def routes_response( self, *, healthy: bool, message: str ) -> ResponseGenerator: yield ListApplicationsResponse( application_names=[ endpoint.app_name for endpoint in self.proxy_router.endpoints ], ).SerializeToString() yield ResponseStatus( code=grpc.StatusCode.OK if healthy else grpc.StatusCode.UNAVAILABLE, message=message, is_error=not healthy, ) async def health_response(self, *, healthy: bool, message) -> ResponseGenerator: yield HealthzResponse(message=message).SerializeToString() yield ResponseStatus( code=grpc.StatusCode.OK if healthy else grpc.StatusCode.UNAVAILABLE, message=message, is_error=not healthy, ) def service_handler_factory(self, service_method: str, stream: bool) -> Callable: async def unary_unary( request_proto: Any, context: grpc._cython.cygrpc._ServicerContext ) -> bytes: """Entry point of the gRPC proxy unary request. This method is called by the gRPC server when a unary request is received. It wraps the request in a ProxyRequest object and calls proxy_request. The return value is serialized user defined protobuf bytes. """ proxy_request = gRPCProxyRequest( request_proto=request_proto, context=context, service_method=service_method, stream=False, ) status = None response = None async for message in self.proxy_request(proxy_request=proxy_request): if isinstance(message, ResponseStatus): status = message else: response = message set_grpc_code_and_details(context, status) return response async def unary_stream( request_proto: Any, context: grpc._cython.cygrpc._ServicerContext ) -> Generator[bytes, None, None]: """Entry point of the gRPC proxy streaming request. This method is called by the gRPC server when a streaming request is received. It wraps the request in a ProxyRequest object and calls proxy_request. The return value is a generator of serialized user defined protobuf bytes. """ proxy_request = gRPCProxyRequest( request_proto=request_proto, context=context, service_method=service_method, stream=True, ) status = None async for message in self.proxy_request(proxy_request=proxy_request): if isinstance(message, ResponseStatus): status = message else: yield message set_grpc_code_and_details(context, status) return unary_stream if stream else unary_unary def setup_request_context_and_handle( self, app_name: str, handle: DeploymentHandle, route: str, proxy_request: ProxyRequest, internal_request_id: str, ) -> Tuple[DeploymentHandle, str]: """Setup request context and handle for the request. Unpack gRPC request metadata and extract info to set up request context and handle. """ multiplexed_model_id = proxy_request.multiplexed_model_id request_id = proxy_request.request_id if not request_id: request_id = generate_request_id() proxy_request.request_id = request_id handle = handle.options( stream=proxy_request.stream, multiplexed_model_id=multiplexed_model_id, method_name=proxy_request.method_name, ) request_context_info = { "route": route, "request_id": request_id, "_internal_request_id": internal_request_id, "app_name": app_name, "multiplexed_model_id": multiplexed_model_id, "grpc_context": proxy_request.ray_serve_grpc_context, } ray.serve.context._serve_request_context.set( ray.serve.context._RequestContext(**request_context_info) ) proxy_request.send_request_id(request_id=request_id) return handle, request_id async def send_request_to_replica( self, request_id: str, internal_request_id: str, handle: DeploymentHandle, proxy_request: ProxyRequest, app_is_cross_language: bool = False, ) -> ResponseGenerator: response_generator = ProxyResponseGenerator( handle.remote(proxy_request.serialized_replica_arg()), timeout_s=self.request_timeout_s, ) try: async for context, result in response_generator: context._set_on_grpc_context(proxy_request.context) yield result status = ResponseStatus(code=grpc.StatusCode.OK) except BaseException as e: status = get_grpc_response_status(e, self.request_timeout_s, request_id) # The status code should always be set. assert status is not None yield status class HTTPProxy(GenericProxy): """This class is meant to be instantiated and run by an ASGI HTTP server.""" def __init__( self, node_id: NodeId, node_ip_address: str, is_head: bool, proxy_router: ProxyRouter, self_actor_name: str, request_timeout_s: Optional[float] = None, access_log_context: Dict[str, Any] = None, ): super().__init__( node_id, node_ip_address, is_head, proxy_router, request_timeout_s=request_timeout_s, access_log_context=access_log_context, ) self.self_actor_name = self_actor_name self.asgi_receive_queues: Dict[str, MessageQueue] = dict() @property def protocol(self) -> RequestProtocol: return RequestProtocol.HTTP async def not_found_response( self, proxy_request: ProxyRequest ) -> ResponseGenerator: status_code = 404 for message in convert_object_to_asgi_messages( f"Path '{proxy_request.path}' not found. " "Ping http://.../-/routes for available routes.", status_code=status_code, ): yield message yield ResponseStatus(code=status_code, is_error=True) async def routes_response( self, *, healthy: bool, message: str ) -> ResponseGenerator: status_code = 200 if healthy else 503 if healthy: response = dict() for endpoint, info in self.proxy_router.endpoints.items(): # For 2.x deployments, return {route -> app name} if endpoint.app_name: response[info.route] = endpoint.app_name # Keep compatibility with 1.x deployments. else: response[info.route] = endpoint.name else: response = message for asgi_message in convert_object_to_asgi_messages( response, status_code=status_code, ): yield asgi_message yield ResponseStatus( code=status_code, message=message, is_error=not healthy, ) async def health_response( self, *, healthy: bool, message: str = "" ) -> ResponseGenerator: status_code = 200 if healthy else 503 for asgi_message in convert_object_to_asgi_messages( message, status_code=status_code, ): yield asgi_message yield ResponseStatus( code=status_code, is_error=not healthy, message=message, ) async def receive_asgi_messages( self, request_metadata: RequestMetadata ) -> ResponseGenerator: queue = self.asgi_receive_queues.get(request_metadata.internal_request_id, None) if queue is None: raise KeyError(f"Request ID {request_metadata.request_id} not found.") await queue.wait_for_message() return queue.get_messages_nowait() async def __call__(self, scope, receive, send): """Implements the ASGI protocol. See details at: https://asgi.readthedocs.io/en/latest/specs/index.html. """ proxy_request = ASGIProxyRequest(scope=scope, receive=receive, send=send) async for message in self.proxy_request(proxy_request): if not isinstance(message, ResponseStatus): await send(message) async def proxy_asgi_receive( self, receive: Receive, queue: MessageQueue ) -> Optional[int]: """Proxies the `receive` interface, placing its messages into the queue. Once a disconnect message is received, the call exits and `receive` is no longer called. For HTTP messages, `None` is always returned. For websocket messages, the disconnect code is returned if a disconnect code is received. """ try: while True: msg = await receive() await queue(msg) if msg["type"] == "http.disconnect": return None if msg["type"] == "websocket.disconnect": return msg["code"] finally: # Close the queue so any subsequent calls to fetch messages return # immediately: https://github.com/ray-project/ray/issues/38368. queue.close() def setup_request_context_and_handle( self, app_name: str, handle: DeploymentHandle, route: str, proxy_request: ProxyRequest, internal_request_id: str, ) -> Tuple[DeploymentHandle, str]: """Setup request context and handle for the request. Unpack HTTP request headers and extract info to set up request context and handle. """ request_context_info = { "route": route, "app_name": app_name, "_internal_request_id": internal_request_id, "is_http_request": True, } for key, value in proxy_request.headers: if key.decode() == SERVE_MULTIPLEXED_MODEL_ID: multiplexed_model_id = value.decode() handle = handle.options(multiplexed_model_id=multiplexed_model_id) request_context_info["multiplexed_model_id"] = multiplexed_model_id if key.decode() == SERVE_HTTP_REQUEST_ID_HEADER: request_context_info["request_id"] = value.decode() ray.serve.context._serve_request_context.set( ray.serve.context._RequestContext(**request_context_info) ) return handle, request_context_info["request_id"] async def _format_handle_arg_for_java( self, proxy_request: ProxyRequest, ) -> bytes: """Convert an HTTP request to the Java-accepted format (single byte string).""" query_string = proxy_request.scope.get("query_string") http_body_bytes = await receive_http_body( proxy_request.scope, proxy_request.receive, proxy_request.send ) if query_string: arg = query_string.decode().split("=", 1)[1] else: arg = http_body_bytes.decode() return arg async def send_request_to_replica( self, request_id: str, internal_request_id: str, handle: DeploymentHandle, proxy_request: ProxyRequest, app_is_cross_language: bool = False, ) -> ResponseGenerator: """Send the request to the replica and yield its response messages. The yielded values will be ASGI messages until the final one, which will be the status code. """ if app_is_cross_language: handle_arg_bytes = await self._format_handle_arg_for_java(proxy_request) # Response is returned as raw bytes, convert it to ASGI messages. result_callback = convert_object_to_asgi_messages else: handle_arg_bytes = proxy_request.serialized_replica_arg( proxy_actor_name=self.self_actor_name, ) # Messages are returned as pickled dictionaries. result_callback = pickle.loads # Proxy the receive interface by placing the received messages on a queue. # The downstream replica must call back into `receive_asgi_messages` on this # actor to receive the messages. receive_queue = MessageQueue() self.asgi_receive_queues[internal_request_id] = receive_queue proxy_asgi_receive_task = get_or_create_event_loop().create_task( self.proxy_asgi_receive(proxy_request.receive, receive_queue) ) response_generator = ProxyResponseGenerator( handle.remote(handle_arg_bytes), timeout_s=self.request_timeout_s, disconnected_task=proxy_asgi_receive_task, result_callback=result_callback, ) status: Optional[ResponseStatus] = None response_started = False expecting_trailers = False try: async for asgi_message_batch in response_generator: # See the ASGI spec for message details: # https://asgi.readthedocs.io/en/latest/specs/www.html. for asgi_message in asgi_message_batch: if asgi_message["type"] == "http.response.start": # HTTP responses begin with exactly one # "http.response.start" message containing the "status" # field. Other response types (e.g., WebSockets) may not. status_code = str(asgi_message["status"]) status = ResponseStatus( code=status_code, is_error=status_code.startswith(("4", "5")), ) expecting_trailers = asgi_message.get("trailers", False) elif asgi_message["type"] == "websocket.accept": # Websocket code explicitly handles client disconnects, # so let the ASGI disconnect message propagate instead of # cancelling the handler. response_generator.stop_checking_for_disconnect() elif ( asgi_message["type"] == "http.response.body" and not asgi_message.get("more_body", False) and not expecting_trailers ): # If the body is completed and we aren't expecting trailers, the # response is done so we should stop listening for disconnects. response_generator.stop_checking_for_disconnect() elif asgi_message["type"] == "http.response.trailers": # If we are expecting trailers, the response is only done when # the trailers message has been sent. if not asgi_message.get("more_trailers", False): response_generator.stop_checking_for_disconnect() elif asgi_message["type"] in [ "websocket.close", "websocket.disconnect", ]: status_code = str(asgi_message["code"]) status = ResponseStatus( code=status_code, # All status codes are considered errors aside from: # 1000 (CLOSE_NORMAL), 1001 (CLOSE_GOING_AWAY). is_error=status_code not in ["1000", "1001"], ) response_generator.stop_checking_for_disconnect() yield asgi_message response_started = True except BaseException as e: status = get_http_response_status(e, self.request_timeout_s, request_id) for asgi_message in send_http_response_on_exception( status, response_started ): yield asgi_message finally: # For websocket connection, queue receive task is done when receiving # disconnect message from client. receive_client_disconnect_msg = False if not proxy_asgi_receive_task.done(): proxy_asgi_receive_task.cancel() else: receive_client_disconnect_msg = True # If the server disconnects, status_code can be set above from the # disconnect message. # If client disconnects, the disconnect code comes from # a client message via the receive interface. if status is None and proxy_request.request_type == "websocket": if receive_client_disconnect_msg: # The disconnect message is sent from the client. status = ResponseStatus( code=str(proxy_asgi_receive_task.result()), is_error=True, ) else: # The server disconnect without sending a disconnect message # (otherwise the `status` would be set). status = ResponseStatus( code="1000", # [Sihan] is there a better code for this? is_error=True, ) del self.asgi_receive_queues[internal_request_id] # The status code should always be set. assert status is not None yield status class ProxyActorInterface(ABC): """Abstract interface for proxy actors in Ray Serve. This interface defines the contract that all proxy actor implementations must follow, allowing for different proxy backends (Ray HTTP/gRPC proxies, HAProxy, etc.). """ def __init__( self, *, node_id: NodeId, node_ip_address: str, logging_config: LoggingConfig, log_buffer_size: int = RAY_SERVE_REQUEST_PATH_LOG_BUFFER_SIZE, ): """Initialize the proxy actor. Args: node_id: ID of the node this proxy is running on node_ip_address: IP address of the node logging_config: Logging configuration log_buffer_size: Size of the log buffer """ self._node_id = node_id self._node_ip_address = node_ip_address self._logging_config = logging_config self._log_buffer_size = log_buffer_size self._update_logging_config(logging_config) @abstractmethod async def ready(self) -> str: """Blocks until the proxy is ready to serve requests. Returns: JSON-serialized metadata containing proxy information (worker ID, log file path, etc.) """ pass @abstractmethod async def serving(self, wait_for_applications_running: bool = True) -> None: """Wait for the proxy to be ready to serve requests. Args: wait_for_applications_running: Whether to wait for the applications to be running Returns: None """ pass @abstractmethod async def update_draining( self, draining: bool, _after: Optional[Any] = None ) -> None: """Update the draining status of the proxy. Args: draining: Whether the proxy should be draining _after: Optional ObjectRef for scheduling dependency """ pass @abstractmethod async def is_drained(self, _after: Optional[Any] = None) -> bool: """Check whether the proxy is drained. Args: _after: Optional ObjectRef for scheduling dependency Returns: True if the proxy is drained, False otherwise """ pass @abstractmethod async def check_health(self) -> bool: """Check the health of the proxy. Returns: True if the proxy is healthy, False otherwise """ pass @abstractmethod def pong(self) -> str: """Respond to ping from replicas. Returns: A response string """ pass @abstractmethod async def receive_asgi_messages(self, request_metadata: RequestMetadata) -> bytes: """Handle ASGI messages for HTTP requests. Args: request_metadata: Metadata about the request Returns: Serialized ASGI messages """ pass # Testing and debugging methods @abstractmethod def _get_http_options(self) -> HTTPOptions: """Get HTTP options used by the proxy.""" pass @abstractmethod def _get_logging_config(self) -> Optional[str]: """Get the file path for the logger (for testing purposes).""" pass @abstractmethod def _dump_ingress_replicas_for_testing(self, route: str) -> Set: """Get replicas for a route (for testing).""" pass def _update_logging_config(self, logging_config: LoggingConfig): configure_component_logger( component_name="proxy", component_id=self._node_ip_address, logging_config=logging_config, buffer_size=self._log_buffer_size, ) @ray.remote(num_cpus=0) class ProxyActor(ProxyActorInterface): def __init__( self, http_options: HTTPOptions, grpc_options: gRPCOptions, *, node_id: NodeId, node_ip_address: str, logging_config: LoggingConfig, long_poll_client: Optional[LongPollClient] = None, ): # noqa: F821 super().__init__( node_id=node_id, node_ip_address=node_ip_address, logging_config=logging_config, ) self._grpc_options = grpc_options self._http_options = configure_http_middlewares(http_options) grpc_enabled = is_grpc_enabled(self._grpc_options) event_loop = get_or_create_event_loop() self.long_poll_client = long_poll_client or LongPollClient( ray.get_actor(SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE), { LongPollNamespace.GLOBAL_LOGGING_CONFIG: self._update_logging_config, LongPollNamespace.ROUTE_TABLE: self._update_routes_in_proxies, }, call_in_event_loop=event_loop, ) startup_msg = f"Proxy starting on node {self._node_id} (HTTP port: {self._http_options.port}" if grpc_enabled: startup_msg += f", gRPC port: {self._grpc_options.port})." else: startup_msg += ")." logger.info(startup_msg) logger.debug( f"Configure Proxy actor {ray.get_runtime_context().get_actor_id()} " f"logger with logging config: {logging_config}" ) configure_component_memory_profiler( component_name="proxy", component_id=node_ip_address ) if logging_config.encoding == EncodingType.JSON: # Create logging context for access logs as a performance optimization. # While logging_utils can automatically add Ray core and Serve access log context, # we pre-compute it here since context evaluation is expensive and this context # will be reused for multiple access log entries. ray_core_logging_context = CoreContextFilter.get_ray_core_logging_context() # remove task level log keys from ray core logging context, it would be nice # to have task level log keys here but we are letting those go in favor of # performance optimization. Also we cannot include task level log keys here because # they would referance the current task (__init__) and not the task that is logging. for key in CoreContextFilter.TASK_LEVEL_LOG_KEYS: ray_core_logging_context.pop(key, None) access_log_context = { **ray_core_logging_context, SERVE_LOG_COMPONENT: "proxy", SERVE_LOG_COMPONENT_ID: self._node_ip_address, "log_to_stderr": False, "skip_context_filter": True, "serve_access_log": True, } else: access_log_context = { "log_to_stderr": False, "skip_context_filter": True, "serve_access_log": True, } is_head = self._node_id == get_head_node_id() self.proxy_router = ProxyRouter(get_proxy_handle) self.http_proxy = HTTPProxy( node_id=self._node_id, node_ip_address=self._node_ip_address, is_head=is_head, self_actor_name=ray.get_runtime_context().get_actor_name(), proxy_router=self.proxy_router, request_timeout_s=self._http_options.request_timeout_s, access_log_context=access_log_context, ) self.grpc_proxy = ( gRPCProxy( node_id=self._node_id, node_ip_address=self._node_ip_address, is_head=is_head, proxy_router=self.proxy_router, request_timeout_s=self._grpc_options.request_timeout_s, access_log_context=access_log_context, ) if grpc_enabled else None ) if self.grpc_proxy: get_or_create_event_loop().set_exception_handler( asyncio_grpc_exception_handler ) # Start a task to initialize the HTTP server. # The result of this task is checked in the `ready` method. self._start_http_server_task = event_loop.create_task( start_asgi_http_server( self.http_proxy, self._http_options, event_loop=event_loop, enable_so_reuseport=SOCKET_REUSE_PORT_ENABLED, ) ) # A task that runs the HTTP server until it exits (currently runs forever). # Populated with the result of self._start_http_server_task. self._running_http_server_task: Optional[asyncio.Task] = None # Start a task to initialize the gRPC server. # The result of this task is checked in the `ready` method. self._start_grpc_server_task: Optional[asyncio.Task] = None if grpc_enabled: self._start_grpc_server_task = event_loop.create_task( start_grpc_server( self.grpc_proxy.service_handler_factory, self._grpc_options, event_loop=event_loop, enable_so_reuseport=SOCKET_REUSE_PORT_ENABLED, ), ) # A task that runs the gRPC server until it exits (currently runs forever). # Populated with the result of self._start_grpc_server_task. self._running_grpc_server_task: Optional[asyncio.Task] = None _configure_gc_options() # Start event loop monitoring for the proxy's main event loop. self._event_loop_monitor = EventLoopMonitor( component=EventLoopMonitor.COMPONENT_PROXY, loop_type=EventLoopMonitor.LOOP_TYPE_MAIN, actor_id=ray.get_runtime_context().get_actor_id(), ) self._event_loop_monitor.start(event_loop) def _update_routes_in_proxies(self, endpoints: Dict[DeploymentID, EndpointInfo]): self.proxy_router.update_routes(endpoints) def _get_logging_config(self) -> Tuple: """Get the logging configuration (for testing purposes).""" log_file_path = None for handler in logger.handlers: if isinstance(handler, logging.handlers.MemoryHandler): log_file_path = handler.target.baseFilename return log_file_path def _dump_ingress_replicas_for_testing(self, route: str) -> Set[ReplicaID]: _, handle, _ = self.http_proxy.proxy_router.match_route(route) return handle._router._asyncio_router._request_router._replica_id_set def _dump_ingress_cache_for_testing(self, route: str) -> Set[ReplicaID]: """Get replica IDs that have entries in the queue length cache (for testing).""" _, handle, _ = self.http_proxy.proxy_router.match_route(route) request_router = handle._router._asyncio_router._request_router cache = request_router.replica_queue_len_cache return { replica_id for replica_id in request_router._replica_id_set if cache.get(replica_id) is not None } async def ready(self) -> str: """Blocks until the proxy HTTP (and optionally gRPC) servers are running. Returns JSON-serialized metadata containing the proxy's worker ID and log file path. Raises any exceptions that occur setting up the HTTP or gRPC server. """ try: self._running_http_server_task = await self._start_http_server_task except Exception as e: logger.exception("Failed to start proxy HTTP server.") raise e from None try: if self._start_grpc_server_task is not None: self._running_grpc_server_task = await self._start_grpc_server_task except Exception as e: logger.exception("Failed to start proxy gRPC server.") raise e from None # Return proxy metadata used by the controller. # NOTE(zcin): We need to convert the metadata to a json string because # of cross-language scenarios. Java can't deserialize a Python tuple. return json.dumps( [ ray.get_runtime_context().get_worker_id(), get_component_logger_file_path(), ] ) async def serving(self, wait_for_applications_running: bool = True) -> None: """Wait for the proxy to be ready to serve requests.""" return async def update_draining(self, draining: bool, _after: Optional[Any] = None): """Update the draining status of the HTTP and gRPC proxies. Unused `_after` argument is for scheduling: passing an ObjectRef allows delaying this call until after the `_after` call has returned. """ self.http_proxy.update_draining(draining) if self.grpc_proxy: self.grpc_proxy.update_draining(draining) async def is_drained(self, _after: Optional[Any] = None): """Check whether both HTTP and gRPC proxies are drained or not. Unused `_after` argument is for scheduling: passing an ObjectRef allows delaying this call until after the `_after` call has returned. """ return self.http_proxy.is_drained() and ( self.grpc_proxy is None or self.grpc_proxy.is_drained() ) async def check_health(self) -> bool: """No-op method to check on the health of the HTTP Proxy. Make sure the async event loop is not blocked. """ logger.debug("Received health check.", extra={"log_to_stderr": False}) return True def pong(self): """Called by the replica to initialize its handle to the proxy.""" pass async def receive_asgi_messages(self, request_metadata: RequestMetadata) -> bytes: """Get ASGI messages for the provided `request_metadata`. After the proxy has stopped receiving messages for this `request_metadata`, this will always return immediately. Raises `KeyError` if this request ID is not found. This will happen when the request is no longer being handled (e.g., the user disconnects). """ return pickle.dumps( await self.http_proxy.receive_asgi_messages(request_metadata) ) def _get_http_options(self) -> HTTPOptions: """Internal method to get HTTP options used by the proxy.""" return self._http_options def _configure_gc_options(): if not RAY_SERVE_ENABLE_PROXY_GC_OPTIMIZATIONS: return # Collect any objects that exist already and exclude them from future GC. gc.collect(2) gc.freeze() # Tune the GC threshold to run less frequently (default is 700). gc.set_threshold(RAY_SERVE_PROXY_GC_THRESHOLD)