| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582 |
- import asyncio
- import concurrent.futures
- import inspect
- import logging
- import pickle
- import threading
- import time
- from abc import ABC, abstractmethod
- from asyncio import run_coroutine_threadsafe
- from functools import wraps
- from typing import Any, AsyncIterator, Callable, Coroutine, Iterator, Optional, Union
- import grpc
- import ray
- from ray.exceptions import ActorUnavailableError, RayTaskError, TaskCancelledError
- from ray.serve._private.common import (
- OBJ_REF_NOT_SUPPORTED_ERROR,
- ReplicaQueueLengthInfo,
- RequestMetadata,
- )
- from ray.serve._private.constants import SERVE_LOGGER_NAME
- from ray.serve._private.http_util import MessageQueue
- from ray.serve._private.serialization import RPCSerializer
- from ray.serve._private.utils import calculate_remaining_timeout, generate_request_id
- from ray.serve.exceptions import RequestCancelledError
- from ray.serve.generated.serve_pb2 import ASGIResponse
- logger = logging.getLogger(SERVE_LOGGER_NAME)
- def is_running_in_asyncio_loop() -> bool:
- try:
- asyncio.get_running_loop()
- return True
- except RuntimeError:
- return False
- class ReplicaResult(ABC):
- @abstractmethod
- async def get_rejection_response(self) -> Optional[ReplicaQueueLengthInfo]:
- raise NotImplementedError
- @abstractmethod
- def get(self, timeout_s: Optional[float]):
- raise NotImplementedError
- @abstractmethod
- async def get_async(self):
- raise NotImplementedError
- @abstractmethod
- def __next__(self):
- raise NotImplementedError
- @abstractmethod
- async def __anext__(self):
- raise NotImplementedError
- @abstractmethod
- def add_done_callback(self, callback: Callable):
- raise NotImplementedError
- @abstractmethod
- def cancel(self):
- raise NotImplementedError
- @abstractmethod
- def to_object_ref(self, timeout_s: Optional[float]) -> ray.ObjectRef:
- raise NotImplementedError
- @abstractmethod
- async def to_object_ref_async(self) -> ray.ObjectRef:
- raise NotImplementedError
- @abstractmethod
- def to_object_ref_gen(self) -> ray.ObjectRefGenerator:
- # NOTE(edoakes): there is only a sync version of this method because it
- # does not block like `to_object_ref` (so there's also no timeout argument).
- raise NotImplementedError
- class ActorReplicaResult(ReplicaResult):
- def __init__(
- self,
- obj_ref_or_gen: Union[ray.ObjectRef, ray.ObjectRefGenerator],
- metadata: RequestMetadata,
- *,
- with_rejection: bool = False,
- ):
- self._obj_ref: Optional[ray.ObjectRef] = None
- self._obj_ref_gen: Optional[ray.ObjectRefGenerator] = None
- self._is_streaming: bool = metadata.is_streaming
- self._request_id: str = metadata.request_id
- self._object_ref_or_gen_sync_lock = threading.Lock()
- self._with_rejection = with_rejection
- self._rejection_response = None
- if isinstance(obj_ref_or_gen, ray.ObjectRefGenerator):
- self._obj_ref_gen = obj_ref_or_gen
- else:
- self._obj_ref = obj_ref_or_gen
- if self._is_streaming:
- assert (
- self._obj_ref_gen is not None
- ), "An ObjectRefGenerator must be passed for streaming requests."
- request_context = ray.serve.context._get_serve_request_context()
- if request_context.cancel_on_parent_request_cancel:
- # Keep track of in-flight requests.
- self._response_id = generate_request_id()
- ray.serve.context._add_in_flight_request(
- request_context._internal_request_id, self._response_id, self
- )
- self.add_done_callback(
- lambda _: ray.serve.context._remove_in_flight_request(
- request_context._internal_request_id, self._response_id
- )
- )
- def _process_response(f: Union[Callable, Coroutine]):
- @wraps(f)
- def wrapper(self, *args, **kwargs):
- try:
- return f(self, *args, **kwargs)
- except ray.exceptions.TaskCancelledError:
- raise RequestCancelledError(self._request_id)
- @wraps(f)
- async def async_wrapper(self, *args, **kwargs):
- try:
- return await f(self, *args, **kwargs)
- except ray.exceptions.TaskCancelledError:
- raise asyncio.CancelledError()
- if inspect.iscoroutinefunction(f):
- return async_wrapper
- else:
- return wrapper
- @_process_response
- async def get_rejection_response(self) -> Optional[ReplicaQueueLengthInfo]:
- """Get the queue length info from the replica to handle rejection."""
- assert (
- self._with_rejection and self._obj_ref_gen is not None
- ), "get_rejection_response() can only be called when request rejection is enabled."
- try:
- if self._rejection_response is None:
- response = await (await self._obj_ref_gen.__anext__())
- self._rejection_response = pickle.loads(response)
- return self._rejection_response
- except asyncio.CancelledError as e:
- # HTTP client disconnected or request was explicitly canceled.
- logger.info(
- "Cancelling request that has already been assigned to a replica."
- )
- self.cancel()
- raise e from None
- except TaskCancelledError:
- raise asyncio.CancelledError()
- @_process_response
- def get(self, timeout_s: Optional[float]):
- assert (
- not self._is_streaming
- ), "get() can only be called on a unary ActorReplicaResult."
- start_time_s = time.time()
- object_ref = self.to_object_ref(timeout_s=timeout_s)
- remaining_timeout_s = calculate_remaining_timeout(
- timeout_s=timeout_s,
- start_time_s=start_time_s,
- curr_time_s=time.time(),
- )
- return ray.get(object_ref, timeout=remaining_timeout_s)
- @_process_response
- async def get_async(self):
- assert (
- not self._is_streaming
- ), "get_async() can only be called on a unary ActorReplicaResult."
- return await (await self.to_object_ref_async())
- @_process_response
- def __next__(self):
- assert (
- self._is_streaming
- ), "next() can only be called on a streaming ActorReplicaResult."
- next_obj_ref = self._obj_ref_gen.__next__()
- return ray.get(next_obj_ref)
- @_process_response
- async def __anext__(self):
- assert (
- self._is_streaming
- ), "__anext__() can only be called on a streaming ActorReplicaResult."
- next_obj_ref = await self._obj_ref_gen.__anext__()
- return await next_obj_ref
- def add_done_callback(self, callback: Callable):
- if self._obj_ref_gen is not None:
- self._obj_ref_gen.completed()._on_completed(callback)
- else:
- self._obj_ref._on_completed(callback)
- def cancel(self):
- if self._obj_ref_gen is not None:
- ray.cancel(self._obj_ref_gen)
- else:
- ray.cancel(self._obj_ref)
- def to_object_ref(self, *, timeout_s: Optional[float] = None) -> ray.ObjectRef:
- assert (
- not self._is_streaming
- ), "to_object_ref can only be called on a unary ReplicaActorResult."
- # NOTE(edoakes): this section needs to be guarded with a lock and the resulting
- # object ref cached in order to avoid calling `__next__()` to
- # resolve to the underlying object ref more than once.
- # See: https://github.com/ray-project/ray/issues/43879.
- with self._object_ref_or_gen_sync_lock:
- if self._obj_ref is None:
- obj_ref = self._obj_ref_gen._next_sync(timeout_s=timeout_s)
- if obj_ref.is_nil():
- raise TimeoutError("Timed out resolving to ObjectRef.")
- self._obj_ref = obj_ref
- return self._obj_ref
- async def to_object_ref_async(self) -> ray.ObjectRef:
- assert (
- not self._is_streaming
- ), "to_object_ref_async can only be called on a unary ReplicaActorResult."
- # NOTE(edoakes): this section needs to be guarded with a lock and the resulting
- # object ref cached in order to avoid calling `__anext__()` to
- # resolve to the underlying object ref more than once.
- # See: https://github.com/ray-project/ray/issues/43879.
- #
- # IMPORTANT: We use a threading lock instead of asyncio.Lock because this method
- # can be called from multiple event loops concurrently:
- # 1. From the user's code (on the replica's event loop) when awaiting a response
- # 2. From the router's event loop when resolving a DeploymentResponse argument
- # asyncio.Lock is NOT thread-safe and NOT designed for cross-loop usage, which
- # causes deadlocks.
- #
- # We use a non-blocking acquire pattern to avoid blocking the event loop:
- # - Try to acquire the lock without blocking
- # - If already held, yield and retry (allows other async tasks to run)
- # - Once acquired, check if result is already available (double-check pattern)
- while True:
- # Fast path: already computed
- if self._obj_ref is not None:
- return self._obj_ref
- acquired = self._object_ref_or_gen_sync_lock.acquire(blocking=False)
- if acquired:
- try:
- # Double-check under lock
- if self._obj_ref is None:
- self._obj_ref = await self._obj_ref_gen.__anext__()
- return self._obj_ref
- finally:
- self._object_ref_or_gen_sync_lock.release()
- else:
- # Lock is held by another task/thread, yield and retry
- await asyncio.sleep(0)
- def to_object_ref_gen(self) -> ray.ObjectRefGenerator:
- assert (
- self._is_streaming
- ), "to_object_ref_gen can only be called on a streaming ReplicaActorResult."
- return self._obj_ref_gen
- class gRPCReplicaResult(ReplicaResult):
- def __init__(
- self,
- call: grpc.aio.Call,
- metadata: RequestMetadata,
- actor_id: ray.ActorID,
- loop: asyncio.AbstractEventLoop = None,
- *,
- with_rejection: bool = False,
- ):
- self._call: grpc.aio.Call = call
- self._actor_id: ray.ActorID = actor_id
- self._metadata: RequestMetadata = metadata # Store metadata for serialization
- self._result_queue: MessageQueue = MessageQueue()
- # This is the asyncio event loop that the gRPC Call object is attached to
- self._grpc_call_loop = loop or asyncio._get_running_loop()
- self._is_streaming = metadata.is_streaming
- self._with_rejection = with_rejection
- self._rejection_response = None
- self._gen = None
- self._fut = None
- # NOTE(zcin): for now, these two concepts will be synonymous.
- # In other words, using a queue means the router is running on
- # a separate thread/event loop, and vice versa not using a queue
- # means the router is running on the main event loop, where the
- # DeploymentHandle lives.
- self._calling_from_same_loop = not metadata._on_separate_loop
- if hasattr(self._call, "__aiter__"):
- self._gen = self._call.__aiter__()
- # If the grpc call IS streaming, AND it was created on a
- # a separate loop, then use a queue to fetch the objects
- self._use_queue = metadata._on_separate_loop
- else:
- self._use_queue = False
- # Start a background task that continuously fetches from the
- # streaming grpc call. This way callbacks will actually be
- # called when the request finishes even without the user
- # explicitly consuming the response.
- self._consume_task = None
- if self._use_queue:
- self._consume_task = self._grpc_call_loop.create_task(
- self.consume_messages_from_gen()
- )
- # Keep track of in-flight requests.
- self._response_id = generate_request_id()
- request_context = ray.serve.context._get_serve_request_context()
- ray.serve.context._add_in_flight_request(
- request_context._internal_request_id, self._response_id, self
- )
- self.add_done_callback(
- lambda _: ray.serve.context._remove_in_flight_request(
- request_context._internal_request_id, self._response_id
- )
- )
- def _process_grpc_response(f: Union[Callable, Coroutine]):
- def deserialize_or_raise_error(
- grpc_response: ASGIResponse,
- metadata: RequestMetadata,
- ):
- # Create serializer with options from metadata
- serializer = RPCSerializer(
- metadata.request_serialization,
- metadata.response_serialization,
- )
- if grpc_response.is_error:
- err = serializer.loads_response(grpc_response.serialized_message)
- if isinstance(err, RayTaskError):
- raise err.as_instanceof_cause()
- else:
- raise err
- else:
- # If it's an HTTP request, then the proxy response generator is
- # expecting a pickled dictionary, so we return result directly
- # without deserializing. Otherwise, we deserialize the result.
- if ray.serve.context._get_serve_request_context().is_http_request:
- return grpc_response.serialized_message
- else:
- return serializer.loads_response(grpc_response.serialized_message)
- @wraps(f)
- def wrapper(self, *args, **kwargs):
- try:
- grpc_response = f(self, *args, **kwargs)
- except grpc.aio.AioRpcError as e:
- if e.code() == grpc.StatusCode.UNAVAILABLE:
- raise ActorUnavailableError(
- "Actor is unavailable.",
- self._actor_id.binary(),
- )
- raise
- except concurrent.futures.CancelledError:
- raise RequestCancelledError from None
- return deserialize_or_raise_error(grpc_response, self._metadata)
- @wraps(f)
- async def async_wrapper(self, *args, **kwargs):
- try:
- grpc_response = await f(self, *args, **kwargs)
- except grpc.aio.AioRpcError as e:
- if e.code() == grpc.StatusCode.UNAVAILABLE:
- raise ActorUnavailableError(
- "Actor is unavailable.",
- self._actor_id.binary(),
- )
- raise
- return deserialize_or_raise_error(grpc_response, self._metadata)
- if inspect.iscoroutinefunction(f):
- return async_wrapper
- else:
- return wrapper
- def __aiter__(self) -> AsyncIterator[Any]:
- return self
- def __iter__(self) -> Iterator[Any]:
- return self
- async def consume_messages_from_gen(self):
- try:
- async for resp in self._gen:
- self._result_queue.put_nowait(resp)
- except BaseException as e:
- self._result_queue.set_error(e)
- finally:
- self._result_queue.close()
- async def _get_internal(self):
- """Gets the result from the gRPC call object.
- If the call object is a UnaryUnaryCall, we await the call.
- Otherwise the call object is a UnaryStreamCall.
- - If the request was sent on a separate loop, then the
- streamed results are being consumed and put onto the in-memory
- queue, so we read from that queue.
- - Otherwise the request was sent on the current loop, so we
- fetch the next object from the async generator.
- """
- if self._gen is None:
- return await self._call
- elif self._use_queue:
- return await self._result_queue.get_one_message()
- else:
- return await self._gen.__anext__()
- async def get_rejection_response(self) -> Optional[ReplicaQueueLengthInfo]:
- """Get the queue length info from the replica to handle rejection."""
- assert (
- self._with_rejection
- ), "get_rejection_response() can only be called when request rejection is enabled."
- try:
- if self._rejection_response is None:
- # NOTE(edoakes): this is required for gRPC to raise an AioRpcError if something
- # goes wrong establishing the connection (for example, a bug in our code).
- await self._call.wait_for_connection()
- metadata = await self._call.initial_metadata()
- accepted = metadata.get("accepted", None)
- num_ongoing_requests = metadata.get("num_ongoing_requests", None)
- if accepted is None or num_ongoing_requests is None:
- code = await self._call.code()
- details = await self._call.details()
- raise RuntimeError(f"Unexpected error ({code}): {details}.")
- self._rejection_response = ReplicaQueueLengthInfo(
- accepted=bool(int(accepted)),
- num_ongoing_requests=int(num_ongoing_requests),
- )
- return self._rejection_response
- except asyncio.CancelledError as e:
- # HTTP client disconnected or request was explicitly canceled.
- logger.info(
- "Cancelling request that has already been assigned to a replica."
- )
- self.cancel()
- raise e from None
- except grpc.aio.AioRpcError as e:
- # If we received an `UNAVAILABLE` grpc error, that is
- # equivalent to `RayActorError`, although we don't know
- # whether it's `ActorDiedError` or `ActorUnavailableError`.
- # Conservatively, we assume it is `ActorUnavailableError`,
- # and we raise it here so that it goes through the unified
- # code path for handling RayActorErrors.
- # The router will retry scheduling the request with the
- # cache invalidated, at which point if the actor is actually
- # dead, the router will realize through active probing.
- if not self._is_streaming:
- # In UnaryUnary calls, initial metadata is sent back with the request
- # response, so we can't determine if the request was accepted until
- # after the request is handled. If the replica crashed while handling
- # the request, we can still get initial metadata via the AioRpcError,
- # since the server sets the metadata before handling the request.
- # If there is no metadata, we know the replica was already unavailable
- # prior to the request being sent. We only raise an ActorUnavailableError
- # (and thus retry the request) if the request was rejected or if the
- # replica was already unavailable.
- metadata = e.initial_metadata()
- accepted = metadata.get("accepted", None)
- if accepted is not None and bool(int(accepted)):
- num_ongoing_requests = metadata.get("num_ongoing_requests", None)
- if num_ongoing_requests is None:
- raise RuntimeError(
- f"Unexpected error ({e.code()}): {e.details()}."
- )
- return ReplicaQueueLengthInfo(
- accepted=True,
- num_ongoing_requests=int(num_ongoing_requests),
- )
- if e.code() == grpc.StatusCode.UNAVAILABLE:
- raise ActorUnavailableError(
- "Actor is unavailable.",
- self._actor_id.binary(),
- )
- raise e from None
- @_process_grpc_response
- def get(self, timeout_s: Optional[float]):
- if is_running_in_asyncio_loop():
- raise RuntimeError(
- "Sync method `get()` should not be called from within an `asyncio` "
- "event loop. Use `get_async()` instead."
- )
- if self._fut is None:
- self._fut = run_coroutine_threadsafe(
- self._get_internal(), self._grpc_call_loop
- )
- try:
- return self._fut.result(timeout=timeout_s)
- except concurrent.futures.TimeoutError:
- raise TimeoutError("Timed out waiting for result.") from None
- @_process_grpc_response
- async def get_async(self):
- if self._fut is None:
- if self._calling_from_same_loop:
- return await self._get_internal()
- else:
- self._fut = run_coroutine_threadsafe(
- self._get_internal(), self._grpc_call_loop
- )
- return await asyncio.wrap_future(self._fut)
- @_process_grpc_response
- def __next__(self):
- if is_running_in_asyncio_loop():
- raise RuntimeError(
- "Sync method `__next__()` should not be called from within an "
- "`asyncio` event loop. Use `__anext__()` instead."
- )
- fut = run_coroutine_threadsafe(self._get_internal(), loop=self._grpc_call_loop)
- try:
- return fut.result()
- except StopAsyncIteration:
- # We need to raise the synchronous version, StopIteration
- raise StopIteration
- @_process_grpc_response
- async def __anext__(self):
- if self._calling_from_same_loop:
- return await self._get_internal()
- else:
- fut = run_coroutine_threadsafe(
- self._get_internal(), loop=self._grpc_call_loop
- )
- return await asyncio.wrap_future(fut)
- def add_done_callback(self, callback: Callable):
- self._call.add_done_callback(callback)
- def cancel(self):
- self._call.cancel()
- def to_object_ref(self, timeout_s: Optional[float]) -> ray.ObjectRef:
- raise OBJ_REF_NOT_SUPPORTED_ERROR
- async def to_object_ref_async(self) -> ray.ObjectRef:
- raise OBJ_REF_NOT_SUPPORTED_ERROR
- def to_object_ref_gen(self) -> ray.ObjectRefGenerator:
- raise OBJ_REF_NOT_SUPPORTED_ERROR
|