grpc_util.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import asyncio
  2. import logging
  3. from copy import deepcopy
  4. from typing import Callable, List, Optional, Sequence, Tuple
  5. from unittest.mock import Mock
  6. import grpc
  7. from grpc.aio._server import Server
  8. from ray.exceptions import RayActorError, RayTaskError
  9. from ray.serve._private.constants import (
  10. DEFAULT_GRPC_SERVER_OPTIONS,
  11. RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S,
  12. SERVE_LOGGER_NAME,
  13. )
  14. from ray.serve._private.proxy_request_response import ResponseStatus
  15. from ray.serve.config import gRPCOptions
  16. from ray.serve.exceptions import (
  17. BackPressureError,
  18. DeploymentUnavailableError,
  19. gRPCStatusError,
  20. )
  21. from ray.serve.generated.serve_pb2_grpc import add_RayServeAPIServiceServicer_to_server
  22. # Maximum length for gRPC status details to avoid hitting HTTP/2 trailer limits.
  23. # gRPC default max metadata size is 8KB, so we use a conservative limit.
  24. GRPC_MAX_STATUS_DETAILS_LENGTH = 4096
  25. logger = logging.getLogger(SERVE_LOGGER_NAME)
  26. class gRPCGenericServer(Server):
  27. """Custom gRPC server that will override all service method handlers.
  28. Original implementation see: https://github.com/grpc/grpc/blob/
  29. 60c1701f87cacf359aa1ad785728549eeef1a4b0/src/python/grpcio/grpc/aio/_server.py
  30. """
  31. def __init__(
  32. self,
  33. service_handler_factory: Callable,
  34. *,
  35. extra_options: Optional[List[Tuple[str, str]]] = None,
  36. ):
  37. super().__init__(
  38. thread_pool=None,
  39. generic_handlers=(),
  40. interceptors=(),
  41. maximum_concurrent_rpcs=None,
  42. compression=None,
  43. options=DEFAULT_GRPC_SERVER_OPTIONS + (extra_options or []),
  44. )
  45. self.generic_rpc_handlers = []
  46. self.service_handler_factory = service_handler_factory
  47. def add_generic_rpc_handlers(
  48. self, generic_rpc_handlers: Sequence[grpc.GenericRpcHandler]
  49. ):
  50. """Override generic_rpc_handlers before adding to the gRPC server.
  51. This function will override all user defined handlers to have
  52. 1. None `response_serializer` so the server can pass back the
  53. raw protobuf bytes to the user.
  54. 2. `unary_unary` is always calling the unary function generated via
  55. `self.service_handler_factory`
  56. 3. `unary_stream` is always calling the streaming function generated via
  57. `self.service_handler_factory`
  58. """
  59. serve_rpc_handlers = {}
  60. rpc_handler = generic_rpc_handlers[0]
  61. for service_method, method_handler in rpc_handler._method_handlers.items():
  62. serve_method_handler = method_handler._replace(
  63. response_serializer=None,
  64. unary_unary=self.service_handler_factory(
  65. service_method=service_method,
  66. stream=False,
  67. ),
  68. unary_stream=self.service_handler_factory(
  69. service_method=service_method,
  70. stream=True,
  71. ),
  72. )
  73. serve_rpc_handlers[service_method] = serve_method_handler
  74. generic_rpc_handlers[0]._method_handlers = serve_rpc_handlers
  75. self.generic_rpc_handlers.append(generic_rpc_handlers)
  76. super().add_generic_rpc_handlers(generic_rpc_handlers)
  77. async def start_grpc_server(
  78. service_handler_factory: Callable,
  79. grpc_options: gRPCOptions,
  80. *,
  81. event_loop: asyncio.AbstractEventLoop,
  82. enable_so_reuseport: bool = False,
  83. ) -> asyncio.Task:
  84. """Start a gRPC server that handles requests with the service handler factory.
  85. Returns a task that blocks until the server exits (e.g., due to error).
  86. """
  87. from ray.serve._private.default_impl import add_grpc_address
  88. server = gRPCGenericServer(
  89. service_handler_factory,
  90. extra_options=[("grpc.so_reuseport", str(int(enable_so_reuseport)))],
  91. )
  92. add_grpc_address(server, f"[::]:{grpc_options.port}")
  93. # Add built-in gRPC service and user-defined services to the server.
  94. # We pass a mock servicer because the actual implementation will be overwritten
  95. # in the gRPCGenericServer implementation.
  96. mock_servicer = Mock()
  97. for servicer_fn in [
  98. add_RayServeAPIServiceServicer_to_server
  99. ] + grpc_options.grpc_servicer_func_callable:
  100. servicer_fn(mock_servicer, server)
  101. await server.start()
  102. return event_loop.create_task(server.wait_for_termination())
  103. def _truncate_message(
  104. message: str, max_length: int = GRPC_MAX_STATUS_DETAILS_LENGTH
  105. ) -> str:
  106. """Truncate a message to avoid exceeding HTTP/2 trailer limits.
  107. gRPC status details are sent as part of HTTP/2 trailers, which have a fixed size limit.
  108. If the message (e.g., a stack trace) is too long, it can cause issues on the client side.
  109. """
  110. if len(message) <= max_length:
  111. return message
  112. truncation_notice = "... [truncated]"
  113. return message[: max_length - len(truncation_notice)] + truncation_notice
  114. def get_grpc_response_status(
  115. exc: BaseException, request_timeout_s: float, request_id: str
  116. ) -> ResponseStatus:
  117. if isinstance(exc, TimeoutError):
  118. message = f"Request timed out after {request_timeout_s}s."
  119. return ResponseStatus(
  120. code=grpc.StatusCode.DEADLINE_EXCEEDED,
  121. is_error=True,
  122. message=message,
  123. )
  124. elif isinstance(exc, asyncio.CancelledError):
  125. message = f"Client for request {request_id} disconnected."
  126. return ResponseStatus(
  127. code=grpc.StatusCode.CANCELLED,
  128. is_error=True,
  129. message=message,
  130. )
  131. elif isinstance(exc, BackPressureError):
  132. return ResponseStatus(
  133. code=grpc.StatusCode.RESOURCE_EXHAUSTED,
  134. is_error=True,
  135. message=exc.message,
  136. )
  137. elif isinstance(exc, DeploymentUnavailableError):
  138. if isinstance(exc, RayTaskError):
  139. logger.warning(f"Request failed: {exc}", extra={"log_to_stderr": False})
  140. return ResponseStatus(
  141. code=grpc.StatusCode.UNAVAILABLE,
  142. is_error=True,
  143. message=exc.message,
  144. )
  145. elif isinstance(exc, gRPCStatusError):
  146. # User set a gRPC status code before raising the exception.
  147. # Respect the user's status code instead of returning INTERNAL.
  148. original_exc = exc.original_exception
  149. if isinstance(original_exc, (RayActorError, RayTaskError)):
  150. logger.warning(
  151. f"Request failed: {original_exc}", extra={"log_to_stderr": False}
  152. )
  153. else:
  154. logger.exception(
  155. f"Request failed with user-set gRPC status code {exc.grpc_code}."
  156. )
  157. # Use user-set details if provided, otherwise use the original exception message.
  158. message = exc.grpc_details if exc.grpc_details else str(original_exc)
  159. return ResponseStatus(
  160. code=exc.grpc_code,
  161. is_error=True,
  162. message=_truncate_message(message),
  163. )
  164. else:
  165. if isinstance(exc, (RayActorError, RayTaskError)):
  166. logger.warning(f"Request failed: {exc}", extra={"log_to_stderr": False})
  167. else:
  168. logger.exception("Request failed due to unexpected error.")
  169. return ResponseStatus(
  170. code=grpc.StatusCode.INTERNAL,
  171. is_error=True,
  172. message=_truncate_message(str(exc)),
  173. )
  174. def set_grpc_code_and_details(
  175. context: grpc._cython.cygrpc._ServicerContext, status: ResponseStatus
  176. ):
  177. # Only the latest code and details will take effect. If the user already
  178. # set them to a truthy value in the context, skip setting them with Serve's
  179. # default values. By default, if nothing is set, the code is 0 and the
  180. # details is "", which both are falsy. So if the user did not set them or
  181. # if they're explicitly set to falsy values, such as None, Serve will
  182. # continue to set them with our default values.
  183. if not context.code():
  184. context.set_code(status.code)
  185. if not context.details():
  186. context.set_details(status.message)
  187. def set_proxy_default_grpc_options(grpc_options) -> gRPCOptions:
  188. grpc_options = deepcopy(grpc_options) or gRPCOptions()
  189. if grpc_options.request_timeout_s or RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S:
  190. grpc_options.request_timeout_s = (
  191. grpc_options.request_timeout_s or RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S
  192. )
  193. return grpc_options