replica_result.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. import asyncio
  2. import concurrent.futures
  3. import inspect
  4. import logging
  5. import pickle
  6. import threading
  7. import time
  8. from abc import ABC, abstractmethod
  9. from asyncio import run_coroutine_threadsafe
  10. from functools import wraps
  11. from typing import Any, AsyncIterator, Callable, Coroutine, Iterator, Optional, Union
  12. import grpc
  13. import ray
  14. from ray.exceptions import ActorUnavailableError, RayTaskError, TaskCancelledError
  15. from ray.serve._private.common import (
  16. OBJ_REF_NOT_SUPPORTED_ERROR,
  17. ReplicaQueueLengthInfo,
  18. RequestMetadata,
  19. )
  20. from ray.serve._private.constants import SERVE_LOGGER_NAME
  21. from ray.serve._private.http_util import MessageQueue
  22. from ray.serve._private.serialization import RPCSerializer
  23. from ray.serve._private.utils import calculate_remaining_timeout, generate_request_id
  24. from ray.serve.exceptions import RequestCancelledError
  25. from ray.serve.generated.serve_pb2 import ASGIResponse
  26. logger = logging.getLogger(SERVE_LOGGER_NAME)
  27. def is_running_in_asyncio_loop() -> bool:
  28. try:
  29. asyncio.get_running_loop()
  30. return True
  31. except RuntimeError:
  32. return False
  33. class ReplicaResult(ABC):
  34. @abstractmethod
  35. async def get_rejection_response(self) -> Optional[ReplicaQueueLengthInfo]:
  36. raise NotImplementedError
  37. @abstractmethod
  38. def get(self, timeout_s: Optional[float]):
  39. raise NotImplementedError
  40. @abstractmethod
  41. async def get_async(self):
  42. raise NotImplementedError
  43. @abstractmethod
  44. def __next__(self):
  45. raise NotImplementedError
  46. @abstractmethod
  47. async def __anext__(self):
  48. raise NotImplementedError
  49. @abstractmethod
  50. def add_done_callback(self, callback: Callable):
  51. raise NotImplementedError
  52. @abstractmethod
  53. def cancel(self):
  54. raise NotImplementedError
  55. @abstractmethod
  56. def to_object_ref(self, timeout_s: Optional[float]) -> ray.ObjectRef:
  57. raise NotImplementedError
  58. @abstractmethod
  59. async def to_object_ref_async(self) -> ray.ObjectRef:
  60. raise NotImplementedError
  61. @abstractmethod
  62. def to_object_ref_gen(self) -> ray.ObjectRefGenerator:
  63. # NOTE(edoakes): there is only a sync version of this method because it
  64. # does not block like `to_object_ref` (so there's also no timeout argument).
  65. raise NotImplementedError
  66. class ActorReplicaResult(ReplicaResult):
  67. def __init__(
  68. self,
  69. obj_ref_or_gen: Union[ray.ObjectRef, ray.ObjectRefGenerator],
  70. metadata: RequestMetadata,
  71. *,
  72. with_rejection: bool = False,
  73. ):
  74. self._obj_ref: Optional[ray.ObjectRef] = None
  75. self._obj_ref_gen: Optional[ray.ObjectRefGenerator] = None
  76. self._is_streaming: bool = metadata.is_streaming
  77. self._request_id: str = metadata.request_id
  78. self._object_ref_or_gen_sync_lock = threading.Lock()
  79. self._with_rejection = with_rejection
  80. self._rejection_response = None
  81. if isinstance(obj_ref_or_gen, ray.ObjectRefGenerator):
  82. self._obj_ref_gen = obj_ref_or_gen
  83. else:
  84. self._obj_ref = obj_ref_or_gen
  85. if self._is_streaming:
  86. assert (
  87. self._obj_ref_gen is not None
  88. ), "An ObjectRefGenerator must be passed for streaming requests."
  89. request_context = ray.serve.context._get_serve_request_context()
  90. if request_context.cancel_on_parent_request_cancel:
  91. # Keep track of in-flight requests.
  92. self._response_id = generate_request_id()
  93. ray.serve.context._add_in_flight_request(
  94. request_context._internal_request_id, self._response_id, self
  95. )
  96. self.add_done_callback(
  97. lambda _: ray.serve.context._remove_in_flight_request(
  98. request_context._internal_request_id, self._response_id
  99. )
  100. )
  101. def _process_response(f: Union[Callable, Coroutine]):
  102. @wraps(f)
  103. def wrapper(self, *args, **kwargs):
  104. try:
  105. return f(self, *args, **kwargs)
  106. except ray.exceptions.TaskCancelledError:
  107. raise RequestCancelledError(self._request_id)
  108. @wraps(f)
  109. async def async_wrapper(self, *args, **kwargs):
  110. try:
  111. return await f(self, *args, **kwargs)
  112. except ray.exceptions.TaskCancelledError:
  113. raise asyncio.CancelledError()
  114. if inspect.iscoroutinefunction(f):
  115. return async_wrapper
  116. else:
  117. return wrapper
  118. @_process_response
  119. async def get_rejection_response(self) -> Optional[ReplicaQueueLengthInfo]:
  120. """Get the queue length info from the replica to handle rejection."""
  121. assert (
  122. self._with_rejection and self._obj_ref_gen is not None
  123. ), "get_rejection_response() can only be called when request rejection is enabled."
  124. try:
  125. if self._rejection_response is None:
  126. response = await (await self._obj_ref_gen.__anext__())
  127. self._rejection_response = pickle.loads(response)
  128. return self._rejection_response
  129. except asyncio.CancelledError as e:
  130. # HTTP client disconnected or request was explicitly canceled.
  131. logger.info(
  132. "Cancelling request that has already been assigned to a replica."
  133. )
  134. self.cancel()
  135. raise e from None
  136. except TaskCancelledError:
  137. raise asyncio.CancelledError()
  138. @_process_response
  139. def get(self, timeout_s: Optional[float]):
  140. assert (
  141. not self._is_streaming
  142. ), "get() can only be called on a unary ActorReplicaResult."
  143. start_time_s = time.time()
  144. object_ref = self.to_object_ref(timeout_s=timeout_s)
  145. remaining_timeout_s = calculate_remaining_timeout(
  146. timeout_s=timeout_s,
  147. start_time_s=start_time_s,
  148. curr_time_s=time.time(),
  149. )
  150. return ray.get(object_ref, timeout=remaining_timeout_s)
  151. @_process_response
  152. async def get_async(self):
  153. assert (
  154. not self._is_streaming
  155. ), "get_async() can only be called on a unary ActorReplicaResult."
  156. return await (await self.to_object_ref_async())
  157. @_process_response
  158. def __next__(self):
  159. assert (
  160. self._is_streaming
  161. ), "next() can only be called on a streaming ActorReplicaResult."
  162. next_obj_ref = self._obj_ref_gen.__next__()
  163. return ray.get(next_obj_ref)
  164. @_process_response
  165. async def __anext__(self):
  166. assert (
  167. self._is_streaming
  168. ), "__anext__() can only be called on a streaming ActorReplicaResult."
  169. next_obj_ref = await self._obj_ref_gen.__anext__()
  170. return await next_obj_ref
  171. def add_done_callback(self, callback: Callable):
  172. if self._obj_ref_gen is not None:
  173. self._obj_ref_gen.completed()._on_completed(callback)
  174. else:
  175. self._obj_ref._on_completed(callback)
  176. def cancel(self):
  177. if self._obj_ref_gen is not None:
  178. ray.cancel(self._obj_ref_gen)
  179. else:
  180. ray.cancel(self._obj_ref)
  181. def to_object_ref(self, *, timeout_s: Optional[float] = None) -> ray.ObjectRef:
  182. assert (
  183. not self._is_streaming
  184. ), "to_object_ref can only be called on a unary ReplicaActorResult."
  185. # NOTE(edoakes): this section needs to be guarded with a lock and the resulting
  186. # object ref cached in order to avoid calling `__next__()` to
  187. # resolve to the underlying object ref more than once.
  188. # See: https://github.com/ray-project/ray/issues/43879.
  189. with self._object_ref_or_gen_sync_lock:
  190. if self._obj_ref is None:
  191. obj_ref = self._obj_ref_gen._next_sync(timeout_s=timeout_s)
  192. if obj_ref.is_nil():
  193. raise TimeoutError("Timed out resolving to ObjectRef.")
  194. self._obj_ref = obj_ref
  195. return self._obj_ref
  196. async def to_object_ref_async(self) -> ray.ObjectRef:
  197. assert (
  198. not self._is_streaming
  199. ), "to_object_ref_async can only be called on a unary ReplicaActorResult."
  200. # NOTE(edoakes): this section needs to be guarded with a lock and the resulting
  201. # object ref cached in order to avoid calling `__anext__()` to
  202. # resolve to the underlying object ref more than once.
  203. # See: https://github.com/ray-project/ray/issues/43879.
  204. #
  205. # IMPORTANT: We use a threading lock instead of asyncio.Lock because this method
  206. # can be called from multiple event loops concurrently:
  207. # 1. From the user's code (on the replica's event loop) when awaiting a response
  208. # 2. From the router's event loop when resolving a DeploymentResponse argument
  209. # asyncio.Lock is NOT thread-safe and NOT designed for cross-loop usage, which
  210. # causes deadlocks.
  211. #
  212. # We use a non-blocking acquire pattern to avoid blocking the event loop:
  213. # - Try to acquire the lock without blocking
  214. # - If already held, yield and retry (allows other async tasks to run)
  215. # - Once acquired, check if result is already available (double-check pattern)
  216. while True:
  217. # Fast path: already computed
  218. if self._obj_ref is not None:
  219. return self._obj_ref
  220. acquired = self._object_ref_or_gen_sync_lock.acquire(blocking=False)
  221. if acquired:
  222. try:
  223. # Double-check under lock
  224. if self._obj_ref is None:
  225. self._obj_ref = await self._obj_ref_gen.__anext__()
  226. return self._obj_ref
  227. finally:
  228. self._object_ref_or_gen_sync_lock.release()
  229. else:
  230. # Lock is held by another task/thread, yield and retry
  231. await asyncio.sleep(0)
  232. def to_object_ref_gen(self) -> ray.ObjectRefGenerator:
  233. assert (
  234. self._is_streaming
  235. ), "to_object_ref_gen can only be called on a streaming ReplicaActorResult."
  236. return self._obj_ref_gen
  237. class gRPCReplicaResult(ReplicaResult):
  238. def __init__(
  239. self,
  240. call: grpc.aio.Call,
  241. metadata: RequestMetadata,
  242. actor_id: ray.ActorID,
  243. loop: asyncio.AbstractEventLoop = None,
  244. *,
  245. with_rejection: bool = False,
  246. ):
  247. self._call: grpc.aio.Call = call
  248. self._actor_id: ray.ActorID = actor_id
  249. self._metadata: RequestMetadata = metadata # Store metadata for serialization
  250. self._result_queue: MessageQueue = MessageQueue()
  251. # This is the asyncio event loop that the gRPC Call object is attached to
  252. self._grpc_call_loop = loop or asyncio._get_running_loop()
  253. self._is_streaming = metadata.is_streaming
  254. self._with_rejection = with_rejection
  255. self._rejection_response = None
  256. self._gen = None
  257. self._fut = None
  258. # NOTE(zcin): for now, these two concepts will be synonymous.
  259. # In other words, using a queue means the router is running on
  260. # a separate thread/event loop, and vice versa not using a queue
  261. # means the router is running on the main event loop, where the
  262. # DeploymentHandle lives.
  263. self._calling_from_same_loop = not metadata._on_separate_loop
  264. if hasattr(self._call, "__aiter__"):
  265. self._gen = self._call.__aiter__()
  266. # If the grpc call IS streaming, AND it was created on a
  267. # a separate loop, then use a queue to fetch the objects
  268. self._use_queue = metadata._on_separate_loop
  269. else:
  270. self._use_queue = False
  271. # Start a background task that continuously fetches from the
  272. # streaming grpc call. This way callbacks will actually be
  273. # called when the request finishes even without the user
  274. # explicitly consuming the response.
  275. self._consume_task = None
  276. if self._use_queue:
  277. self._consume_task = self._grpc_call_loop.create_task(
  278. self.consume_messages_from_gen()
  279. )
  280. # Keep track of in-flight requests.
  281. self._response_id = generate_request_id()
  282. request_context = ray.serve.context._get_serve_request_context()
  283. ray.serve.context._add_in_flight_request(
  284. request_context._internal_request_id, self._response_id, self
  285. )
  286. self.add_done_callback(
  287. lambda _: ray.serve.context._remove_in_flight_request(
  288. request_context._internal_request_id, self._response_id
  289. )
  290. )
  291. def _process_grpc_response(f: Union[Callable, Coroutine]):
  292. def deserialize_or_raise_error(
  293. grpc_response: ASGIResponse,
  294. metadata: RequestMetadata,
  295. ):
  296. # Create serializer with options from metadata
  297. serializer = RPCSerializer(
  298. metadata.request_serialization,
  299. metadata.response_serialization,
  300. )
  301. if grpc_response.is_error:
  302. err = serializer.loads_response(grpc_response.serialized_message)
  303. if isinstance(err, RayTaskError):
  304. raise err.as_instanceof_cause()
  305. else:
  306. raise err
  307. else:
  308. # If it's an HTTP request, then the proxy response generator is
  309. # expecting a pickled dictionary, so we return result directly
  310. # without deserializing. Otherwise, we deserialize the result.
  311. if ray.serve.context._get_serve_request_context().is_http_request:
  312. return grpc_response.serialized_message
  313. else:
  314. return serializer.loads_response(grpc_response.serialized_message)
  315. @wraps(f)
  316. def wrapper(self, *args, **kwargs):
  317. try:
  318. grpc_response = f(self, *args, **kwargs)
  319. except grpc.aio.AioRpcError as e:
  320. if e.code() == grpc.StatusCode.UNAVAILABLE:
  321. raise ActorUnavailableError(
  322. "Actor is unavailable.",
  323. self._actor_id.binary(),
  324. )
  325. raise
  326. except concurrent.futures.CancelledError:
  327. raise RequestCancelledError from None
  328. return deserialize_or_raise_error(grpc_response, self._metadata)
  329. @wraps(f)
  330. async def async_wrapper(self, *args, **kwargs):
  331. try:
  332. grpc_response = await f(self, *args, **kwargs)
  333. except grpc.aio.AioRpcError as e:
  334. if e.code() == grpc.StatusCode.UNAVAILABLE:
  335. raise ActorUnavailableError(
  336. "Actor is unavailable.",
  337. self._actor_id.binary(),
  338. )
  339. raise
  340. return deserialize_or_raise_error(grpc_response, self._metadata)
  341. if inspect.iscoroutinefunction(f):
  342. return async_wrapper
  343. else:
  344. return wrapper
  345. def __aiter__(self) -> AsyncIterator[Any]:
  346. return self
  347. def __iter__(self) -> Iterator[Any]:
  348. return self
  349. async def consume_messages_from_gen(self):
  350. try:
  351. async for resp in self._gen:
  352. self._result_queue.put_nowait(resp)
  353. except BaseException as e:
  354. self._result_queue.set_error(e)
  355. finally:
  356. self._result_queue.close()
  357. async def _get_internal(self):
  358. """Gets the result from the gRPC call object.
  359. If the call object is a UnaryUnaryCall, we await the call.
  360. Otherwise the call object is a UnaryStreamCall.
  361. - If the request was sent on a separate loop, then the
  362. streamed results are being consumed and put onto the in-memory
  363. queue, so we read from that queue.
  364. - Otherwise the request was sent on the current loop, so we
  365. fetch the next object from the async generator.
  366. """
  367. if self._gen is None:
  368. return await self._call
  369. elif self._use_queue:
  370. return await self._result_queue.get_one_message()
  371. else:
  372. return await self._gen.__anext__()
  373. async def get_rejection_response(self) -> Optional[ReplicaQueueLengthInfo]:
  374. """Get the queue length info from the replica to handle rejection."""
  375. assert (
  376. self._with_rejection
  377. ), "get_rejection_response() can only be called when request rejection is enabled."
  378. try:
  379. if self._rejection_response is None:
  380. # NOTE(edoakes): this is required for gRPC to raise an AioRpcError if something
  381. # goes wrong establishing the connection (for example, a bug in our code).
  382. await self._call.wait_for_connection()
  383. metadata = await self._call.initial_metadata()
  384. accepted = metadata.get("accepted", None)
  385. num_ongoing_requests = metadata.get("num_ongoing_requests", None)
  386. if accepted is None or num_ongoing_requests is None:
  387. code = await self._call.code()
  388. details = await self._call.details()
  389. raise RuntimeError(f"Unexpected error ({code}): {details}.")
  390. self._rejection_response = ReplicaQueueLengthInfo(
  391. accepted=bool(int(accepted)),
  392. num_ongoing_requests=int(num_ongoing_requests),
  393. )
  394. return self._rejection_response
  395. except asyncio.CancelledError as e:
  396. # HTTP client disconnected or request was explicitly canceled.
  397. logger.info(
  398. "Cancelling request that has already been assigned to a replica."
  399. )
  400. self.cancel()
  401. raise e from None
  402. except grpc.aio.AioRpcError as e:
  403. # If we received an `UNAVAILABLE` grpc error, that is
  404. # equivalent to `RayActorError`, although we don't know
  405. # whether it's `ActorDiedError` or `ActorUnavailableError`.
  406. # Conservatively, we assume it is `ActorUnavailableError`,
  407. # and we raise it here so that it goes through the unified
  408. # code path for handling RayActorErrors.
  409. # The router will retry scheduling the request with the
  410. # cache invalidated, at which point if the actor is actually
  411. # dead, the router will realize through active probing.
  412. if not self._is_streaming:
  413. # In UnaryUnary calls, initial metadata is sent back with the request
  414. # response, so we can't determine if the request was accepted until
  415. # after the request is handled. If the replica crashed while handling
  416. # the request, we can still get initial metadata via the AioRpcError,
  417. # since the server sets the metadata before handling the request.
  418. # If there is no metadata, we know the replica was already unavailable
  419. # prior to the request being sent. We only raise an ActorUnavailableError
  420. # (and thus retry the request) if the request was rejected or if the
  421. # replica was already unavailable.
  422. metadata = e.initial_metadata()
  423. accepted = metadata.get("accepted", None)
  424. if accepted is not None and bool(int(accepted)):
  425. num_ongoing_requests = metadata.get("num_ongoing_requests", None)
  426. if num_ongoing_requests is None:
  427. raise RuntimeError(
  428. f"Unexpected error ({e.code()}): {e.details()}."
  429. )
  430. return ReplicaQueueLengthInfo(
  431. accepted=True,
  432. num_ongoing_requests=int(num_ongoing_requests),
  433. )
  434. if e.code() == grpc.StatusCode.UNAVAILABLE:
  435. raise ActorUnavailableError(
  436. "Actor is unavailable.",
  437. self._actor_id.binary(),
  438. )
  439. raise e from None
  440. @_process_grpc_response
  441. def get(self, timeout_s: Optional[float]):
  442. if is_running_in_asyncio_loop():
  443. raise RuntimeError(
  444. "Sync method `get()` should not be called from within an `asyncio` "
  445. "event loop. Use `get_async()` instead."
  446. )
  447. if self._fut is None:
  448. self._fut = run_coroutine_threadsafe(
  449. self._get_internal(), self._grpc_call_loop
  450. )
  451. try:
  452. return self._fut.result(timeout=timeout_s)
  453. except concurrent.futures.TimeoutError:
  454. raise TimeoutError("Timed out waiting for result.") from None
  455. @_process_grpc_response
  456. async def get_async(self):
  457. if self._fut is None:
  458. if self._calling_from_same_loop:
  459. return await self._get_internal()
  460. else:
  461. self._fut = run_coroutine_threadsafe(
  462. self._get_internal(), self._grpc_call_loop
  463. )
  464. return await asyncio.wrap_future(self._fut)
  465. @_process_grpc_response
  466. def __next__(self):
  467. if is_running_in_asyncio_loop():
  468. raise RuntimeError(
  469. "Sync method `__next__()` should not be called from within an "
  470. "`asyncio` event loop. Use `__anext__()` instead."
  471. )
  472. fut = run_coroutine_threadsafe(self._get_internal(), loop=self._grpc_call_loop)
  473. try:
  474. return fut.result()
  475. except StopAsyncIteration:
  476. # We need to raise the synchronous version, StopIteration
  477. raise StopIteration
  478. @_process_grpc_response
  479. async def __anext__(self):
  480. if self._calling_from_same_loop:
  481. return await self._get_internal()
  482. else:
  483. fut = run_coroutine_threadsafe(
  484. self._get_internal(), loop=self._grpc_call_loop
  485. )
  486. return await asyncio.wrap_future(fut)
  487. def add_done_callback(self, callback: Callable):
  488. self._call.add_done_callback(callback)
  489. def cancel(self):
  490. self._call.cancel()
  491. def to_object_ref(self, timeout_s: Optional[float]) -> ray.ObjectRef:
  492. raise OBJ_REF_NOT_SUPPORTED_ERROR
  493. async def to_object_ref_async(self) -> ray.ObjectRef:
  494. raise OBJ_REF_NOT_SUPPORTED_ERROR
  495. def to_object_ref_gen(self) -> ray.ObjectRefGenerator:
  496. raise OBJ_REF_NOT_SUPPORTED_ERROR