http_util.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910
  1. import asyncio
  2. import inspect
  3. import json
  4. import logging
  5. import pickle
  6. import socket
  7. from collections import deque
  8. from copy import deepcopy
  9. from dataclasses import dataclass
  10. from typing import (
  11. Any,
  12. AsyncGenerator,
  13. Awaitable,
  14. Callable,
  15. List,
  16. Optional,
  17. Tuple,
  18. Type,
  19. Union,
  20. )
  21. import starlette
  22. import uvicorn
  23. from fastapi import FastAPI
  24. from fastapi.encoders import jsonable_encoder
  25. from packaging import version
  26. from starlette.datastructures import MutableHeaders
  27. from starlette.middleware import Middleware
  28. from starlette.types import ASGIApp, Message, Receive, Scope, Send
  29. from uvicorn.config import Config
  30. from uvicorn.lifespan.on import LifespanOn
  31. from ray._common.network_utils import is_ipv6
  32. from ray._common.pydantic_compat import IS_PYDANTIC_2
  33. from ray.exceptions import RayActorError, RayTaskError
  34. from ray.serve._private.common import RequestMetadata
  35. from ray.serve._private.constants import (
  36. RAY_SERVE_HTTP_KEEP_ALIVE_TIMEOUT_S,
  37. RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH,
  38. RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S,
  39. SERVE_HTTP_REQUEST_ID_HEADER,
  40. SERVE_LOGGER_NAME,
  41. )
  42. from ray.serve._private.constants_utils import warn_if_deprecated_env_var_set
  43. from ray.serve._private.proxy_request_response import ResponseStatus
  44. from ray.serve._private.utils import (
  45. call_function_from_import_path,
  46. generate_request_id,
  47. serve_encoders,
  48. )
  49. from ray.serve.config import HTTPOptions
  50. from ray.serve.exceptions import (
  51. BackPressureError,
  52. DeploymentUnavailableError,
  53. RayServeException,
  54. )
  55. logger = logging.getLogger(SERVE_LOGGER_NAME)
  56. @dataclass(frozen=True)
  57. class ASGIArgs:
  58. scope: Scope
  59. receive: Receive
  60. send: Send
  61. def to_args_tuple(self) -> Tuple[Scope, Receive, Send]:
  62. return (self.scope, self.receive, self.send)
  63. def to_starlette_request(self) -> starlette.requests.Request:
  64. return starlette.requests.Request(
  65. *self.to_args_tuple(),
  66. )
  67. def make_buffered_asgi_receive(serialized_body: bytes) -> Receive:
  68. """Returns an ASGI receiver that returns the provided buffered body."""
  69. # Simulates receiving HTTP body from TCP socket. In reality, the body has
  70. # already been streamed in chunks and stored in serialized_body.
  71. received = False
  72. async def mock_receive():
  73. nonlocal received
  74. # If the request has already been received, starlette will keep polling
  75. # for HTTP disconnect. We will pause forever. The coroutine should be
  76. # cancelled by starlette after the response has been sent.
  77. if received:
  78. block_forever = asyncio.Event()
  79. await block_forever.wait()
  80. received = True
  81. return {"body": serialized_body, "type": "http.request", "more_body": False}
  82. return mock_receive
  83. def convert_object_to_asgi_messages(
  84. obj: Optional[Any] = None, status_code: int = 200
  85. ) -> List[Message]:
  86. """Serializes the provided object and converts it to ASGI messages.
  87. These ASGI messages can be sent via an ASGI `send` interface to comprise an HTTP
  88. response.
  89. """
  90. body = None
  91. content_type = None
  92. if obj is None:
  93. body = b""
  94. content_type = b"text/plain"
  95. elif isinstance(obj, bytes):
  96. body = obj
  97. content_type = b"text/plain"
  98. elif isinstance(obj, str):
  99. body = obj.encode("utf-8")
  100. content_type = b"text/plain; charset=utf-8"
  101. else:
  102. # `separators=(",", ":")` will remove all whitespaces between separators in the
  103. # json string and return a minimized json string. This helps to reduce the size
  104. # of the response similar to Starlette's JSONResponse.
  105. body = json.dumps(
  106. jsonable_encoder(obj, custom_encoder=serve_encoders),
  107. separators=(",", ":"),
  108. ).encode()
  109. content_type = b"application/json"
  110. return [
  111. {
  112. "type": "http.response.start",
  113. "status": status_code,
  114. "headers": [[b"content-type", content_type]],
  115. },
  116. {"type": "http.response.body", "body": body},
  117. ]
  118. class Response:
  119. """ASGI compliant response class.
  120. It is expected to be called in async context and pass along
  121. `scope, receive, send` as in ASGI spec.
  122. >>> from ray.serve.http_util import Response # doctest: +SKIP
  123. >>> scope, receive = ... # doctest: +SKIP
  124. >>> await Response({"k": "v"}).send(scope, receive, send) # doctest: +SKIP
  125. """
  126. def __init__(self, content=None, status_code=200):
  127. """Construct a HTTP Response based on input type.
  128. Args:
  129. content: Any JSON serializable object.
  130. status_code (int, optional): Default status code is 200.
  131. """
  132. self._messages = convert_object_to_asgi_messages(
  133. obj=content,
  134. status_code=status_code,
  135. )
  136. async def send(self, scope, receive, send):
  137. for message in self._messages:
  138. await send(message)
  139. async def receive_http_body(scope, receive, send):
  140. body_buffer = []
  141. more_body = True
  142. while more_body:
  143. message = await receive()
  144. assert message["type"] == "http.request"
  145. more_body = message["more_body"]
  146. body_buffer.append(message["body"])
  147. return b"".join(body_buffer)
  148. class MessageQueue(Send):
  149. """Queue enables polling for received or sent messages.
  150. Implements the ASGI `Send` interface.
  151. This class:
  152. - Is *NOT* thread safe and should only be accessed from a single asyncio
  153. event loop.
  154. - Assumes a single consumer of the queue (concurrent calls to
  155. `get_messages_nowait` and `wait_for_message` is undefined behavior).
  156. """
  157. def __init__(self):
  158. self._message_queue = deque()
  159. self._new_message_event = asyncio.Event()
  160. self._closed = False
  161. self._error = None
  162. def close(self):
  163. """Close the queue, rejecting new messages.
  164. Once the queue is closed, existing messages will be returned from
  165. `get_messages_nowait` and subsequent calls to `wait_for_message` will
  166. always return immediately.
  167. """
  168. self._closed = True
  169. self._new_message_event.set()
  170. def set_error(self, e: BaseException):
  171. self._error = e
  172. def put_nowait(self, message: Message):
  173. self._message_queue.append(message)
  174. self._new_message_event.set()
  175. async def __call__(self, message: Message):
  176. """Send a message, putting it on the queue.
  177. `RuntimeError` is raised if the queue has been closed using `.close()`.
  178. """
  179. if self._closed:
  180. raise RuntimeError("New messages cannot be sent after the queue is closed.")
  181. self.put_nowait(message)
  182. async def wait_for_message(self):
  183. """Wait until at least one new message is available.
  184. If a message is available, this method will return immediately on each call
  185. until `get_messages_nowait` is called.
  186. After the queue is closed using `.close()`, this will always return
  187. immediately.
  188. """
  189. if not self._closed:
  190. await self._new_message_event.wait()
  191. def get_messages_nowait(self) -> List[Message]:
  192. """Returns all messages that are currently available (non-blocking).
  193. At least one message will be present if `wait_for_message` had previously
  194. returned and a subsequent call to `wait_for_message` blocks until at
  195. least one new message is available.
  196. """
  197. messages = []
  198. while len(self._message_queue) > 0:
  199. messages.append(self._message_queue.popleft())
  200. self._new_message_event.clear()
  201. return messages
  202. async def get_one_message(self) -> Message:
  203. """This blocks until a message is ready.
  204. This method should not be used together with get_messages_nowait.
  205. Please use either `get_one_message` or `get_messages_nowait`.
  206. Raises:
  207. StopAsyncIteration: if the queue is closed and there are no
  208. more messages.
  209. Exception (self._error): if there are no more messages in
  210. the queue and an error has been set.
  211. """
  212. if self._error:
  213. raise self._error
  214. await self._new_message_event.wait()
  215. if len(self._message_queue) > 0:
  216. msg = self._message_queue.popleft()
  217. if len(self._message_queue) == 0 and not self._closed:
  218. self._new_message_event.clear()
  219. return msg
  220. elif len(self._message_queue) == 0 and self._error:
  221. raise self._error
  222. elif len(self._message_queue) == 0 and self._closed:
  223. raise StopAsyncIteration
  224. async def fetch_messages_from_queue(
  225. self, call_fut: asyncio.Future
  226. ) -> AsyncGenerator[List[Any], None]:
  227. """Repeatedly consume messages from the queue and yield them.
  228. This is used to fetch queue messages in the system event loop in
  229. a thread-safe manner.
  230. Args:
  231. call_fut: The async Future pointing to the task from the user
  232. code event loop that is pushing messages onto the queue.
  233. Yields:
  234. List[Any]: Messages from the queue.
  235. """
  236. # Repeatedly consume messages from the queue.
  237. wait_for_msg_task = None
  238. try:
  239. while True:
  240. wait_for_msg_task = asyncio.create_task(self.wait_for_message())
  241. done, _ = await asyncio.wait(
  242. [call_fut, wait_for_msg_task], return_when=asyncio.FIRST_COMPLETED
  243. )
  244. messages = self.get_messages_nowait()
  245. if messages:
  246. yield messages
  247. # Exit once `call_fut` has finished. In this case, all
  248. # messages must have already been sent.
  249. if call_fut in done:
  250. break
  251. e = call_fut.exception()
  252. if e is not None:
  253. raise e from None
  254. finally:
  255. if not call_fut.done():
  256. call_fut.cancel()
  257. if wait_for_msg_task is not None and not wait_for_msg_task.done():
  258. wait_for_msg_task.cancel()
  259. class ASGIReceiveProxy:
  260. """Proxies ASGI receive from an actor.
  261. The `receive_asgi_messages` callback will be called repeatedly to fetch messages
  262. until a disconnect message is received.
  263. """
  264. def __init__(
  265. self,
  266. scope: Scope,
  267. request_metadata: RequestMetadata,
  268. receive_asgi_messages: Callable[[RequestMetadata], Awaitable[bytes]],
  269. ):
  270. self._type = scope["type"] # Either 'http' or 'websocket'.
  271. # Lazy init the queue to ensure it is created in the user code event loop.
  272. self._queue = None
  273. self._request_metadata = request_metadata
  274. self._receive_asgi_messages = receive_asgi_messages
  275. self._disconnect_message = None
  276. def _get_default_disconnect_message(self) -> Message:
  277. """Return the appropriate disconnect message based on the connection type.
  278. HTTP ASGI spec:
  279. https://asgi.readthedocs.io/en/latest/specs/www.html#disconnect-receive-event
  280. WS ASGI spec:
  281. https://asgi.readthedocs.io/en/latest/specs/www.html#disconnect-receive-event-ws
  282. """
  283. if self._type == "websocket":
  284. return {
  285. "type": "websocket.disconnect",
  286. # 1005 is the default disconnect code according to the ASGI spec.
  287. "code": 1005,
  288. }
  289. else:
  290. return {"type": "http.disconnect"}
  291. @property
  292. def queue(self) -> asyncio.Queue:
  293. if self._queue is None:
  294. self._queue = asyncio.Queue()
  295. return self._queue
  296. async def fetch_until_disconnect(self):
  297. """Fetch messages repeatedly until a disconnect message is received.
  298. If a disconnect message is received, this function exits and returns it.
  299. If an exception occurs, it will be raised on the next __call__ and no more
  300. messages will be received.
  301. """
  302. while True:
  303. try:
  304. pickled_messages = await self._receive_asgi_messages(
  305. self._request_metadata
  306. )
  307. for message in pickle.loads(pickled_messages):
  308. self.queue.put_nowait(message)
  309. if message["type"] in {"http.disconnect", "websocket.disconnect"}:
  310. self._disconnect_message = message
  311. return
  312. except KeyError:
  313. # KeyError can be raised if the request is no longer active in the proxy
  314. # (i.e., the user disconnects). This is expected behavior and we should
  315. # not log an error: https://github.com/ray-project/ray/issues/43290.
  316. message = self._get_default_disconnect_message()
  317. self.queue.put_nowait(message)
  318. self._disconnect_message = message
  319. return
  320. except Exception as e:
  321. # Raise unexpected exceptions in the next `__call__`.
  322. self.queue.put_nowait(e)
  323. return
  324. async def __call__(self) -> Message:
  325. """Return the next message once available.
  326. This will repeatedly return a disconnect message once it's been received.
  327. """
  328. if self.queue.empty() and self._disconnect_message is not None:
  329. return self._disconnect_message
  330. message = await self.queue.get()
  331. if isinstance(message, Exception):
  332. raise message
  333. return message
  334. def make_fastapi_class_based_view(fastapi_app, cls: Type) -> None:
  335. """Transform the `cls`'s methods and class annotations to FastAPI routes.
  336. Modified from
  337. https://github.com/dmontagu/fastapi-utils/blob/master/fastapi_utils/cbv.py
  338. Usage:
  339. >>> from fastapi import FastAPI
  340. >>> app = FastAPI() # doctest: +SKIP
  341. >>> class A: # doctest: +SKIP
  342. ... @app.route("/{i}") # doctest: +SKIP
  343. ... def func(self, i: int) -> str: # doctest: +SKIP
  344. ... return self.dep + i # doctest: +SKIP
  345. >>> # just running the app won't work, here.
  346. >>> make_fastapi_class_based_view(app, A) # doctest: +SKIP
  347. >>> # now app can be run properly
  348. """
  349. # Delayed import to prevent ciruclar imports in workers.
  350. from fastapi import APIRouter, Depends
  351. from fastapi.routing import APIRoute, APIWebSocketRoute
  352. async def get_current_servable_instance():
  353. from ray import serve
  354. return serve.get_replica_context().servable_object
  355. # Find all the class method routes
  356. class_method_routes = [
  357. route
  358. for route in fastapi_app.routes
  359. if
  360. # User defined routes must all be APIRoute or APIWebSocketRoute.
  361. isinstance(route, (APIRoute, APIWebSocketRoute))
  362. # We want to find the route that's bound to the `cls`.
  363. # NOTE(simon): we can't use `route.endpoint in inspect.getmembers(cls)`
  364. # because the FastAPI supports different routes for the methods with
  365. # same name. See #17559.
  366. # NOTE: We check against all classes in the MRO to handle inherited
  367. # methods. When a method is inherited, its __qualname__ still references
  368. # the parent class (e.g., "ParentClass.method" not "ChildClass.method").
  369. # We use "ClassName." prefix matching (not substring) to avoid false
  370. # positives where class "A" would incorrectly match routes from "AA".
  371. and any(
  372. route.endpoint.__qualname__.startswith(base.__qualname__ + ".")
  373. for base in cls.__mro__
  374. if base is not object
  375. )
  376. ]
  377. # Modify these routes and mount it to a new APIRouter.
  378. # We need to to this (instead of modifying in place) because we want to use
  379. # the laster fastapi_app.include_router to re-run the dependency analysis
  380. # for each routes.
  381. new_router = APIRouter()
  382. for route in class_method_routes:
  383. fastapi_app.routes.remove(route)
  384. # This block just adds a default values to the self parameters so that
  385. # FastAPI knows to inject the object when calling the route.
  386. # Before: def method(self, i): ...
  387. # After: def method(self=Depends(...), *, i):...
  388. old_endpoint = route.endpoint
  389. old_signature = inspect.signature(old_endpoint)
  390. old_parameters = list(old_signature.parameters.values())
  391. if len(old_parameters) == 0:
  392. # TODO(simon): make it more flexible to support no arguments.
  393. raise RayServeException(
  394. "Methods in FastAPI class-based view must have ``self`` as "
  395. "their first argument."
  396. )
  397. old_self_parameter = old_parameters[0]
  398. new_self_parameter = old_self_parameter.replace(
  399. default=Depends(get_current_servable_instance)
  400. )
  401. new_parameters = [new_self_parameter] + [
  402. # Make the rest of the parameters keyword only because
  403. # the first argument is no longer positional.
  404. parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY)
  405. for parameter in old_parameters[1:]
  406. ]
  407. new_signature = old_signature.replace(parameters=new_parameters)
  408. route.endpoint.__signature__ = new_signature
  409. route.endpoint._serve_cls = cls
  410. new_router.routes.append(route)
  411. fastapi_app.include_router(new_router)
  412. routes_to_remove = list()
  413. for route in fastapi_app.routes:
  414. if not isinstance(route, (APIRoute, APIWebSocketRoute)):
  415. continue
  416. # If there is a response model, FastAPI creates a copy of the fields.
  417. # But FastAPI creates the field incorrectly by missing the outer_type_.
  418. if (
  419. # TODO(edoakes): I don't think this check is complete because we need
  420. # to support v1 models in v2 (from pydantic.v1 import *).
  421. not IS_PYDANTIC_2
  422. and isinstance(route, APIRoute)
  423. and route.response_model
  424. ):
  425. route.secure_cloned_response_field.outer_type_ = (
  426. route.response_field.outer_type_
  427. )
  428. # Remove endpoints that belong to other class based views.
  429. serve_cls = getattr(route.endpoint, "_serve_cls", None)
  430. if serve_cls is not None and serve_cls != cls:
  431. routes_to_remove.append(route)
  432. fastapi_app.routes[:] = [r for r in fastapi_app.routes if r not in routes_to_remove]
  433. def set_socket_reuse_port(sock: socket.socket) -> bool:
  434. """Mutate a socket object to allow multiple process listening on the same port.
  435. Returns:
  436. success: whether the setting was successful.
  437. """
  438. try:
  439. # These two socket options will allow multiple process to bind the the
  440. # same port. Kernel will evenly load balance among the port listeners.
  441. # Note: this will only work on Linux.
  442. sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  443. if hasattr(socket, "SO_REUSEPORT"):
  444. sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
  445. # In some Python binary distribution (e.g., conda py3.6), this flag
  446. # was not present at build time but available in runtime. But
  447. # Python relies on compiler flag to include this in binary.
  448. # Therefore, in the absence of socket.SO_REUSEPORT, we try
  449. # to use `15` which is value in linux kernel.
  450. # https://github.com/torvalds/linux/blob/master/tools/include/uapi/asm-generic/socket.h#L27
  451. else:
  452. sock.setsockopt(socket.SOL_SOCKET, 15, 1)
  453. return True
  454. except Exception as e:
  455. logger.debug(
  456. f"Setting SO_REUSEPORT failed because of {e}. SO_REUSEPORT is disabled."
  457. )
  458. return False
  459. class ASGIAppReplicaWrapper:
  460. """Provides a common wrapper for replicas running an ASGI app."""
  461. def __init__(self, app_or_func: Union[ASGIApp, Callable]):
  462. if inspect.isfunction(app_or_func):
  463. self._asgi_app = app_or_func()
  464. else:
  465. self._asgi_app = app_or_func
  466. # Use uvicorn's lifespan handling code to properly deal with
  467. # startup and shutdown event.
  468. # If log_config is not None, uvicorn will use the default logger.
  469. # and that interferes with our logging setup.
  470. self._serve_asgi_lifespan = LifespanOn(
  471. Config(
  472. self._asgi_app,
  473. lifespan="on",
  474. log_level=None,
  475. log_config=None,
  476. access_log=False,
  477. )
  478. )
  479. # Replace uvicorn logger with our own.
  480. self._serve_asgi_lifespan.logger = logger
  481. @property
  482. def app(self) -> ASGIApp:
  483. return self._asgi_app
  484. @property
  485. def docs_path(self) -> Optional[str]:
  486. if isinstance(self._asgi_app, FastAPI):
  487. return self._asgi_app.docs_url
  488. async def _run_asgi_lifespan_startup(self):
  489. # LifespanOn's logger logs in INFO level thus becomes spammy
  490. # Within this block we temporarily uplevel for cleaner logging
  491. from ray.serve._private.logging_utils import LoggingContext
  492. with LoggingContext(self._serve_asgi_lifespan.logger, level=logging.WARNING):
  493. await self._serve_asgi_lifespan.startup()
  494. if self._serve_asgi_lifespan.should_exit:
  495. raise RuntimeError(
  496. "ASGI lifespan startup failed. Check replica logs for details."
  497. )
  498. async def __call__(
  499. self,
  500. scope: Scope,
  501. receive: Receive,
  502. send: Send,
  503. ) -> Optional[ASGIApp]:
  504. """Calls into the wrapped ASGI app."""
  505. await self._asgi_app(
  506. scope,
  507. receive,
  508. send,
  509. )
  510. # NOTE: __del__ must be async so that we can run ASGI shutdown
  511. # in the same event loop.
  512. async def __del__(self):
  513. # LifespanOn's logger logs in INFO level thus becomes spammy.
  514. # Within this block we temporarily uplevel for cleaner logging.
  515. from ray.serve._private.logging_utils import LoggingContext
  516. with LoggingContext(self._serve_asgi_lifespan.logger, level=logging.WARNING):
  517. await self._serve_asgi_lifespan.shutdown()
  518. def validate_http_proxy_callback_return(
  519. middlewares: Any,
  520. ) -> [Middleware]:
  521. """Validate the return value of HTTP proxy callback.
  522. Middlewares should be a list of Starlette middlewares. If it is None, we
  523. will treat it as an empty list. If it is not a list, we will raise an
  524. error. If it is a list, we will check if all the items in the list are
  525. Starlette middlewares.
  526. """
  527. if middlewares is None:
  528. middlewares = []
  529. if not isinstance(middlewares, list):
  530. raise ValueError(
  531. "HTTP proxy callback must return a list of Starlette middlewares."
  532. )
  533. else:
  534. # All middlewares must be Starlette middlewares.
  535. # https://www.starlette.io/middleware/#using-pure-asgi-middleware
  536. for middleware in middlewares:
  537. if not issubclass(type(middleware), Middleware):
  538. raise ValueError(
  539. "HTTP proxy callback must return a list of Starlette middlewares, "
  540. f"instead got {type(middleware)} type item in the list."
  541. )
  542. return middlewares
  543. class RequestIdMiddleware:
  544. def __init__(self, app: ASGIApp):
  545. self._app = app
  546. async def __call__(self, scope: Scope, receive: Receive, send: Send):
  547. headers = MutableHeaders(scope=scope)
  548. request_id = headers.get(SERVE_HTTP_REQUEST_ID_HEADER)
  549. if request_id is None:
  550. request_id = generate_request_id()
  551. headers.append(SERVE_HTTP_REQUEST_ID_HEADER, request_id)
  552. async def send_with_request_id(message: Message):
  553. if message["type"] == "http.response.start":
  554. headers = MutableHeaders(scope=message)
  555. headers.append("X-Request-ID", request_id)
  556. if message["type"] == "websocket.accept":
  557. message["X-Request-ID"] = request_id
  558. await send(message)
  559. await self._app(scope, receive, send_with_request_id)
  560. def _apply_middlewares(app: ASGIApp, middlewares: List[Callable]) -> ASGIApp:
  561. """Wrap the ASGI app with the provided middlewares.
  562. The built-in RequestIdMiddleware will always be applied first.
  563. """
  564. for middleware in [Middleware(RequestIdMiddleware)] + middlewares:
  565. if version.parse(starlette.__version__) < version.parse("0.35.0"):
  566. app = middleware.cls(app, **middleware.options)
  567. else:
  568. # In starlette >= 0.35.0, middleware.options does not exist:
  569. # https://github.com/encode/starlette/pull/2381.
  570. app = middleware.cls(
  571. app,
  572. *middleware.args,
  573. **middleware.kwargs,
  574. )
  575. return app
  576. def _inject_root_path(app: ASGIApp, root_path: str):
  577. """Middleware to inject root_path to the ASGI app."""
  578. if not root_path:
  579. return app
  580. async def scope_root_path_middleware(scope, receive, send):
  581. if scope["type"] in ("http", "websocket"):
  582. scope["root_path"] = root_path
  583. await app(scope, receive, send)
  584. return scope_root_path_middleware
  585. def _apply_root_path(app: ASGIApp, root_path: str):
  586. """Handle root_path parameter across different uvicorn versions.
  587. For uvicorn >= 0.26.0, root_path must be injected into the ASGI scope
  588. rather than passed to uvicorn.Config, as uvicorn changed its behavior
  589. in version 0.26.0.
  590. Reference: https://uvicorn.dev/release-notes/#0260-january-16-2024
  591. Args:
  592. app: The ASGI application
  593. root_path: The root path prefix for all routes
  594. Returns:
  595. Tuple of (app, root_path) where:
  596. - app may be wrapped with middleware (for uvicorn >= 0.26.0)
  597. - root_path is "" for uvicorn >= 0.26.0, unchanged otherwise
  598. """
  599. if not root_path:
  600. return app, root_path
  601. uvicorn_version = version.parse(uvicorn.__version__)
  602. if uvicorn_version < version.parse("0.26.0"):
  603. return app, root_path
  604. else:
  605. app = _inject_root_path(app, root_path)
  606. return app, ""
  607. async def start_asgi_http_server(
  608. app: ASGIApp,
  609. http_options: HTTPOptions,
  610. *,
  611. event_loop: asyncio.AbstractEventLoop,
  612. enable_so_reuseport: bool = False,
  613. ) -> asyncio.Task:
  614. """Start an HTTP server to run the ASGI app.
  615. Returns a task that blocks until the server exits (e.g., due to error).
  616. """
  617. app = _apply_middlewares(app, http_options.middlewares)
  618. app, root_path = _apply_root_path(app, http_options.root_path)
  619. sock = socket.socket(
  620. socket.AF_INET6 if is_ipv6(http_options.host) else socket.AF_INET,
  621. socket.SOCK_STREAM,
  622. )
  623. if enable_so_reuseport:
  624. set_socket_reuse_port(sock)
  625. try:
  626. sock.bind((http_options.host, http_options.port))
  627. except OSError as e:
  628. raise RuntimeError(
  629. f"Failed to bind to address '{http_options.host}:{http_options.port}'."
  630. ) from e
  631. # Even though we set log_level=None, uvicorn adds MessageLoggerMiddleware
  632. # if log level for uvicorn.error is not set. And MessageLoggerMiddleware
  633. # has no use to us.
  634. logging.getLogger("uvicorn.error").level = logging.CRITICAL
  635. # Configure SSL if certificates are provided
  636. ssl_kwargs = {}
  637. if http_options.ssl_keyfile and http_options.ssl_certfile:
  638. ssl_kwargs = {
  639. "ssl_keyfile": http_options.ssl_keyfile,
  640. "ssl_certfile": http_options.ssl_certfile,
  641. }
  642. if http_options.ssl_keyfile_password:
  643. ssl_kwargs["ssl_keyfile_password"] = http_options.ssl_keyfile_password
  644. if http_options.ssl_ca_certs:
  645. ssl_kwargs["ssl_ca_certs"] = http_options.ssl_ca_certs
  646. logger.info(
  647. f"Starting HTTPS server on {http_options.host}:{http_options.port} "
  648. f"with SSL certificate: {http_options.ssl_certfile}"
  649. )
  650. # NOTE: We have to use lower level uvicorn Config and Server
  651. # class because we want to run the server as a coroutine. The only
  652. # alternative is to call uvicorn.run which is blocking.
  653. server = uvicorn.Server(
  654. config=uvicorn.Config(
  655. lambda: app,
  656. factory=True,
  657. host=http_options.host,
  658. port=http_options.port,
  659. root_path=root_path,
  660. timeout_keep_alive=http_options.keep_alive_timeout_s,
  661. loop=event_loop,
  662. lifespan="off",
  663. access_log=False,
  664. log_level=None,
  665. log_config=None,
  666. **ssl_kwargs,
  667. )
  668. )
  669. # NOTE(edoakes): we need to override install_signal_handlers here
  670. # because the existing implementation fails if it isn't running in
  671. # the main thread and uvicorn doesn't expose a way to configure it.
  672. server.install_signal_handlers = lambda: None
  673. return event_loop.create_task(server.serve(sockets=[sock]))
  674. def get_http_response_status(
  675. exc: BaseException, request_timeout_s: float, request_id: str
  676. ) -> ResponseStatus:
  677. if isinstance(exc, TimeoutError):
  678. return ResponseStatus(
  679. code=408,
  680. is_error=True,
  681. message=f"Request {request_id} timed out after {request_timeout_s}s.",
  682. )
  683. elif isinstance(exc, asyncio.CancelledError):
  684. message = f"Client for request {request_id} disconnected, cancelling request."
  685. logger.info(message)
  686. return ResponseStatus(
  687. code=499,
  688. is_error=True,
  689. message=message,
  690. )
  691. elif isinstance(exc, (BackPressureError, DeploymentUnavailableError)):
  692. if isinstance(exc, RayTaskError):
  693. logger.warning(f"Request failed: {exc}", extra={"log_to_stderr": False})
  694. return ResponseStatus(
  695. code=503,
  696. is_error=True,
  697. message=exc.message,
  698. )
  699. else:
  700. if isinstance(exc, (RayActorError, RayTaskError)):
  701. logger.warning(f"Request failed: {exc}", extra={"log_to_stderr": False})
  702. else:
  703. logger.exception("Request failed due to unexpected error.")
  704. return ResponseStatus(
  705. code=500,
  706. is_error=True,
  707. message=str(exc),
  708. )
  709. def send_http_response_on_exception(
  710. status: ResponseStatus, response_started: bool
  711. ) -> List[Message]:
  712. if response_started or status.code not in (408, 503):
  713. return []
  714. return convert_object_to_asgi_messages(
  715. status.message,
  716. status_code=status.code,
  717. )
  718. def configure_http_options_with_defaults(http_options: HTTPOptions) -> HTTPOptions:
  719. """Enhanced configuration with component-specific options."""
  720. http_options = deepcopy(http_options)
  721. # Warn if deprecated env var is set
  722. warn_if_deprecated_env_var_set("RAY_SERVE_HTTP_KEEP_ALIVE_TIMEOUT_S")
  723. # Apply environment defaults
  724. if (RAY_SERVE_HTTP_KEEP_ALIVE_TIMEOUT_S or 0) > 0:
  725. http_options.keep_alive_timeout_s = RAY_SERVE_HTTP_KEEP_ALIVE_TIMEOUT_S
  726. # TODO: Deprecate SERVE_REQUEST_PROCESSING_TIMEOUT_S env var
  727. if http_options.request_timeout_s or RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S:
  728. http_options.request_timeout_s = (
  729. http_options.request_timeout_s or RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S
  730. )
  731. http_options.middlewares = http_options.middlewares or []
  732. return http_options
  733. def configure_http_middlewares(http_options: HTTPOptions) -> HTTPOptions:
  734. http_options = deepcopy(http_options)
  735. # Add environment variable middleware
  736. if RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH:
  737. logger.info(
  738. f"Calling user-provided callback from import path "
  739. f"'{RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH}'."
  740. )
  741. # noinspection PyTypeChecker
  742. http_options.middlewares.extend(
  743. validate_http_proxy_callback_return(
  744. call_function_from_import_path(
  745. RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH
  746. )
  747. )
  748. )
  749. return http_options