| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397 |
- import asyncio
- import concurrent.futures
- import errno
- import functools
- import inspect
- import logging
- import math
- import os
- import pickle
- import threading
- import time
- import traceback
- import warnings
- from abc import ABC, abstractmethod
- from collections import defaultdict, deque
- from contextlib import asynccontextmanager, contextmanager
- from dataclasses import dataclass
- from functools import wraps
- from importlib import import_module
- from typing import (
- Any,
- AsyncGenerator,
- Callable,
- Dict,
- Generator,
- List,
- Optional,
- Set,
- Tuple,
- Union,
- )
- import grpc
- import starlette.responses
- from anyio import to_thread
- from fastapi import Request
- from starlette.applications import Starlette
- from starlette.types import ASGIApp, Receive, Scope, Send
- import ray
- from ray import cloudpickle
- from ray._common.filters import CoreContextFilter
- from ray._common.utils import get_or_create_event_loop
- from ray.actor import ActorClass, ActorHandle
- from ray.dag.py_obj_scanner import _PyObjScanner
- from ray.remote_function import RemoteFunction
- from ray.serve import metrics
- from ray.serve._private.common import (
- RUNNING_REQUESTS_KEY,
- DeploymentID,
- ReplicaID,
- ReplicaMetricReport,
- ReplicaQueueLengthInfo,
- RequestMetadata,
- RequestProtocol,
- ServeComponentType,
- StreamingHTTPRequest,
- gRPCRequest,
- )
- from ray.serve._private.config import DeploymentConfig
- from ray.serve._private.constants import (
- GRPC_CONTEXT_ARG_NAME,
- HEALTH_CHECK_METHOD,
- HEALTHY_MESSAGE,
- RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
- RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S,
- RAY_SERVE_DIRECT_INGRESS_PORT_RETRY_COUNT,
- RAY_SERVE_ENABLE_DIRECT_INGRESS,
- RAY_SERVE_METRICS_EXPORT_INTERVAL_MS,
- RAY_SERVE_RECORD_AUTOSCALING_STATS_TIMEOUT_S,
- RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_INTERVAL_S,
- RAY_SERVE_REPLICA_GRPC_MAX_MESSAGE_LENGTH,
- RAY_SERVE_REQUEST_PATH_LOG_BUFFER_SIZE,
- RAY_SERVE_RUN_SYNC_IN_THREADPOOL,
- RAY_SERVE_RUN_SYNC_IN_THREADPOOL_WARNING,
- RAY_SERVE_RUN_USER_CODE_IN_SEPARATE_THREAD,
- RECONFIGURE_METHOD,
- REQUEST_LATENCY_BUCKETS_MS,
- REQUEST_ROUTING_STATS_METHOD,
- SERVE_CONTROLLER_NAME,
- SERVE_HTTP_REQUEST_DISCONNECT_DISABLED_HEADER,
- SERVE_HTTP_REQUEST_ID_HEADER,
- SERVE_HTTP_REQUEST_TIMEOUT_S_HEADER,
- SERVE_LOG_APPLICATION,
- SERVE_LOG_COMPONENT,
- SERVE_LOG_DEPLOYMENT,
- SERVE_LOG_REPLICA,
- SERVE_LOG_REQUEST_ID,
- SERVE_LOG_ROUTE,
- SERVE_LOGGER_NAME,
- SERVE_NAMESPACE,
- )
- from ray.serve._private.default_impl import (
- create_replica_impl,
- create_replica_metrics_manager,
- )
- from ray.serve._private.direct_ingress_http_util import ASGIDIReceiveProxy
- from ray.serve._private.event_loop_monitoring import EventLoopMonitor
- from ray.serve._private.grpc_util import (
- get_grpc_response_status,
- set_grpc_code_and_details,
- start_grpc_server,
- )
- from ray.serve._private.http_util import (
- ASGIAppReplicaWrapper,
- ASGIArgs,
- ASGIReceiveProxy,
- MessageQueue,
- Response,
- configure_http_middlewares,
- configure_http_options_with_defaults,
- convert_object_to_asgi_messages,
- start_asgi_http_server,
- )
- from ray.serve._private.logging_utils import (
- access_log_msg,
- configure_component_logger,
- configure_component_memory_profiler,
- get_component_logger_file_path,
- )
- from ray.serve._private.metrics_utils import InMemoryMetricsStore, MetricsPusher
- from ray.serve._private.proxy_request_response import ResponseStatus
- from ray.serve._private.replica_response_generator import ReplicaResponseGenerator
- from ray.serve._private.serialization import RPCSerializer
- from ray.serve._private.task_consumer import TaskConsumerWrapper
- from ray.serve._private.thirdparty.get_asgi_route_name import (
- extract_route_patterns,
- get_asgi_route_name,
- )
- from ray.serve._private.usage import ServeUsageTag
- from ray.serve._private.utils import (
- Semaphore,
- asyncio_grpc_exception_handler,
- generate_request_id,
- get_component_file_name, # noqa: F401
- is_grpc_enabled,
- parse_import_path,
- )
- from ray.serve._private.version import DeploymentVersion
- from ray.serve.config import AutoscalingConfig, HTTPOptions, gRPCOptions
- from ray.serve.context import _get_in_flight_requests
- from ray.serve.deployment import Deployment
- from ray.serve.exceptions import (
- BackPressureError,
- DeploymentUnavailableError,
- RayServeException,
- gRPCStatusError,
- )
- from ray.serve.generated.serve_pb2 import (
- ASGIRequest,
- ASGIResponse,
- HealthzResponse,
- ListApplicationsResponse,
- )
- from ray.serve.generated.serve_pb2_grpc import add_ASGIServiceServicer_to_server
- from ray.serve.grpc_util import RayServegRPCContext
- from ray.serve.handle import DeploymentHandle
- from ray.serve.schema import EncodingType, LoggingConfig, ReplicaRank
- from ray.util import metrics as ray_metrics
- logger = logging.getLogger(SERVE_LOGGER_NAME)
- def _wrap_grpc_call(f):
- """Decorator that processes grpc methods."""
- def serialize(result, metadata):
- if metadata.is_streaming and metadata.is_http_request:
- return result
- else:
- # Use cached serializer to avoid per-request instantiation overhead
- serializer = RPCSerializer.get_cached_serializer(
- metadata.request_serialization,
- metadata.response_serialization,
- )
- return serializer.dumps_response(result)
- @wraps(f)
- async def wrapper(
- self,
- request: ASGIRequest,
- context: grpc.aio.ServicerContext,
- ):
- request_metadata = pickle.loads(request.pickled_request_metadata)
- # Get cached serializer with options from metadata
- serializer = RPCSerializer.get_cached_serializer(
- request_metadata.request_serialization,
- request_metadata.response_serialization,
- )
- request_args = serializer.loads_request(request.request_args)
- request_kwargs = serializer.loads_request(request.request_kwargs)
- if request_metadata.is_http_request or request_metadata.is_grpc_request:
- request_args = (pickle.loads(request_args[0]),)
- try:
- result = await f(
- self, context, request_metadata, *request_args, **request_kwargs
- )
- return ASGIResponse(serialized_message=serialize(result, request_metadata))
- except (Exception, asyncio.CancelledError) as e:
- return ASGIResponse(
- serialized_message=serializer.dumps_response(e),
- is_error=True,
- )
- @wraps(f)
- async def gen_wrapper(
- self,
- request: ASGIRequest,
- context: grpc.aio.ServicerContext,
- ):
- request_metadata = pickle.loads(request.pickled_request_metadata)
- # Get cached serializer with options from metadata
- serializer = RPCSerializer.get_cached_serializer(
- request_metadata.request_serialization,
- request_metadata.response_serialization,
- )
- request_args = serializer.loads_request(request.request_args)
- request_kwargs = serializer.loads_request(request.request_kwargs)
- if request_metadata.is_http_request or request_metadata.is_grpc_request:
- request_args = (pickle.loads(request_args[0]),)
- try:
- async for result in f(
- self, context, request_metadata, *request_args, **request_kwargs
- ):
- yield ASGIResponse(
- serialized_message=serialize(result, request_metadata)
- )
- except (Exception, asyncio.CancelledError) as e:
- yield ASGIResponse(
- serialized_message=serializer.dumps_response(e),
- is_error=True,
- )
- if inspect.isasyncgenfunction(f):
- return gen_wrapper
- else:
- return wrapper
- ReplicaMetadata = Tuple[
- DeploymentConfig,
- DeploymentVersion,
- Optional[float],
- Optional[int],
- Optional[str],
- int,
- int,
- ReplicaRank, # rank
- Optional[List[str]], # route_patterns
- Optional[List[DeploymentID]], # outbound_deployments
- ]
- def _load_deployment_def_from_import_path(import_path: str) -> Callable:
- module_name, attr_name = parse_import_path(import_path)
- deployment_def = getattr(import_module(module_name), attr_name)
- # For ray or serve decorated class or function, strip to return
- # original body.
- if isinstance(deployment_def, RemoteFunction):
- deployment_def = deployment_def._function
- elif isinstance(deployment_def, ActorClass):
- deployment_def = deployment_def.__ray_metadata__.modified_class
- elif isinstance(deployment_def, Deployment):
- logger.warning(
- f'The import path "{import_path}" contains a '
- "decorated Serve deployment. The decorator's settings "
- "are ignored when deploying via import path."
- )
- deployment_def = deployment_def.func_or_class
- return deployment_def
- class ReplicaMetricsManager:
- """Manages metrics for the replica.
- A variety of metrics are managed:
- - Fine-grained metrics are set for every request.
- - Autoscaling statistics are periodically pushed to the controller.
- - Queue length metrics are periodically recorded as user-facing gauges.
- """
- PUSH_METRICS_TO_CONTROLLER_TASK_NAME = "push_metrics_to_controller"
- RECORD_METRICS_TASK_NAME = "record_metrics"
- SET_REPLICA_REQUEST_METRIC_GAUGE_TASK_NAME = "set_replica_request_metric_gauge"
- def __init__(
- self,
- replica_id: ReplicaID,
- event_loop: asyncio.BaseEventLoop,
- autoscaling_config: Optional[AutoscalingConfig],
- ingress: bool,
- ):
- self._replica_id = replica_id
- self._deployment_id = replica_id.deployment_id
- self._metrics_pusher = MetricsPusher()
- self._metrics_store = InMemoryMetricsStore()
- self._ingress = ingress
- self._controller_handle = ray.get_actor(
- SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE
- )
- self._num_ongoing_requests = 0
- # Store event loop for scheduling async tasks from sync context
- self._event_loop = event_loop or asyncio.get_event_loop()
- # Cache user_callable_wrapper initialization state to avoid repeated runtime checks
- self._custom_metrics_enabled = False
- # On first call to _fetch_custom_autoscaling_metrics. Failing validation disables _custom_metrics_enabled
- self._checked_custom_metrics = False
- self._record_autoscaling_stats_fn = None
- # If the interval is set to 0, eagerly sets all metrics.
- self._cached_metrics_enabled = RAY_SERVE_METRICS_EXPORT_INTERVAL_MS != 0
- self._cached_metrics_interval_s = RAY_SERVE_METRICS_EXPORT_INTERVAL_MS / 1000
- # Request counter (only set on replica startup).
- self._restart_counter = metrics.Counter(
- "serve_deployment_replica_starts",
- description=(
- "The number of times this replica has been restarted due to failure."
- ),
- )
- self._restart_counter.inc()
- # Per-request metrics.
- self._request_counter = metrics.Counter(
- "serve_deployment_request_counter",
- description=(
- "The number of queries that have been processed in this replica."
- ),
- tag_keys=("route",),
- )
- if self._cached_metrics_enabled:
- self._cached_request_counter = defaultdict(int)
- self._error_counter = metrics.Counter(
- "serve_deployment_error_counter",
- description=(
- "The number of exceptions that have occurred in this replica."
- ),
- tag_keys=("route",),
- )
- if self._cached_metrics_enabled:
- self._cached_error_counter = defaultdict(int)
- # log REQUEST_LATENCY_BUCKET_MS
- logger.debug(f"REQUEST_LATENCY_BUCKETS_MS: {REQUEST_LATENCY_BUCKETS_MS}")
- self._processing_latency_tracker = metrics.Histogram(
- "serve_deployment_processing_latency_ms",
- description="The latency for queries to be processed.",
- boundaries=REQUEST_LATENCY_BUCKETS_MS,
- tag_keys=("route",),
- )
- if self._cached_metrics_enabled:
- self._cached_latencies = defaultdict(deque)
- self._event_loop.create_task(self._report_cached_metrics_forever())
- self._num_ongoing_requests_gauge = metrics.Gauge(
- "serve_replica_processing_queries",
- description="The current number of queries being processed.",
- )
- self.record_autoscaling_stats_failed_counter = metrics.Counter(
- "serve_record_autoscaling_stats_failed",
- tag_keys=("exception_name",),
- description="The number of errored record_autoscaling_stats invocations.",
- )
- self.user_autoscaling_stats_latency_tracker = metrics.Histogram(
- "serve_user_autoscaling_stats_latency_ms",
- description=(
- "Time taken to execute the user-defined autoscaling stats function "
- "in milliseconds."
- ),
- boundaries=REQUEST_LATENCY_BUCKETS_MS,
- )
- self.set_autoscaling_config(autoscaling_config)
- if self._is_direct_ingress:
- # TODO(alexyang): De-duplicate these metrics from those collected by
- # the proxy.
- self.ingress_http_request_counter = self._init_ingress_request_counter(
- "HTTP"
- )
- self.ingress_http_request_error_counter = (
- self._init_ingress_request_error_counter("HTTP")
- )
- self.deployment_http_request_error_counter = (
- self._init_deployment_request_error_counter("HTTP")
- )
- logger.debug(f"REQUEST_LATENCY_BUCKETS_MS: {REQUEST_LATENCY_BUCKETS_MS}")
- self.ingress_http_processing_latency_tracker = (
- self._init_ingress_processing_latency_tracker("HTTP")
- )
- node_id = ray.get_runtime_context().get_node_id()
- node_ip_address = ray.util.get_node_ip_address()
- self.ingress_num_ongoing_http_requests_gauge = (
- self._init_ingress_num_ongoing_requests_gauge(
- "HTTP", node_id, node_ip_address
- )
- )
- self._ingress_ongoing_http_requests = 0
- if self._cached_metrics_enabled:
- self._cached_ingress_request_counter = defaultdict(
- lambda: defaultdict(int)
- )
- self._cached_ingress_request_error_counter = defaultdict(
- lambda: defaultdict(int)
- )
- self._cached_deployment_request_error_counter = defaultdict(
- lambda: defaultdict(int)
- )
- self._cached_ingress_processing_latencies = defaultdict(
- lambda: defaultdict(deque)
- )
- @property
- def _is_direct_ingress(self) -> bool:
- return self._ingress and RAY_SERVE_ENABLE_DIRECT_INGRESS
- def _init_ingress_request_counter(self, protocol: str):
- return ray_metrics.Counter(
- f"serve_num_{protocol.lower()}_requests",
- description=f"The number of {protocol} requests processed.",
- tag_keys=("route", "method", "application", "status_code"),
- )
- def _init_ingress_request_error_counter(self, protocol: str):
- return ray_metrics.Counter(
- f"serve_num_{protocol.lower()}_error_requests",
- description=(f"The number of errored {protocol} responses."),
- tag_keys=(
- "route",
- "error_code",
- "method",
- "application",
- ),
- )
- def _init_deployment_request_error_counter(self, protocol: str):
- return ray_metrics.Counter(
- f"serve_num_deployment_{protocol.lower()}_error_requests",
- description=(
- f"The number of errored {protocol} responses returned by each deployment."
- ),
- tag_keys=(
- "deployment",
- "error_code",
- "method",
- "route",
- "application",
- ),
- )
- def _init_ingress_processing_latency_tracker(self, protocol: str):
- return ray_metrics.Histogram(
- f"serve_{protocol.lower()}_request_latency_ms",
- description=(
- f"The end-to-end latency of {protocol} requests "
- f"(measured from the Serve ingress)."
- ),
- boundaries=REQUEST_LATENCY_BUCKETS_MS,
- tag_keys=(
- "method",
- "route",
- "application",
- "status_code",
- ),
- )
- def _init_ingress_num_ongoing_requests_gauge(
- self, protocol: str, node_id: str, node_ip_address: str
- ):
- return ray_metrics.Gauge(
- name=f"serve_num_ongoing_{protocol.lower()}_requests",
- description=f"The number of ongoing requests in this {protocol} ingress.",
- tag_keys=("node_id", "node_ip_address"),
- ).set_default_tags(
- {
- "node_id": node_id,
- "node_ip_address": node_ip_address,
- }
- )
- def _report_cached_metrics(self):
- for route, count in self._cached_request_counter.items():
- self._request_counter.inc(count, tags={"route": route})
- self._cached_request_counter.clear()
- for route, count in self._cached_error_counter.items():
- self._error_counter.inc(count, tags={"route": route})
- self._cached_error_counter.clear()
- for route, latencies in self._cached_latencies.items():
- for latency_ms in latencies:
- self._processing_latency_tracker.observe(
- latency_ms, tags={"route": route}
- )
- self._cached_latencies.clear()
- self._num_ongoing_requests_gauge.set(self._num_ongoing_requests)
- if not self._is_direct_ingress:
- return
- for protocol in [RequestProtocol.HTTP]:
- if protocol == RequestProtocol.HTTP:
- ingress_request_counter = self.ingress_http_request_counter
- ingress_request_error_counter = self.ingress_http_request_error_counter
- deployment_request_error_counter = (
- self.deployment_http_request_error_counter
- )
- ingress_processing_latencies = (
- self.ingress_http_processing_latency_tracker
- )
- self.ingress_num_ongoing_http_requests_gauge.set(
- self._ingress_ongoing_http_requests
- )
- else:
- # TODO(alexyang): Add metrics for gRPC.
- continue
- for request_tags, count in self._cached_ingress_request_counter[
- protocol
- ].items():
- ingress_request_counter.inc(count, tags=dict(request_tags))
- for request_tags, count in self._cached_ingress_request_error_counter[
- protocol
- ].items():
- ingress_request_error_counter.inc(count, tags=dict(request_tags))
- for request_tags, count in self._cached_deployment_request_error_counter[
- protocol
- ].items():
- deployment_request_error_counter.inc(count, tags=dict(request_tags))
- for latency_tags, latencies in self._cached_ingress_processing_latencies[
- protocol
- ].items():
- for latency_ms in latencies:
- ingress_processing_latencies.observe(
- latency_ms, tags=dict(latency_tags)
- )
- self._cached_ingress_request_counter.clear()
- self._cached_ingress_request_error_counter.clear()
- self._cached_deployment_request_error_counter.clear()
- self._cached_ingress_processing_latencies.clear()
- async def _report_cached_metrics_forever(self):
- assert self._cached_metrics_interval_s > 0
- consecutive_errors = 0
- while True:
- try:
- await asyncio.sleep(self._cached_metrics_interval_s)
- self._report_cached_metrics()
- consecutive_errors = 0
- except Exception:
- logger.exception("Unexpected error reporting metrics.")
- # Exponential backoff starting at 1s and capping at 10s.
- backoff_time_s = min(10, 2**consecutive_errors)
- consecutive_errors += 1
- await asyncio.sleep(backoff_time_s)
- async def shutdown(self):
- """Stop periodic background tasks."""
- await self._metrics_pusher.graceful_shutdown()
- def start_metrics_pusher(self):
- self._metrics_pusher.start()
- # Push autoscaling metrics to the controller periodically.
- self._metrics_pusher.register_or_update_task(
- self.PUSH_METRICS_TO_CONTROLLER_TASK_NAME,
- self._push_autoscaling_metrics,
- self._autoscaling_config.metrics_interval_s,
- )
- # Collect autoscaling metrics locally periodically.
- self._metrics_pusher.register_or_update_task(
- self.RECORD_METRICS_TASK_NAME,
- self._add_autoscaling_metrics_point_async,
- min(
- RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_INTERVAL_S,
- self._autoscaling_config.metrics_interval_s,
- ),
- )
- def should_collect_ongoing_requests(self) -> bool:
- """Determine if replicas should collect ongoing request metrics.
- ┌────────────────────────────────────────────────────────────────┐
- │ Replica-based metrics collection │
- ├────────────────────────────────────────────────────────────────┤
- │ │
- │ Client Handle Replicas │
- │ ┌──────┐ ┌────────┐ │
- │ │ App │─────>│ Handle │────┬───>┌─────────┐ │
- │ │ │ │ Tracks │ │ │ Replica │ │
- │ └──────┘ │ Queued │ │ │ 1 │ │
- │ │Requests│ │ │ Tracks │ │
- │ └────────┘ │ │ Running │ │
- │ │ │ └─────────┘ │
- │ │ │ │ │
- │ │ │ │ │
- │ │ │ ┌─────────┐ │
- │ │ └───>│ Replica │ │
- │ │ │ 2 │ │
- │ │ │ Tracks │ │
- │ │ │ Running │ │
- │ │ └─────────┘ │
- │ │ │ │
- │ │ │ │
- │ ▼ ▼ │
- │ ┌──────────────────────────────┐ │
- │ │ Controller │ │
- │ │ • Queued metrics (handle) │ │
- │ │ • Running metrics (replica1)│ │
- │ │ • Running metrics (replica2)│ │
- │ └──────────────────────────────┘ │
- │ │
- └────────────────────────────────────────────────────────────────┘
- For direct ingress deployments, metrics must be collected from replicas regardless
- of whether autoscaling metrics are being collected via handles. This is necessary
- because direct ingress traffic bypasses deployment handles and goes directly to
- the replicas.
- """
- if self._is_direct_ingress and self._autoscaling_config:
- return True
- return not RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE
- def set_autoscaling_config(self, autoscaling_config: Optional[AutoscalingConfig]):
- """Dynamically update autoscaling config."""
- self._autoscaling_config = autoscaling_config
- if self._autoscaling_config and self.should_collect_ongoing_requests():
- self.start_metrics_pusher()
- def enable_custom_autoscaling_metrics(
- self,
- custom_metrics_enabled: bool,
- record_autoscaling_stats_fn: Callable[[], Optional[concurrent.futures.Future]],
- ):
- """Runs after the user callable wrapper is initialized to enable autoscaling metrics collection."""
- if custom_metrics_enabled:
- self._custom_metrics_enabled = custom_metrics_enabled
- self._record_autoscaling_stats_fn = record_autoscaling_stats_fn
- self.start_metrics_pusher()
- def inc_num_ongoing_requests(self, request_metadata: RequestMetadata) -> int:
- self._num_ongoing_requests += 1
- if self._is_direct_ingress and request_metadata.is_direct_ingress:
- self._ingress_ongoing_http_requests += 1
- if not self._cached_metrics_enabled:
- self._num_ongoing_requests_gauge.set(self._num_ongoing_requests)
- if self._is_direct_ingress and request_metadata.is_direct_ingress:
- if request_metadata.is_http_request:
- self.ingress_num_ongoing_http_requests_gauge.set(
- self._ingress_ongoing_http_requests
- )
- def dec_num_ongoing_requests(self, request_metadata: RequestMetadata) -> int:
- self._num_ongoing_requests -= 1
- if self._is_direct_ingress and request_metadata.is_direct_ingress:
- self._ingress_ongoing_http_requests -= 1
- if not self._cached_metrics_enabled:
- self._num_ongoing_requests_gauge.set(self._num_ongoing_requests)
- if self._is_direct_ingress and request_metadata.is_direct_ingress:
- if request_metadata.is_http_request:
- self.ingress_num_ongoing_http_requests_gauge.set(
- self._ingress_ongoing_http_requests
- )
- def get_num_ongoing_requests(self) -> int:
- """Get current total queue length of requests for this replica."""
- return self._num_ongoing_requests
- def record_request_metrics(self, *, route: str, latency_ms: float, was_error: bool):
- """Records per-request metrics."""
- if self._cached_metrics_enabled:
- self._cached_latencies[route].append(latency_ms)
- if was_error:
- self._cached_error_counter[route] += 1
- else:
- self._cached_request_counter[route] += 1
- else:
- self._processing_latency_tracker.observe(latency_ms, tags={"route": route})
- if was_error:
- self._error_counter.inc(tags={"route": route})
- else:
- self._request_counter.inc(tags={"route": route})
- def record_ingress_request_metrics(
- self,
- *,
- protocol: RequestProtocol,
- method: str,
- route: str,
- app_name: str,
- deployment_name: str,
- latency_ms: float,
- was_error: bool,
- status_code: str,
- ):
- """Record per-request metrics."""
- if not self._is_direct_ingress:
- return
- if protocol == RequestProtocol.HTTP:
- latency_tracker = self.ingress_http_processing_latency_tracker
- request_error_counter = self.ingress_http_request_error_counter
- deployment_error_counter = self.deployment_http_request_error_counter
- request_counter = self.ingress_http_request_counter
- else:
- # TODO(alexyang): Add metrics for gRPC.
- return
- request_tags = {
- "route": route,
- "method": method,
- "application": app_name,
- "status_code": status_code,
- }
- latency_tags = request_tags
- request_error_tags = {
- "route": route,
- "method": method,
- "application": app_name,
- "error_code": status_code,
- }
- deployment_error_tags = {
- "route": route,
- "method": method,
- "application": app_name,
- "error_code": status_code,
- "deployment": deployment_name,
- }
- if self._cached_metrics_enabled:
- self._cached_ingress_request_counter[protocol][
- frozenset(request_tags.items())
- ] += 1
- self._cached_ingress_processing_latencies[protocol][
- frozenset(latency_tags.items())
- ].append(latency_ms)
- if was_error:
- self._cached_ingress_request_error_counter[protocol][
- frozenset(request_error_tags.items())
- ] += 1
- self._cached_deployment_request_error_counter[protocol][
- frozenset(deployment_error_tags.items())
- ] += 1
- else:
- request_counter.inc(tags=request_tags)
- latency_tracker.observe(latency_ms, tags=latency_tags)
- if was_error:
- request_error_counter.inc(tags=request_error_tags)
- deployment_error_counter.inc(tags=deployment_error_tags)
- def _push_autoscaling_metrics(self) -> Dict[str, Any]:
- look_back_period = self._autoscaling_config.look_back_period_s
- self._metrics_store.prune_keys_and_compact_data(time.time() - look_back_period)
- new_aggregated_metrics = {}
- new_metrics = {**self._metrics_store.data}
- if self.should_collect_ongoing_requests():
- # Keep the legacy window_avg ongoing requests in the merged metrics dict
- window_avg = (
- self._metrics_store.aggregate_avg([RUNNING_REQUESTS_KEY])[0] or 0.0
- )
- new_aggregated_metrics.update({RUNNING_REQUESTS_KEY: window_avg})
- replica_metric_report = ReplicaMetricReport(
- replica_id=self._replica_id,
- timestamp=time.time(),
- aggregated_metrics=new_aggregated_metrics,
- metrics=new_metrics,
- )
- self._controller_handle.record_autoscaling_metrics_from_replica.remote(
- replica_metric_report
- )
- async def _fetch_custom_autoscaling_metrics(
- self,
- ) -> Optional[Dict[str, Union[int, float]]]:
- try:
- start_time = time.time()
- res = await asyncio.wait_for(
- self._record_autoscaling_stats_fn(),
- timeout=RAY_SERVE_RECORD_AUTOSCALING_STATS_TIMEOUT_S,
- )
- latency_ms = (time.time() - start_time) * 1000
- self.user_autoscaling_stats_latency_tracker.observe(latency_ms)
- # Perform validation only first call
- if not self._checked_custom_metrics:
- # Enforce return type to be Dict[str, Union[int, float]]
- if not isinstance(res, dict):
- logger.error(
- f"User autoscaling stats method returned {type(res).__name__}, "
- f"expected Dict[str, Union[int, float]]. Disabling autoscaling stats."
- )
- self._custom_metrics_enabled = False
- return None
- for key, value in res.items():
- if not isinstance(value, (int, float)):
- logger.error(
- f"User autoscaling stats method returned invalid value type "
- f"{type(value).__name__} for key '{key}', expected int or float. "
- f"Disabling autoscaling stats."
- )
- self._custom_metrics_enabled = False
- return None
- self._checked_custom_metrics = True
- return res
- except asyncio.TimeoutError as e:
- logger.error(
- f"Replica autoscaling stats timed out after {RAY_SERVE_RECORD_AUTOSCALING_STATS_TIMEOUT_S}s."
- )
- self.record_autoscaling_stats_failed_counter.inc(
- tags={"exception_name": e.__class__.__name__}
- )
- except Exception as e:
- logger.error(f"Replica autoscaling stats failed. {e}")
- self.record_autoscaling_stats_failed_counter.inc(
- tags={"exception_name": e.__class__.__name__}
- )
- return None
- async def _add_autoscaling_metrics_point_async(self) -> None:
- metrics_dict = {}
- if self.should_collect_ongoing_requests():
- metrics_dict = {RUNNING_REQUESTS_KEY: self._num_ongoing_requests}
- # Use cached availability flag to avoid repeated runtime checks
- if self._custom_metrics_enabled:
- custom_metrics = await self._fetch_custom_autoscaling_metrics()
- if custom_metrics:
- metrics_dict.update(custom_metrics)
- self._metrics_store.add_metrics_point(
- metrics_dict,
- time.time(),
- )
- StatusCodeCallback = Callable[[str], None]
- class ReplicaBase(ABC):
- def __init__(
- self,
- replica_id: ReplicaID,
- deployment_def: Callable,
- init_args: Tuple,
- init_kwargs: Dict,
- deployment_config: DeploymentConfig,
- version: DeploymentVersion,
- ingress: bool,
- route_prefix: str,
- ):
- self._version = version
- self._replica_id = replica_id
- self._deployment_id = replica_id.deployment_id
- self._deployment_config = deployment_config
- self._ingress = ingress
- self._route_prefix = route_prefix
- self._component_name = f"{self._deployment_id.name}"
- if self._deployment_id.app_name:
- self._component_name = (
- f"{self._deployment_id.app_name}_" + self._component_name
- )
- self._component_id = self._replica_id.unique_id
- self._configure_logger_and_profilers(self._deployment_config.logging_config)
- self._event_loop = get_or_create_event_loop()
- actor_id = ray.get_runtime_context().get_actor_id()
- self._user_callable_wrapper = UserCallableWrapper(
- deployment_def,
- init_args,
- init_kwargs,
- deployment_id=self._deployment_id,
- run_sync_methods_in_threadpool=RAY_SERVE_RUN_SYNC_IN_THREADPOOL,
- run_user_code_in_separate_thread=RAY_SERVE_RUN_USER_CODE_IN_SEPARATE_THREAD,
- local_testing_mode=False,
- deployment_config=deployment_config,
- actor_id=actor_id,
- ray_actor_options=self._version.ray_actor_options,
- )
- self._semaphore = Semaphore(lambda: self.max_ongoing_requests)
- # Guards against calling the user's callable constructor multiple times.
- self._user_callable_initialized = False
- self._user_callable_initialized_lock = asyncio.Lock()
- self._initialization_latency: Optional[float] = None
- # Track deployment handles created dynamically via get_deployment_handle()
- self._dynamically_created_handles: Set[DeploymentID] = set()
- # Flipped to `True` when health checks pass and `False` when they fail. May be
- # used by replica subclass implementations.
- self._healthy = False
- # Flipped to `True` once graceful shutdown is initiated. May be used by replica
- # subclass implementations.
- self._shutting_down = False
- # Will be populated with the wrapped ASGI app if the user callable is an
- # `ASGIAppReplicaWrapper` (i.e., they are using the FastAPI integration).
- self._user_callable_asgi_app: Optional[ASGIApp] = None
- # Set metadata for logs and metrics.
- # servable_object will be populated in `initialize_and_get_metadata`.
- self._set_internal_replica_context(servable_object=None, rank=None)
- self._metrics_manager = create_replica_metrics_manager(
- replica_id=replica_id,
- event_loop=self._event_loop,
- autoscaling_config=self._deployment_config.autoscaling_config,
- ingress=ingress,
- )
- # Start event loop monitoring for the replica's main event loop.
- self._main_loop_monitor = EventLoopMonitor(
- component=EventLoopMonitor.COMPONENT_REPLICA,
- loop_type=EventLoopMonitor.LOOP_TYPE_MAIN,
- actor_id=actor_id,
- extra_tags={
- "deployment": self._deployment_id.name,
- "application": self._deployment_id.app_name,
- },
- )
- self._main_loop_monitor.start(self._event_loop)
- self._internal_grpc_port: Optional[int] = None
- self._docs_path: Optional[str] = None
- self._http_port: Optional[int] = None
- self._grpc_port: Optional[int] = None
- self._rank: Optional[ReplicaRank] = None
- # gRPC server for inter-deployment communication
- self._server = grpc.aio.server(
- options=[
- (
- "grpc.max_receive_message_length",
- RAY_SERVE_REPLICA_GRPC_MAX_MESSAGE_LENGTH,
- )
- ]
- )
- # Silence spammy false positive errors from gRPC Python
- self._event_loop.set_exception_handler(asyncio_grpc_exception_handler)
- @property
- def max_ongoing_requests(self) -> int:
- return self._deployment_config.max_ongoing_requests
- def get_num_ongoing_requests(self) -> int:
- return self._metrics_manager.get_num_ongoing_requests()
- def get_metadata(self) -> ReplicaMetadata:
- current_rank = ray.serve.context._get_internal_replica_context().rank
- # Extract route patterns from ASGI app if available
- route_patterns = None
- if self._user_callable_asgi_app is not None:
- # _user_callable_asgi_app is the actual ASGI app (FastAPI/Starlette)
- # It's set when initialize_callable() returns an ASGI app
- if hasattr(self._user_callable_asgi_app, "routes"):
- route_patterns = extract_route_patterns(self._user_callable_asgi_app)
- return (
- self._version.deployment_config,
- self._version,
- self._initialization_latency,
- self._internal_grpc_port,
- self._docs_path,
- self._http_port,
- self._grpc_port,
- current_rank,
- route_patterns,
- self.list_outbound_deployments(),
- )
- def get_dynamically_created_handles(self) -> Set[DeploymentID]:
- return self._dynamically_created_handles
- def list_outbound_deployments(self) -> List[DeploymentID]:
- """List all outbound deployment IDs this replica calls into.
- This includes:
- - Handles created via get_deployment_handle()
- - Handles passed as init args/kwargs to the deployment constructor
- This is used to determine which deployments are reachable from this replica.
- The list of DeploymentIDs can change over time as new handles can be created at runtime.
- Also its not guaranteed that the list of DeploymentIDs are identical across replicas
- because it depends on user code.
- Returns:
- A list of DeploymentIDs that this replica calls into.
- """
- seen_deployment_ids: Set[DeploymentID] = set()
- # First, collect dynamically created handles
- for deployment_id in self.get_dynamically_created_handles():
- seen_deployment_ids.add(deployment_id)
- # Get the init args/kwargs
- init_args = self._user_callable_wrapper._init_args
- init_kwargs = self._user_callable_wrapper._init_kwargs
- # Use _PyObjScanner to find all DeploymentHandle objects in:
- # The init_args and init_kwargs (handles might be passed as init args)
- scanner = _PyObjScanner(source_type=DeploymentHandle)
- try:
- handles = scanner.find_nodes((init_args, init_kwargs))
- for handle in handles:
- deployment_id = handle.deployment_id
- seen_deployment_ids.add(deployment_id)
- finally:
- scanner.clear()
- return list(seen_deployment_ids)
- def _set_internal_replica_context(
- self, *, servable_object: Callable = None, rank: ReplicaRank = None
- ):
- # Calculate world_size from deployment config instead of storing it
- world_size = self._deployment_config.num_replicas
- # Create callback for registering dynamically created handles
- def register_handle_callback(deployment_id: DeploymentID) -> None:
- self._dynamically_created_handles.add(deployment_id)
- ray.serve.context._set_internal_replica_context(
- replica_id=self._replica_id,
- servable_object=servable_object,
- _deployment_config=self._deployment_config,
- rank=rank,
- world_size=world_size,
- handle_registration_callback=register_handle_callback,
- )
- def _configure_logger_and_profilers(
- self, logging_config: Union[None, Dict, LoggingConfig]
- ):
- if logging_config is None:
- logging_config = {}
- if isinstance(logging_config, dict):
- logging_config = LoggingConfig(**logging_config)
- configure_component_logger(
- component_type=ServeComponentType.REPLICA,
- component_name=self._component_name,
- component_id=self._component_id,
- logging_config=logging_config,
- buffer_size=RAY_SERVE_REQUEST_PATH_LOG_BUFFER_SIZE,
- )
- configure_component_memory_profiler(
- component_type=ServeComponentType.REPLICA,
- component_name=self._component_name,
- component_id=self._component_id,
- )
- if logging_config.encoding == EncodingType.JSON:
- # Create logging context for access logs as a performance optimization.
- # While logging_utils can automatically add Ray core and Serve access log context,
- # we pre-compute it here since context evaluation is expensive and this context
- # will be reused for multiple access log entries.
- ray_core_logging_context = CoreContextFilter.get_ray_core_logging_context()
- # remove task level log keys from ray core logging context, it would be nice
- # to have task level log keys here but we are letting those go in favor of
- # performance optimization. Also we cannot include task level log keys here because
- # they would referance the current task (__init__) and not the task that is logging.
- for key in CoreContextFilter.TASK_LEVEL_LOG_KEYS:
- ray_core_logging_context.pop(key, None)
- self._access_log_context = {
- **ray_core_logging_context,
- SERVE_LOG_DEPLOYMENT: self._component_name,
- SERVE_LOG_REPLICA: self._component_id,
- SERVE_LOG_COMPONENT: ServeComponentType.REPLICA,
- SERVE_LOG_APPLICATION: self._deployment_id.app_name,
- "skip_context_filter": True,
- "serve_access_log": True,
- }
- else:
- self._access_log_context = {
- "skip_context_filter": True,
- "serve_access_log": True,
- }
- def _can_accept_request(self, request_metadata: RequestMetadata) -> bool:
- # This replica gates concurrent request handling with an asyncio.Semaphore.
- # Each in-flight request acquires the semaphore. When the number of ongoing
- # requests reaches max_ongoing_requests, the semaphore becomes locked.
- # A new request can be accepted if the semaphore is currently unlocked.
- return not self._semaphore.locked()
- @contextmanager
- def _handle_errors_and_metrics(
- self, request_metadata: RequestMetadata
- ) -> Generator[StatusCodeCallback, None, None]:
- start_time = time.time()
- user_exception = None
- status_code = None
- def _status_code_callback(s: str):
- nonlocal status_code
- status_code = s
- try:
- yield _status_code_callback
- except asyncio.CancelledError as e:
- user_exception = e
- self._on_request_cancelled(request_metadata, e)
- except Exception as e:
- user_exception = e
- logger.exception("Request failed.")
- self._on_request_failed(request_metadata, e)
- latency_ms = (time.time() - start_time) * 1000
- self._record_errors_and_metrics(
- user_exception, status_code, latency_ms, request_metadata
- )
- if user_exception is not None:
- raise user_exception from None
- def _record_errors_and_metrics(
- self,
- user_exception: Optional[BaseException],
- status_code: Optional[str],
- latency_ms: float,
- request_metadata: RequestMetadata,
- ):
- http_method = request_metadata._http_method
- http_route = request_metadata.route
- call_method = request_metadata.call_method
- if user_exception is None:
- status_str = "OK"
- elif isinstance(user_exception, asyncio.CancelledError):
- status_str = "CANCELLED"
- else:
- status_str = "ERROR"
- # Mutating self._access_log_context is not thread safe, but since this
- # is only called from the same thread, it is safe. Mutating the same object
- # because creating a new dict is expensive.
- self._access_log_context[SERVE_LOG_ROUTE] = http_route
- self._access_log_context[SERVE_LOG_REQUEST_ID] = request_metadata.request_id
- logger.info(
- access_log_msg(
- method=http_method or "CALL",
- route=http_route if self._ingress and http_route else call_method,
- # Prefer the HTTP status code if it was populated.
- status=status_code or status_str,
- latency_ms=latency_ms,
- ),
- extra=self._access_log_context,
- )
- self._metrics_manager.record_request_metrics(
- route=http_route,
- latency_ms=latency_ms,
- was_error=user_exception is not None,
- )
- # Record ingress metrics for direct ingress HTTP requests
- if request_metadata.is_direct_ingress and status_code is not None:
- self._metrics_manager.record_ingress_request_metrics(
- protocol=RequestProtocol.HTTP,
- method=request_metadata._http_method,
- route=self._route_prefix,
- app_name=self._deployment_id.app_name,
- deployment_name=self._deployment_id.name,
- latency_ms=latency_ms,
- was_error=status_code.startswith(("4", "5")),
- status_code=status_code,
- )
- def _unpack_proxy_args(
- self,
- request_metadata: RequestMetadata,
- request_args: Tuple[Any],
- request_kwargs: Dict[str, Any],
- ) -> Tuple[Tuple[Any], Dict[str, Any], Any]:
- # Extract _ray_trace_ctx from kwargs at the entry point.
- #
- # Context: When tracing is enabled, Ray's tracing decorators inject
- # _ray_trace_ctx into ServeReplica actor method calls. The ServeReplica
- # actor methods properly handle this, but we
- # need to extract it before calling user-defined deployment methods.
- #
- # Design: We return it so it can be passed to _wrap_request() which
- # stores it in _RequestContext. Users can then access it via serve.context
- # if needed (advanced use case), while keeping it out of their method signatures.
- ray_trace_ctx = request_kwargs.pop("_ray_trace_ctx", None)
- if request_metadata.is_http_request:
- assert len(request_args) == 1 and isinstance(
- request_args[0], StreamingHTTPRequest
- )
- request: StreamingHTTPRequest = request_args[0]
- scope = request.asgi_scope
- receive = ASGIReceiveProxy(
- scope, request_metadata, request.receive_asgi_messages
- )
- request_metadata._http_method = scope.get("method", "WS")
- request_args = (scope, receive)
- elif request_metadata.is_grpc_request:
- assert len(request_args) == 1 and isinstance(request_args[0], gRPCRequest)
- request: gRPCRequest = request_args[0]
- method_info = self._user_callable_wrapper.get_user_method_info(
- request_metadata.call_method
- )
- request_args = (request.user_request_proto,)
- request_kwargs = (
- {GRPC_CONTEXT_ARG_NAME: request_metadata.grpc_context}
- if method_info.takes_grpc_context_kwarg
- else {}
- )
- return request_args, request_kwargs, ray_trace_ctx
- async def handle_request(
- self, request_metadata: RequestMetadata, *request_args, **request_kwargs
- ) -> Tuple[bytes, Any]:
- request_args, request_kwargs, ray_trace_ctx = self._unpack_proxy_args(
- request_metadata, request_args, request_kwargs
- )
- with self._wrap_request(request_metadata, ray_trace_ctx):
- async with self._start_request(request_metadata):
- try:
- return await self._user_callable_wrapper.call_user_method(
- request_metadata, request_args, request_kwargs
- )
- except Exception as e:
- # For gRPC requests, wrap exception with user-set status code
- raise self._maybe_wrap_grpc_exception(e, request_metadata) from e
- async def handle_request_streaming(
- self, request_metadata: RequestMetadata, *request_args, **request_kwargs
- ) -> AsyncGenerator[Any, None]:
- """Generator that is the entrypoint for all `stream=True` handle calls."""
- request_args, request_kwargs, ray_trace_ctx = self._unpack_proxy_args(
- request_metadata, request_args, request_kwargs
- )
- with self._wrap_request(
- request_metadata, ray_trace_ctx
- ) as status_code_callback:
- async with self._start_request(request_metadata):
- try:
- if request_metadata.is_http_request:
- scope, receive = request_args
- async for msgs in self._user_callable_wrapper.call_http_entrypoint(
- request_metadata,
- status_code_callback,
- scope,
- receive,
- ):
- yield pickle.dumps(msgs)
- else:
- async for result in self._user_callable_wrapper.call_user_generator(
- request_metadata,
- request_args,
- request_kwargs,
- ):
- yield result
- except Exception as e:
- # For gRPC requests, wrap exception with user-set status code
- raise self._maybe_wrap_grpc_exception(e, request_metadata) from e
- def _maybe_wrap_grpc_exception(
- self, e: BaseException, request_metadata: RequestMetadata
- ) -> BaseException:
- """Wrap exception with gRPCStatusError if user set a status code.
- For gRPC requests, if the user set a status code on the grpc_context before
- raising an exception, we wrap the exception with gRPCStatusError to preserve
- the user's intended status code through the error handling path.
- """
- if request_metadata.is_grpc_request:
- grpc_context = request_metadata.grpc_context
- if grpc_context and grpc_context.code():
- return gRPCStatusError(
- original_exception=e,
- code=grpc_context.code(),
- details=grpc_context.details(),
- )
- return e
- async def handle_request_with_rejection(
- self, request_metadata: RequestMetadata, *request_args, **request_kwargs
- ):
- # Check if the replica has capacity for the request.
- if not self._can_accept_request(request_metadata):
- limit = self.max_ongoing_requests
- logger.warning(
- f"Replica at capacity of max_ongoing_requests={limit}, "
- f"rejecting request {request_metadata.request_id}.",
- extra={"log_to_stderr": False},
- )
- yield ReplicaQueueLengthInfo(False, self.get_num_ongoing_requests())
- return
- request_args, request_kwargs, ray_trace_ctx = self._unpack_proxy_args(
- request_metadata, request_args, request_kwargs
- )
- with self._wrap_request(
- request_metadata, ray_trace_ctx
- ) as status_code_callback:
- async with self._start_request(request_metadata):
- yield ReplicaQueueLengthInfo(
- accepted=True,
- # NOTE(edoakes): `_wrap_request` will increment the number
- # of ongoing requests to include this one, so re-fetch the value.
- num_ongoing_requests=self.get_num_ongoing_requests(),
- )
- try:
- if request_metadata.is_http_request:
- scope, receive = request_args
- async for msgs in self._user_callable_wrapper.call_http_entrypoint(
- request_metadata,
- status_code_callback,
- scope,
- receive,
- ):
- yield pickle.dumps(msgs)
- elif request_metadata.is_streaming:
- async for result in self._user_callable_wrapper.call_user_generator(
- request_metadata,
- request_args,
- request_kwargs,
- ):
- yield result
- else:
- yield await self._user_callable_wrapper.call_user_method(
- request_metadata, request_args, request_kwargs
- )
- except Exception as e:
- # For gRPC requests, wrap exception with user-set status code
- raise self._maybe_wrap_grpc_exception(e, request_metadata) from e
- @abstractmethod
- async def _on_initialized(self):
- raise NotImplementedError
- async def initialize(
- self, deployment_config: Optional[DeploymentConfig], rank: Optional[ReplicaRank]
- ):
- if rank is not None:
- self._rank = rank
- self._set_internal_replica_context(
- servable_object=self._user_callable_wrapper.user_callable, rank=rank
- )
- try:
- # Ensure that initialization is only performed once.
- # When controller restarts, it will call this method again.
- async with self._user_callable_initialized_lock:
- self._initialization_start_time = time.time()
- if not self._user_callable_initialized:
- self._user_callable_asgi_app = (
- await self._user_callable_wrapper.initialize_callable()
- )
- if self._user_callable_asgi_app:
- self._docs_path = (
- self._user_callable_wrapper._callable.docs_path
- )
- await self._on_initialized()
- self._user_callable_initialized = True
- if self._user_callable_wrapper is not None:
- initialized = (
- hasattr(
- self._user_callable_wrapper, "_user_autoscaling_stats"
- )
- and self._user_callable_wrapper._user_autoscaling_stats
- is not None
- )
- self._metrics_manager.enable_custom_autoscaling_metrics(
- custom_metrics_enabled=initialized,
- record_autoscaling_stats_fn=self._user_callable_wrapper.call_record_autoscaling_stats,
- )
- if deployment_config is not None:
- await self._user_callable_wrapper.set_sync_method_threadpool_limit(
- deployment_config.max_ongoing_requests
- )
- rank = ray.serve.context._get_internal_replica_context().rank
- await self._user_callable_wrapper.call_reconfigure(
- deployment_config.user_config,
- rank=rank,
- )
- # A new replica should not be considered healthy until it passes
- # an initial health check. If an initial health check fails,
- # consider it an initialization failure.
- await self.check_health()
- except Exception:
- raise RuntimeError(traceback.format_exc()) from None
- async def reconfigure(
- self,
- deployment_config: DeploymentConfig,
- rank: ReplicaRank,
- route_prefix: Optional[str] = None,
- ):
- try:
- user_config_changed = (
- deployment_config.user_config != self._deployment_config.user_config
- )
- rank_changed = rank != self._rank
- self._rank = rank
- logging_config_changed = (
- deployment_config.logging_config
- != self._deployment_config.logging_config
- )
- self._deployment_config = deployment_config
- self._version = DeploymentVersion.from_deployment_version(
- self._version, deployment_config, route_prefix
- )
- self._metrics_manager.set_autoscaling_config(
- deployment_config.autoscaling_config
- )
- if logging_config_changed:
- self._configure_logger_and_profilers(deployment_config.logging_config)
- await self._user_callable_wrapper.set_sync_method_threadpool_limit(
- deployment_config.max_ongoing_requests
- )
- if user_config_changed or rank_changed:
- await self._user_callable_wrapper.call_reconfigure(
- deployment_config.user_config,
- rank=rank,
- )
- # We need to update internal replica context to reflect the new
- # deployment_config and rank.
- self._set_internal_replica_context(
- servable_object=self._user_callable_wrapper.user_callable,
- rank=rank,
- )
- self._route_prefix = self._version.route_prefix
- except Exception:
- raise RuntimeError(traceback.format_exc()) from None
- @abstractmethod
- def _on_request_cancelled(
- self, request_metadata: RequestMetadata, e: asyncio.CancelledError
- ):
- pass
- @abstractmethod
- def _on_request_failed(self, request_metadata: RequestMetadata, e: Exception):
- pass
- @abstractmethod
- @contextmanager
- def _wrap_request(
- self, request_metadata: RequestMetadata
- ) -> Generator[StatusCodeCallback, None, None]:
- pass
- @asynccontextmanager
- async def _start_request(self, request_metadata: RequestMetadata):
- async with self._semaphore:
- try:
- self._metrics_manager.inc_num_ongoing_requests(request_metadata)
- yield
- finally:
- self._metrics_manager.dec_num_ongoing_requests(request_metadata)
- async def _drain_ongoing_requests(self):
- """Wait for any ongoing requests to finish.
- Sleep for a grace period before the first time we check the number of ongoing
- requests to allow the notification to remove this replica to propagate to
- callers first.
- """
- wait_loop_period_s = self._deployment_config.graceful_shutdown_wait_loop_s
- while True:
- await asyncio.sleep(wait_loop_period_s)
- num_ongoing_requests = self._metrics_manager.get_num_ongoing_requests()
- if num_ongoing_requests > 0:
- logger.info(
- f"Waiting for an additional {wait_loop_period_s}s to shut down "
- f"because there are {num_ongoing_requests} ongoing requests."
- )
- else:
- logger.info(
- "Graceful shutdown complete; replica exiting.",
- extra={"log_to_stderr": False},
- )
- break
- async def shutdown(self):
- try:
- await self._user_callable_wrapper.call_destructor()
- except: # noqa: E722
- # We catch a blanket exception since the constructor may still be
- # running, so instance variables used by the destructor may not exist.
- if self._user_callable_initialized:
- logger.exception(
- "__del__ ran before replica finished initializing, and "
- "raised an exception."
- )
- else:
- logger.exception("__del__ raised an exception.")
- await self._metrics_manager.shutdown()
- async def perform_graceful_shutdown(self):
- self._shutting_down = True
- # If the replica was never initialized it never served traffic, so we
- # can skip the wait period.
- if self._user_callable_initialized:
- await self._drain_ongoing_requests()
- await self.shutdown()
- async def check_health(self):
- try:
- # If there's no user-defined health check, nothing runs on the user code event
- # loop and no future is returned.
- f = self._user_callable_wrapper.call_user_health_check()
- if f is not None:
- await f
- self._healthy = True
- except Exception as e:
- logger.warning("Replica health check failed.")
- self._healthy = False
- raise e from None
- async def record_routing_stats(self) -> Dict[str, Any]:
- try:
- f = self._user_callable_wrapper.call_user_record_routing_stats()
- if f is not None:
- return await f
- return {}
- except Exception as e:
- logger.warning("Replica record routing stats failed.")
- raise e from None
- async def send_http_response(message, status_code, send):
- for msg in convert_object_to_asgi_messages(
- message,
- status_code=status_code,
- ):
- await send(msg)
- class Replica(ReplicaBase):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self._controller_handle = ray.get_actor(
- SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE
- )
- # get node ID
- self._node_id = ray.get_runtime_context().get_node_id()
- self._http_options: Optional[HTTPOptions] = None
- self._grpc_options: Optional[gRPCOptions] = None
- self._direct_ingress_http_server_task: Optional[asyncio.Task] = None
- self._direct_ingress_grpc_server_task: Optional[asyncio.Task] = None
- self._num_queued_requests = 0
- @property
- def max_queued_requests(self) -> int:
- return self._deployment_config.max_queued_requests
- async def _maybe_start_direct_ingress_servers(self):
- if not RAY_SERVE_ENABLE_DIRECT_INGRESS:
- return
- if not self._ingress:
- return
- async def allocate_and_start_server(start_server_fn, protocol):
- """Attempt to allocate a port and start the server with retries."""
- is_port_in_use = False
- for _ in range(RAY_SERVE_DIRECT_INGRESS_PORT_RETRY_COUNT):
- port = await self._controller_handle.allocate_replica_port.remote(
- self._node_id, self._replica_id.unique_id, protocol
- )
- logger.info(f"Allocated port {port} for {protocol}")
- try:
- server_task = await start_server_fn(port)
- logger.info(
- f"Successfully started {protocol} server on port {port}"
- )
- return port, server_task
- except RuntimeError as e:
- logger.warning(
- f"Failed to start {protocol} server on port {port}: {e}. Retrying..."
- )
- # `start_asgi_http_server` raises a RuntimeError with the original OSError as the cause.
- if isinstance(e.__cause__, OSError) and e.__cause__.errno in (
- errno.EADDRINUSE,
- errno.EADDRNOTAVAIL,
- ):
- is_port_in_use = True
- else:
- is_port_in_use = False
- # setting block_port to True because we are concluding that the port is
- # in use by another service on the same node. Blocking port here is a small
- # optimization to avoid trying to start the server on a the same port
- # multiple times by other replicas.
- await self._controller_handle.release_replica_port.remote(
- self._node_id,
- self._replica_id.unique_id,
- port,
- protocol,
- block_port=True,
- )
- err_msg = f"Failed to allocate and start {protocol} server after retries"
- if is_port_in_use:
- err_msg = f"""
- Failed to start {protocol} server: port already in use. Suggestion: Ensure that the Ray Serve direct ingress port ranges do not overlap with the Ray worker port range (min_worker_port to max_worker_port).
- """
- raise RuntimeError(err_msg)
- # Fetch configs
- self._http_options, self._grpc_options = ray.get(
- [
- self._controller_handle.get_http_config.remote(),
- self._controller_handle.get_grpc_config.remote(),
- ]
- )
- grpc_enabled = is_grpc_enabled(self._grpc_options)
- # Allocate and start HTTP server
- async def start_http_server(port):
- options = configure_http_middlewares(
- configure_http_options_with_defaults(
- HTTPOptions(**{**self._http_options.dict(), "port": port})
- )
- )
- return await start_asgi_http_server(
- self._direct_ingress_asgi,
- options,
- event_loop=self._event_loop,
- enable_so_reuseport=False,
- )
- (
- self._http_port,
- self._direct_ingress_http_server_task,
- ) = await allocate_and_start_server(
- start_server_fn=start_http_server,
- protocol=RequestProtocol.HTTP,
- )
- # Allocate and start gRPC server if enabled
- if grpc_enabled:
- async def start_grpc_server_fn(port):
- options = gRPCOptions(**{**self._grpc_options.dict(), "port": port})
- return await start_grpc_server(
- self._direct_ingress_service_handler_factory,
- options,
- event_loop=self._event_loop,
- enable_so_reuseport=False,
- )
- (
- self._grpc_port,
- self._direct_ingress_grpc_server_task,
- ) = await allocate_and_start_server(
- start_server_fn=start_grpc_server_fn,
- protocol=RequestProtocol.GRPC,
- )
- logger.info(
- f"Started HTTP server on port {self._http_port}"
- + (f" and gRPC server on port {self._grpc_port}" if grpc_enabled else "")
- )
- async def _on_initialized(self):
- await self._maybe_start_direct_ingress_servers()
- current_rank = ray.serve.context._get_internal_replica_context().rank
- self._set_internal_replica_context(
- servable_object=self._user_callable_wrapper.user_callable,
- rank=current_rank,
- )
- # Start the gRPC server for inter-deployment communication
- add_ASGIServiceServicer_to_server(self, self._server)
- self._internal_grpc_port = self._server.add_insecure_port("[::]:0")
- await self._server.start()
- logger.debug(
- f"Started inter-deployment gRPC server on port {self._internal_grpc_port}"
- )
- # Save the initialization latency if the replica is initializing
- # for the first time.
- if self._initialization_latency is None:
- self._initialization_latency = time.time() - self._initialization_start_time
- def _on_request_cancelled(
- self, metadata: RequestMetadata, e: asyncio.CancelledError
- ):
- """Recursively cancel child requests.
- This includes all requests that are pending assignment, and gRPC
- requests that have already been assigned.
- """
- # Cancel child requests pending assignment
- requests_pending_assignment = (
- ray.serve.context._get_requests_pending_assignment(
- metadata.internal_request_id
- )
- )
- for task in requests_pending_assignment.values():
- task.cancel()
- # Cancel child requests that have already been assigned.
- # This is for gRPC requests and direct ingress requests.
- in_flight_requests = _get_in_flight_requests(metadata.internal_request_id)
- for replica_result in in_flight_requests.values():
- replica_result.cancel()
- def _on_request_failed(self, request_metadata: RequestMetadata, e: Exception):
- if ray.util.pdb._is_ray_debugger_post_mortem_enabled():
- ray.util.pdb._post_mortem()
- def _can_accept_request(self, request_metadata: RequestMetadata):
- if request_metadata.is_direct_ingress:
- limit = self.max_queued_requests
- if limit != -1 and self._num_queued_requests >= limit:
- return False
- return True
- else:
- return super()._can_accept_request(request_metadata)
- @contextmanager
- def _wrap_request(
- self, request_metadata: RequestMetadata, ray_trace_ctx: Optional[Any] = None
- ) -> Generator[StatusCodeCallback, None, None]:
- """Context manager that wraps user method calls.
- 1) Sets the request context var with appropriate metadata.
- 2) Records the access log message (if not disabled).
- 3) Records per-request metrics via the metrics manager.
- """
- ray.serve.context._serve_request_context.set(
- ray.serve.context._RequestContext(
- route=request_metadata.route,
- request_id=request_metadata.request_id,
- _internal_request_id=request_metadata.internal_request_id,
- app_name=self._deployment_id.app_name,
- multiplexed_model_id=request_metadata.multiplexed_model_id,
- grpc_context=request_metadata.grpc_context,
- cancel_on_parent_request_cancel=self._ingress
- and RAY_SERVE_ENABLE_DIRECT_INGRESS,
- _ray_trace_ctx=ray_trace_ctx,
- )
- )
- with self._handle_errors_and_metrics(request_metadata) as status_code_callback:
- yield status_code_callback
- @_wrap_grpc_call
- async def HandleRequest(
- self,
- context: grpc.aio.ServicerContext,
- request_metadata: RequestMetadata,
- *request_args,
- **request_kwargs,
- ):
- result = await self.handle_request(
- request_metadata, *request_args, **request_kwargs
- )
- if request_metadata.is_grpc_request:
- result = (request_metadata.grpc_context, result.SerializeToString())
- return result
- @_wrap_grpc_call
- async def HandleRequestStreaming(
- self,
- context: grpc.aio.ServicerContext,
- request_metadata: RequestMetadata,
- *request_args,
- **request_kwargs,
- ):
- async for result in self.handle_request_streaming(
- request_metadata, *request_args, **request_kwargs
- ):
- if request_metadata.is_grpc_request:
- result = (request_metadata.grpc_context, result.SerializeToString())
- yield result
- @_wrap_grpc_call
- async def HandleRequestWithRejection(
- self,
- context: grpc.aio.ServicerContext,
- request_metadata: RequestMetadata,
- *request_args,
- **request_kwargs,
- ):
- """gRPC entrypoint for all unary requests with strict max_ongoing_requests enforcement
- This generator yields a system message indicating if the request was accepted,
- then the actual response.
- If an exception occurred while processing the request, whether it's a user
- exception or an error intentionally raised by Serve, it will be returned as
- a gRPC response instead of raised directly.
- """
- result_gen = self.handle_request_with_rejection(
- request_metadata, *request_args, **request_kwargs
- )
- queue_len_info: ReplicaQueueLengthInfo = await result_gen.__anext__()
- await context.send_initial_metadata(
- [
- ("accepted", str(int(queue_len_info.accepted))),
- ("num_ongoing_requests", str(queue_len_info.num_ongoing_requests)),
- ]
- )
- if not queue_len_info.accepted:
- # NOTE(edoakes): in gRPC, it's not guaranteed that the initial metadata sent
- # by the server will be delivered for a stream with no messages. Therefore,
- # we send a dummy message here to ensure it is populated in every case.
- return b""
- result = await result_gen.__anext__()
- # Consume the result generator to ensure all request operations are completed.
- async for _ in result_gen:
- pass
- if request_metadata.is_grpc_request:
- result = (request_metadata.grpc_context, result.SerializeToString())
- return result
- @_wrap_grpc_call
- async def HandleRequestWithRejectionStreaming(
- self,
- context: grpc.aio.ServicerContext,
- request_metadata: RequestMetadata,
- *request_args,
- **request_kwargs,
- ) -> AsyncGenerator[Any, None]:
- """gRPC entrypoint for all streaming requests with strict max_ongoing_requests enforcement
- This generator yields a system message indicating if the request was accepted,
- then the actual response(s).
- If an exception occurred while processing the request, whether it's a user
- exception or an error intentionally raised by Serve, it will be returned as
- a gRPC response instead of raised directly.
- """
- result_gen = self.handle_request_with_rejection(
- request_metadata, *request_args, **request_kwargs
- )
- queue_len_info: ReplicaQueueLengthInfo = await result_gen.__anext__()
- await context.send_initial_metadata(
- [
- ("accepted", str(int(queue_len_info.accepted))),
- ("num_ongoing_requests", str(queue_len_info.num_ongoing_requests)),
- ]
- )
- if not queue_len_info.accepted:
- # NOTE(edoakes): in gRPC, it's not guaranteed that the initial metadata sent
- # by the server will be delivered for a stream with no messages. Therefore,
- # we send a dummy message here to ensure it is populated in every case.
- yield b""
- return
- async for result in result_gen:
- if request_metadata.is_grpc_request:
- result = (request_metadata.grpc_context, result.SerializeToString())
- yield result
- async def _dataplane_health_check(self) -> Tuple[bool, str]:
- healthy, message = True, HEALTHY_MESSAGE
- if self._shutting_down:
- healthy = False
- message = "DRAINING"
- elif not self._healthy:
- healthy = False
- message = "UNHEALTHY"
- return healthy, message
- async def _direct_ingress_unary_unary(
- self,
- service_method: str,
- request_proto: Any,
- context: grpc._cython.cygrpc._ServicerContext,
- ) -> bytes:
- if service_method == "/ray.serve.RayServeAPIService/Healthz":
- healthy, message = await self._dataplane_health_check()
- context.set_code(
- grpc.StatusCode.OK if healthy else grpc.StatusCode.UNAVAILABLE
- )
- context.set_details(message)
- return HealthzResponse(message=message).SerializeToString()
- if service_method == "/ray.serve.RayServeAPIService/ListApplications":
- # NOTE(edoakes): ListApplications may be used for health checking.
- healthy, message = await self._dataplane_health_check()
- context.set_code(
- grpc.StatusCode.OK if healthy else grpc.StatusCode.UNAVAILABLE
- )
- context.set_details(message)
- # ListApplications returns only the app name the replica is serving.
- application_names = [self._deployment_id.app_name]
- return ListApplicationsResponse(
- application_names=application_names
- ).SerializeToString()
- request_id = generate_request_id()
- c = RayServegRPCContext(context)
- c.set_trailing_metadata([("request_id", request_id)])
- request_metadata = RequestMetadata(
- # TODO: pick up the request ID from gRPC initial metadata.
- request_id=request_id,
- internal_request_id=generate_request_id(),
- call_method=service_method.split("/")[-1],
- _request_protocol=RequestProtocol.GRPC,
- grpc_context=c,
- app_name=self._deployment_id.app_name,
- # TODO(edoakes): populate this.
- multiplexed_model_id="",
- route=self._deployment_id.app_name,
- tracing_context=None,
- is_streaming=False,
- is_direct_ingress=True,
- )
- if not self._can_accept_request(request_metadata):
- status = ResponseStatus(
- code=grpc.StatusCode.RESOURCE_EXHAUSTED,
- message="Request dropped due to backpressure",
- )
- set_grpc_code_and_details(context, status)
- return
- method_info = self._user_callable_wrapper.get_user_method_info(
- request_metadata.call_method
- )
- request_args = (request_proto,)
- request_kwargs = (
- {GRPC_CONTEXT_ARG_NAME: request_metadata.grpc_context}
- if method_info.takes_grpc_context_kwarg
- else {}
- )
- async def call_unary():
- yield await self._user_callable_wrapper.call_user_method(
- request_metadata, request_args, request_kwargs
- )
- with self._wrap_request(request_metadata):
- self._num_queued_requests += 1
- async with self._start_request(request_metadata):
- self._num_queued_requests -= 1
- # Use the generic disconnect detecting wrapper
- result_gen = call_unary()
- replica_response_generator = ReplicaResponseGenerator(
- result_gen,
- timeout_s=self._grpc_options.request_timeout_s,
- )
- try:
- result = await replica_response_generator.__anext__()
- c._set_on_grpc_context(context)
- status = ResponseStatus(code=grpc.StatusCode.OK)
- # NOTE(edoakes): we need to fully consume the generator otherwise the
- # finalizers that run after the `yield` statement won't run. There might
- # be a cleaner way to structure this.
- try:
- await replica_response_generator.__anext__()
- except StopAsyncIteration:
- pass
- except BaseException as e:
- # For gRPC requests, wrap exception with user-set status code
- e = self._maybe_wrap_grpc_exception(e, request_metadata)
- status = get_grpc_response_status(
- e,
- self._grpc_options.request_timeout_s,
- request_metadata.request_id,
- )
- return
- finally:
- set_grpc_code_and_details(context, status)
- return result.SerializeToString()
- async def _direct_ingress_unary_stream(
- self,
- service_method: str,
- request: Any,
- context: grpc._cython.cygrpc._ServicerContext,
- ):
- raise NotImplementedError("unary_stream not implemented.")
- def _direct_ingress_service_handler_factory(
- self, service_method: str, stream: bool
- ) -> Callable:
- if stream:
- async def handler(*args, **kwargs):
- return await self._direct_ingress_unary_stream(
- service_method, *args, **kwargs
- )
- else:
- async def handler(*args, **kwargs):
- return await self._direct_ingress_unary_unary(
- service_method, *args, **kwargs
- )
- return handler
- def _determine_http_route(self, scope: Scope) -> str:
- # Default to route prefix for consistency with non-DI mode
- route = self._route_prefix
- if self._user_callable_asgi_app is not None:
- try:
- matched_route = get_asgi_route_name(self._user_callable_asgi_app, scope)
- if matched_route is not None:
- route = matched_route
- except Exception:
- # If route matching fails, keep the route prefix
- pass
- return route
- def _parse_request_timeout(self, headers: Dict[str, str]) -> Optional[float]:
- """Gets the desired request timeout from the headers.
- If the header is missing or invalid, returns the default request timeout
- from HttpOptions. If the header is non-positive, timeout is disabled.
- """
- header_name = SERVE_HTTP_REQUEST_TIMEOUT_S_HEADER.encode("utf-8")
- if header_name not in headers:
- return self._http_options.request_timeout_s
- value = headers.get(header_name).decode("utf-8")
- try:
- timeout = float(value)
- if timeout > 0:
- return timeout
- return None
- except ValueError:
- return self._http_options.request_timeout_s
- async def _direct_ingress_asgi(
- self,
- scope: Scope,
- receive: Receive,
- send: Send,
- ):
- # NOTE(edoakes): it's important to only start the replica server after the
- # constructor runs because we are using SO_REUSEPORT. We don't want a new
- # replica to start handling connections until it's ready to serve traffic.
- #
- # This can be loosened to listen on the port but fail health checks once we no
- # longer rely on SO_REUSEPORT.
- assert (
- self._user_callable_initialized
- ), "Replica server should only be started *after* the replica is initialized."
- if self._route_prefix and self._route_prefix != "/":
- scope["root_path"] = self._route_prefix
- start_time = time.time()
- method = scope.get("method", "WS").upper()
- route = scope.get("path", "")
- # Handle health check or routes request.
- if route in ["/-/healthz", "/-/routes"]:
- healthy, message = await self._dataplane_health_check()
- status_code = 200 if healthy else 503
- if route == "/-/routes" and healthy:
- # routes endpoint returns only the route prefix andapp name the replica is serving.
- message = {
- self._route_prefix: self._deployment_id.app_name,
- }
- for msg in convert_object_to_asgi_messages(
- message,
- status_code=status_code,
- ):
- await send(msg)
- latency_ms = (time.time() - start_time) * 1000.0
- self._metrics_manager.record_ingress_request_metrics(
- protocol=RequestProtocol.HTTP,
- method=method,
- route=route,
- app_name=self._deployment_id.app_name,
- deployment_name=self._deployment_id.name,
- latency_ms=latency_ms,
- was_error=not healthy,
- status_code=str(status_code),
- )
- return
- # If the HTTP path does not match the deployment route prefix,
- # it is invalid and we should not serve it.
- if not route.startswith(self._route_prefix):
- for msg in convert_object_to_asgi_messages(
- f"Path '{route}' not found. "
- "Ping http://.../-/routes for available routes.",
- status_code=404,
- ):
- await send(msg)
- return
- headers = dict(scope["headers"])
- request_id = (
- headers.get(SERVE_HTTP_REQUEST_ID_HEADER.encode("utf-8")).decode("utf-8")
- or generate_request_id()
- )
- request_disconnect_disabled = (
- headers.get(
- SERVE_HTTP_REQUEST_DISCONNECT_DISABLED_HEADER.encode("utf-8"), b"?0"
- ).decode("utf-8")
- ) == "?1"
- request_timeout_s = self._parse_request_timeout(headers)
- request_metadata = RequestMetadata(
- request_id=request_id,
- internal_request_id=generate_request_id(),
- call_method="__call__",
- route=self._determine_http_route(scope),
- app_name=self._deployment_id.app_name,
- # TODO(edoakes): populate the multiplexed model ID.
- multiplexed_model_id="",
- is_streaming=True,
- _request_protocol=RequestProtocol.HTTP,
- tracing_context=None,
- _http_method=scope.get("method", "WS"),
- is_direct_ingress=True,
- )
- if not self._can_accept_request(request_metadata):
- # NOTE(abrar): its possible that we drop more requests than actual max_queued_requests
- # because between incrementing and decrementing the queued requests, we yield to the event loop.
- for msg in convert_object_to_asgi_messages(
- "Request dropped due to backpressure",
- status_code=503,
- ):
- await send(msg)
- return
- # Optimization: we can avoid creating an async receive task if the client
- # has disabled handling disconnects for this request.
- if request_disconnect_disabled:
- receive_proxy = receive
- receive_task = None
- else:
- receive_proxy = ASGIDIReceiveProxy(
- scope, receive, self._user_callable_wrapper.event_loop
- )
- receive_task = receive_proxy.fetch_until_disconnect_task()
- response_started = False
- response_finished = False
- first_message_peeked = False
- with self._wrap_request(request_metadata) as status_code_callback:
- self._num_queued_requests += 1
- async def send_user_message(msg: Dict):
- nonlocal response_started
- nonlocal response_finished
- nonlocal first_message_peeked
- if not first_message_peeked:
- first_message_peeked = True
- if msg["type"] == "http.response.start":
- status_code_callback(str(msg["status"]))
- await send(msg)
- response_started = True
- if msg.get("more_body") is False:
- response_finished = True
- async def call_asgi():
- async with self._start_request(request_metadata):
- self._num_queued_requests -= 1
- if (
- not self._user_callable_wrapper._run_user_code_in_separate_thread
- ):
- user_method_info = (
- self._user_callable_wrapper.get_user_method_info(
- request_metadata.call_method
- )
- )
- # `_call_http_entrypoint` will have already called
- # `send_user_message`, so the ASGI messages will have
- # already been sent back to the client.
- await self._user_callable_wrapper._call_http_entrypoint(
- user_method_info, scope, receive_proxy, send_user_message
- )
- else:
- async for asgi_messages in self._user_callable_wrapper.call_http_entrypoint(
- request_metadata, status_code_callback, scope, receive_proxy
- ):
- for message in asgi_messages:
- await send_user_message(message)
- # Optimization: if Serve doesn't need to handle disconnects and
- # timeouts for this request, we can avoid event loop overhead by
- # directly awaiting the user code.
- if receive_task is None and request_timeout_s is None:
- return await call_asgi()
- # Otherwise, we'd always need the call_asgi() task.
- request_task = asyncio.create_task(call_asgi())
- tasks = [request_task]
- if receive_task is not None:
- tasks.append(receive_task)
- done, _ = await asyncio.wait(
- tasks,
- timeout=request_timeout_s,
- return_when=asyncio.FIRST_COMPLETED,
- )
- # NOTE(zcin): it's possible that the request task has finished sending
- # all ASGI messages, but the task is suspended and before it can fully
- # complete, the client has sent a disconnect message after the request
- # is completed. That is why we check for `response_finished` here.
- if request_task in done or response_finished:
- if receive_task is not None:
- receive_task.cancel()
- await request_task
- elif receive_task in done:
- request_task.cancel()
- status_code_callback("499")
- if not response_started:
- msg = (
- f"Client for request {request_id} disconnected, "
- "cancelling request."
- )
- await send_http_response(msg, 499, send)
- raise asyncio.CancelledError
- else:
- request_task.cancel()
- status_code_callback("408")
- if not response_started:
- msg = (
- f"Request {request_id} timed out after "
- f"{self._http_options.request_timeout_s}s."
- )
- await send_http_response(msg, 408, send)
- raise asyncio.CancelledError
- async def perform_graceful_shutdown(self):
- if (
- RAY_SERVE_ENABLE_DIRECT_INGRESS
- and self._ingress
- and self._user_callable_initialized
- ):
- # In direct ingress mode, we need to wait at least
- # RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S to give external load
- # balancers (e.g., ALB) time to deregister the replica, in addition to
- # waiting for requests to drain.
- await asyncio.gather(
- asyncio.sleep(RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S),
- super().perform_graceful_shutdown(),
- )
- else:
- await super().perform_graceful_shutdown()
- # Cancel direct ingress HTTP/gRPC server tasks if they exist.
- if self._direct_ingress_http_server_task:
- self._direct_ingress_http_server_task.cancel()
- if self._direct_ingress_grpc_server_task:
- self._direct_ingress_grpc_server_task.cancel()
- class ReplicaActor:
- """Actor definition for replicas of Ray Serve deployments.
- This class defines the interface that the controller and deployment handles
- (i.e., from proxies and other replicas) use to interact with a replica.
- All interaction with the user-provided callable is done via the
- `UserCallableWrapper` class.
- """
- async def __init__(
- self,
- replica_id: ReplicaID,
- serialized_deployment_def: bytes,
- serialized_init_args: bytes,
- serialized_init_kwargs: bytes,
- deployment_config_proto_bytes: bytes,
- version: DeploymentVersion,
- ingress: bool,
- route_prefix: str,
- ):
- deployment_config = DeploymentConfig.from_proto_bytes(
- deployment_config_proto_bytes
- )
- deployment_def = cloudpickle.loads(serialized_deployment_def)
- if isinstance(deployment_def, str):
- deployment_def = _load_deployment_def_from_import_path(deployment_def)
- self._replica_impl: ReplicaBase = create_replica_impl(
- replica_id=replica_id,
- deployment_def=deployment_def,
- init_args=cloudpickle.loads(serialized_init_args),
- init_kwargs=cloudpickle.loads(serialized_init_kwargs),
- deployment_config=deployment_config,
- version=version,
- ingress=ingress,
- route_prefix=route_prefix,
- )
- def push_proxy_handle(self, handle: ActorHandle):
- # NOTE(edoakes): it's important to call a method on the proxy handle to
- # initialize its state in the C++ core worker.
- handle.pong.remote()
- def get_num_ongoing_requests(self) -> int:
- """Fetch the number of ongoing requests at this replica (queue length).
- This runs on a separate thread (using a Ray concurrency group) so it will
- not be blocked by user code.
- """
- return self._replica_impl.get_num_ongoing_requests()
- async def is_allocated(self) -> str:
- """poke the replica to check whether it's alive.
- When calling this method on an ActorHandle, it will complete as
- soon as the actor has started running. We use this mechanism to
- detect when a replica has been allocated a worker slot.
- At this time, the replica can transition from PENDING_ALLOCATION
- to PENDING_INITIALIZATION startup state.
- Returns:
- The PID, actor ID, node ID, node IP, and log filepath id of the replica.
- """
- return (
- os.getpid(),
- ray.get_runtime_context().get_actor_id(),
- ray.get_runtime_context().get_worker_id(),
- ray.get_runtime_context().get_node_id(),
- ray.util.get_node_ip_address(),
- ray.util.get_node_instance_id(),
- get_component_logger_file_path(),
- )
- def list_outbound_deployments(self) -> Optional[List[DeploymentID]]:
- return self._replica_impl.list_outbound_deployments()
- async def initialize_and_get_metadata(
- self, deployment_config: DeploymentConfig = None, rank: ReplicaRank = None
- ) -> ReplicaMetadata:
- """Handles initializing the replica.
- Returns: 5-tuple containing
- 1. DeploymentConfig of the replica
- 2. DeploymentVersion of the replica
- 3. Initialization duration in seconds
- 4. Port
- 5. FastAPI `docs_path`, if relevant (i.e. this is an ingress deployment integrated with FastAPI).
- """
- # Unused `_after` argument is for scheduling: passing an ObjectRef
- # allows delaying this call until after the `_after` call has returned.
- await self._replica_impl.initialize(deployment_config, rank)
- return self._replica_impl.get_metadata()
- async def check_health(self):
- await self._replica_impl.check_health()
- async def record_routing_stats(self) -> Dict[str, Any]:
- return await self._replica_impl.record_routing_stats()
- async def reconfigure(
- self, deployment_config, rank: ReplicaRank, route_prefix: Optional[str] = None
- ) -> ReplicaMetadata:
- await self._replica_impl.reconfigure(deployment_config, rank, route_prefix)
- return self._replica_impl.get_metadata()
- def _preprocess_request_args(
- self,
- pickled_request_metadata: bytes,
- request_args: Tuple[Any],
- ) -> Tuple[RequestMetadata, Tuple[Any]]:
- request_metadata = pickle.loads(pickled_request_metadata)
- if request_metadata.is_http_request or request_metadata.is_grpc_request:
- request_args = (pickle.loads(request_args[0]),)
- return request_metadata, request_args
- async def handle_request(
- self,
- pickled_request_metadata: bytes,
- *request_args,
- **request_kwargs,
- ) -> Tuple[bytes, Any]:
- """Entrypoint for `stream=False` calls."""
- request_metadata, request_args = self._preprocess_request_args(
- pickled_request_metadata, request_args
- )
- result = await self._replica_impl.handle_request(
- request_metadata, *request_args, **request_kwargs
- )
- if request_metadata.is_grpc_request:
- result = (request_metadata.grpc_context, result.SerializeToString())
- return result
- async def handle_request_streaming(
- self,
- pickled_request_metadata: bytes,
- *request_args,
- **request_kwargs,
- ) -> AsyncGenerator[Any, None]:
- """Generator that is the entrypoint for all `stream=True` handle calls."""
- request_metadata, request_args = self._preprocess_request_args(
- pickled_request_metadata, request_args
- )
- async for result in self._replica_impl.handle_request_streaming(
- request_metadata, *request_args, **request_kwargs
- ):
- if request_metadata.is_grpc_request:
- result = (request_metadata.grpc_context, result.SerializeToString())
- yield result
- async def handle_request_with_rejection(
- self,
- pickled_request_metadata: bytes,
- *request_args,
- **request_kwargs,
- ) -> AsyncGenerator[Any, None]:
- """Entrypoint for all requests with strict max_ongoing_requests enforcement.
- The first response from this generator is always a system message indicating
- if the request was accepted (the replica has capacity for the request) or
- rejected (the replica is already at max_ongoing_requests).
- For non-streaming requests, there will only be one more message, the unary
- result of the user request handler.
- For streaming requests, the subsequent messages will be the results of the
- user request handler (which must be a generator).
- """
- request_metadata, request_args = self._preprocess_request_args(
- pickled_request_metadata, request_args
- )
- async for result in self._replica_impl.handle_request_with_rejection(
- request_metadata, *request_args, **request_kwargs
- ):
- if isinstance(result, ReplicaQueueLengthInfo):
- yield pickle.dumps(result)
- else:
- if request_metadata.is_grpc_request:
- result = (request_metadata.grpc_context, result.SerializeToString())
- yield result
- async def handle_request_from_java(
- self,
- proto_request_metadata: bytes,
- *request_args,
- **request_kwargs,
- ) -> Any:
- from ray.serve.generated.serve_pb2 import (
- RequestMetadata as RequestMetadataProto,
- )
- proto = RequestMetadataProto.FromString(proto_request_metadata)
- request_metadata: RequestMetadata = RequestMetadata(
- request_id=proto.request_id,
- internal_request_id=proto.internal_request_id,
- call_method=proto.call_method,
- multiplexed_model_id=proto.multiplexed_model_id,
- route=proto.route,
- )
- return await self._replica_impl.handle_request(
- request_metadata, *request_args, **request_kwargs
- )
- async def perform_graceful_shutdown(self):
- await self._replica_impl.perform_graceful_shutdown()
- @dataclass
- class UserMethodInfo:
- """Wrapper for a user method and its relevant metadata."""
- callable: Callable
- name: str
- is_asgi_app: bool
- takes_any_args: bool
- takes_grpc_context_kwarg: bool
- @classmethod
- def from_callable(cls, c: Callable, *, is_asgi_app: bool) -> "UserMethodInfo":
- params = inspect.signature(c).parameters
- return cls(
- callable=c,
- name=c.__name__,
- is_asgi_app=is_asgi_app,
- takes_any_args=len(params) > 0,
- takes_grpc_context_kwarg=GRPC_CONTEXT_ARG_NAME in params,
- )
- class UserCallableWrapper:
- """Wraps a user-provided callable that is used to handle requests to a replica."""
- service_unavailable_exceptions = (BackPressureError, DeploymentUnavailableError)
- def __init__(
- self,
- deployment_def: Callable,
- init_args: Tuple,
- init_kwargs: Dict,
- *,
- deployment_id: DeploymentID,
- run_sync_methods_in_threadpool: bool,
- run_user_code_in_separate_thread: bool,
- local_testing_mode: bool,
- deployment_config: DeploymentConfig,
- actor_id: str,
- ray_actor_options: Optional[Dict] = None,
- ):
- if not (inspect.isfunction(deployment_def) or inspect.isclass(deployment_def)):
- raise TypeError(
- "deployment_def must be a function or class. Instead, its type was "
- f"{type(deployment_def)}."
- )
- self._deployment_def = deployment_def
- self._init_args = init_args
- self._init_kwargs = init_kwargs
- self._is_function = inspect.isfunction(deployment_def)
- self._deployment_id = deployment_id
- self._local_testing_mode = local_testing_mode
- self._destructor_called = False
- self._run_sync_methods_in_threadpool = run_sync_methods_in_threadpool
- self._run_user_code_in_separate_thread = run_user_code_in_separate_thread
- self._warned_about_sync_method_change = False
- self._cached_user_method_info: Dict[str, UserMethodInfo] = {}
- # This is for performance optimization https://docs.python.org/3/howto/logging.html#optimization
- self._is_enabled_for_debug = logger.isEnabledFor(logging.DEBUG)
- # Will be populated in `initialize_callable`.
- self._callable = None
- self._deployment_config = deployment_config
- self._ray_actor_options = ray_actor_options or {}
- self._user_code_threadpool: Optional[
- concurrent.futures.ThreadPoolExecutor
- ] = None
- if self._run_user_code_in_separate_thread:
- # All interactions with user code run on this loop to avoid blocking the
- # replica's main event loop.
- self._user_code_event_loop: asyncio.AbstractEventLoop = (
- asyncio.new_event_loop()
- )
- # Start event loop monitoring for the user code event loop.
- # We create the monitor here but start it inside the thread function
- # so the task is created on the correct thread.
- self._user_code_loop_monitor = EventLoopMonitor(
- component=EventLoopMonitor.COMPONENT_REPLICA,
- loop_type=EventLoopMonitor.LOOP_TYPE_USER_CODE,
- actor_id=actor_id,
- extra_tags={
- "deployment": self._deployment_id.name,
- "application": self._deployment_id.app_name,
- },
- )
- def _run_user_code_event_loop():
- # Required so that calls to get the current running event loop work
- # properly in user code.
- asyncio.set_event_loop(self._user_code_event_loop)
- self._configure_user_code_threadpool()
- # Start monitoring before run_forever so the task is scheduled.
- self._user_code_loop_monitor.start(self._user_code_event_loop)
- self._user_code_event_loop.run_forever()
- self._user_code_event_loop_thread = threading.Thread(
- daemon=True,
- target=_run_user_code_event_loop,
- )
- self._user_code_event_loop_thread.start()
- else:
- self._user_code_event_loop = asyncio.get_running_loop()
- self._user_code_loop_monitor = None
- self._configure_user_code_threadpool()
- @property
- def event_loop(self) -> asyncio.AbstractEventLoop:
- return self._user_code_event_loop
- def _get_user_code_threadpool_max_workers(self) -> Optional[int]:
- num_cpus = self._ray_actor_options.get("num_cpus")
- if num_cpus is None:
- return None
- # Mirror ThreadPoolExecutor default behavior while respecting num_cpus.
- return min(32, max(1, int(math.ceil(num_cpus))) + 4)
- def _configure_user_code_threadpool(self) -> None:
- max_workers = self._get_user_code_threadpool_max_workers()
- if max_workers is None:
- return
- self._user_code_threadpool = concurrent.futures.ThreadPoolExecutor(
- max_workers=max_workers
- )
- self._user_code_event_loop.set_default_executor(self._user_code_threadpool)
- def _run_user_code(f: Callable) -> Callable:
- """Decorator to run a coroutine method on the user code event loop.
- The method will be modified to be a sync function that returns a
- `asyncio.Future` if user code is running in a separate event loop.
- Otherwise, it will return the coroutine directly.
- """
- assert inspect.iscoroutinefunction(
- f
- ), "_run_user_code can only be used on coroutine functions."
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs) -> Any:
- coro = f(self, *args, **kwargs)
- if self._run_user_code_in_separate_thread:
- fut = asyncio.run_coroutine_threadsafe(coro, self._user_code_event_loop)
- if self._local_testing_mode:
- return fut
- return asyncio.wrap_future(fut)
- else:
- return coro
- return wrapper
- @_run_user_code
- async def set_sync_method_threadpool_limit(self, limit: int):
- # NOTE(edoakes): the limit is thread local, so this must
- # be run on the user code event loop.
- to_thread.current_default_thread_limiter().total_tokens = limit
- def get_user_method_info(self, method_name: str) -> UserMethodInfo:
- """Get UserMethodInfo for the provided call method name.
- This method is cached to avoid repeated expensive calls to `inspect.signature`.
- """
- if method_name in self._cached_user_method_info:
- return self._cached_user_method_info[method_name]
- if self._is_function:
- user_method = self._callable
- elif hasattr(self._callable, method_name):
- user_method = getattr(self._callable, method_name)
- else:
- # Filter to methods that don't start with '__' prefix.
- def callable_method_filter(attr):
- if attr.startswith("__"):
- return False
- elif not callable(getattr(self._callable, attr)):
- return False
- return True
- methods = list(filter(callable_method_filter, dir(self._callable)))
- raise RayServeException(
- f"Tried to call a method '{method_name}' "
- "that does not exist. Available methods: "
- f"{methods}."
- )
- info = UserMethodInfo.from_callable(
- user_method,
- is_asgi_app=isinstance(self._callable, ASGIAppReplicaWrapper),
- )
- self._cached_user_method_info[method_name] = info
- return info
- async def _send_user_result_over_asgi(
- self,
- result: Any,
- asgi_args: ASGIArgs,
- ):
- """Handle the result from user code and send it over the ASGI interface.
- If the result is already a Response type, it is sent directly. Otherwise, it
- is converted to a custom Response type that handles serialization for
- common Python objects.
- """
- scope, receive, send = asgi_args.to_args_tuple()
- if isinstance(result, starlette.responses.Response):
- await result(scope, receive, send)
- else:
- await Response(result).send(scope, receive, send)
- async def _call_func_or_gen(
- self,
- callable: Callable,
- *,
- args: Optional[Tuple[Any]] = None,
- kwargs: Optional[Dict[str, Any]] = None,
- is_streaming: bool = False,
- generator_result_callback: Optional[Callable] = None,
- run_sync_methods_in_threadpool_override: Optional[bool] = None,
- ) -> Tuple[Any, bool]:
- """Call the callable with the provided arguments.
- This is a convenience wrapper that will work for `def`, `async def`,
- generator, and async generator functions.
- Returns the result and a boolean indicating if the result was a sync generator
- that has already been consumed.
- """
- sync_gen_consumed = False
- args = args if args is not None else tuple()
- kwargs = kwargs if kwargs is not None else dict()
- run_sync_in_threadpool = (
- self._run_sync_methods_in_threadpool
- if run_sync_methods_in_threadpool_override is None
- else run_sync_methods_in_threadpool_override
- )
- is_sync_method = (
- inspect.isfunction(callable) or inspect.ismethod(callable)
- ) and not (
- inspect.iscoroutinefunction(callable)
- or inspect.isasyncgenfunction(callable)
- )
- if is_sync_method and run_sync_in_threadpool:
- is_generator = inspect.isgeneratorfunction(callable)
- if is_generator:
- sync_gen_consumed = True
- if not is_streaming:
- # TODO(edoakes): make this check less redundant with the one in
- # _handle_user_method_result.
- raise TypeError(
- f"Method '{callable.__name__}' returned a generator. "
- "You must use `handle.options(stream=True)` to call "
- "generators on a deployment."
- )
- def run_callable():
- result = callable(*args, **kwargs)
- if is_generator:
- for r in result:
- generator_result_callback(r)
- result = None
- return result
- # NOTE(edoakes): we use anyio.to_thread here because it's what Starlette
- # uses (and therefore FastAPI too). The max size of the threadpool is
- # set to max_ongoing_requests in the replica wrapper.
- # anyio.to_thread propagates ContextVars to the worker thread automatically.
- result = await to_thread.run_sync(run_callable)
- else:
- if (
- is_sync_method
- and not self._warned_about_sync_method_change
- and run_sync_methods_in_threadpool_override is None
- ):
- self._warned_about_sync_method_change = True
- warnings.warn(
- RAY_SERVE_RUN_SYNC_IN_THREADPOOL_WARNING.format(
- method_name=callable.__name__,
- )
- )
- result = callable(*args, **kwargs)
- if inspect.iscoroutine(result):
- result = await result
- return result, sync_gen_consumed
- @property
- def user_callable(self) -> Optional[Callable]:
- return self._callable
- async def _initialize_asgi_callable(self) -> None:
- self._callable: ASGIAppReplicaWrapper
- app: Starlette = self._callable.app
- # The reason we need to do this is because BackPressureError is a serve internal exception
- # and FastAPI doesn't know how to handle it, so it treats it as a 500 error.
- # With same reasoning, we are not handling TimeoutError because it's a generic exception
- # the FastAPI knows how to handle. See https://www.starlette.io/exceptions/
- def handle_exception(_: Request, exc: Exception):
- return self.handle_exception(exc)
- for exc in self.service_unavailable_exceptions:
- app.add_exception_handler(exc, handle_exception)
- await self._callable._run_asgi_lifespan_startup()
- @_run_user_code
- async def initialize_callable(self) -> Optional[ASGIApp]:
- """Initialize the user callable.
- If the callable is an ASGI app wrapper (e.g., using @serve.ingress), returns
- the ASGI app object, which may be used *read only* by the caller.
- """
- if self._callable is not None:
- raise RuntimeError("initialize_callable should only be called once.")
- # This closure initializes user code and finalizes replica
- # startup. By splitting the initialization step like this,
- # we can already access this actor before the user code
- # has finished initializing.
- # The supervising state manager can then wait
- # for allocation of this replica by using the `is_allocated`
- # method. After that, it calls `reconfigure` to trigger
- # user code initialization.
- logger.info(
- "Started initializing replica.",
- extra={"log_to_stderr": False},
- )
- if self._is_function:
- self._callable = self._deployment_def
- else:
- # This allows deployments to define an async __init__
- # method (mostly used for testing).
- self._callable = self._deployment_def.__new__(self._deployment_def)
- await self._call_func_or_gen(
- self._callable.__init__,
- args=self._init_args,
- kwargs=self._init_kwargs,
- # Always run the constructor on the main user code thread.
- run_sync_methods_in_threadpool_override=False,
- )
- if isinstance(self._callable, ASGIAppReplicaWrapper):
- await self._initialize_asgi_callable()
- if isinstance(self._callable, TaskConsumerWrapper):
- self._callable.initialize_callable(
- self._deployment_config.max_ongoing_requests
- )
- ServeUsageTag.NUM_REPLICAS_USING_ASYNCHRONOUS_INFERENCE.record("1")
- self._user_health_check = getattr(self._callable, HEALTH_CHECK_METHOD, None)
- self._user_record_routing_stats = getattr(
- self._callable, REQUEST_ROUTING_STATS_METHOD, None
- )
- self._user_autoscaling_stats = getattr(
- self._callable, "record_autoscaling_stats", None
- )
- logger.info(
- "Finished initializing replica.",
- extra={"log_to_stderr": False},
- )
- return (
- self._callable.app
- if isinstance(self._callable, ASGIAppReplicaWrapper)
- else None
- )
- def _raise_if_not_initialized(self, method_name: str):
- if self._callable is None:
- raise RuntimeError(
- f"`initialize_callable` must be called before `{method_name}`."
- )
- def call_user_health_check(self) -> Optional[concurrent.futures.Future]:
- self._raise_if_not_initialized("call_user_health_check")
- # If the user provided a health check, call it on the user code thread. If user
- # code blocks the event loop the health check may time out.
- #
- # To avoid this issue for basic cases without a user-defined health check, skip
- # interacting with the user callable entirely.
- if self._user_health_check is not None:
- return self._call_user_health_check()
- return None
- def call_user_record_routing_stats(self) -> Optional[concurrent.futures.Future]:
- self._raise_if_not_initialized("call_user_record_routing_stats")
- if self._user_record_routing_stats is not None:
- return self._call_user_record_routing_stats()
- return None
- def call_record_autoscaling_stats(self) -> Optional[concurrent.futures.Future]:
- self._raise_if_not_initialized("call_record_autoscaling_stats")
- if self._user_autoscaling_stats is not None:
- return self._call_user_autoscaling_stats()
- return None
- @_run_user_code
- async def _call_user_health_check(self):
- await self._call_func_or_gen(self._user_health_check)
- @_run_user_code
- async def _call_user_record_routing_stats(self) -> Dict[str, Any]:
- result, _ = await self._call_func_or_gen(self._user_record_routing_stats)
- return result
- @_run_user_code
- async def _call_user_autoscaling_stats(self) -> Dict[str, Union[int, float]]:
- result, _ = await self._call_func_or_gen(self._user_autoscaling_stats)
- return result
- @_run_user_code
- async def call_reconfigure(self, user_config: Optional[Any], rank: ReplicaRank):
- self._raise_if_not_initialized("call_reconfigure")
- # NOTE(edoakes): there is the possibility of a race condition in user code if
- # they don't have any form of concurrency control between `reconfigure` and
- # other methods. See https://github.com/ray-project/ray/pull/42159.
- # NOTE(abrar): The only way to subscribe to rank changes is to provide some user config.
- # We can relax this in the future as more use cases arise for rank. I am reluctant to
- # introduce behavior change for a feature we might not need.
- user_subscribed_to_rank = False
- if not self._is_function and hasattr(self._callable, RECONFIGURE_METHOD):
- reconfigure_method = getattr(self._callable, RECONFIGURE_METHOD)
- params = inspect.signature(reconfigure_method).parameters
- user_subscribed_to_rank = "rank" in params
- if user_config is not None or user_subscribed_to_rank:
- if self._is_function:
- raise ValueError(
- "deployment_def must be a class to use user_config or rank"
- )
- elif not hasattr(self._callable, RECONFIGURE_METHOD):
- raise RayServeException(
- "user_config or rank specified but deployment "
- + self._deployment_id
- + " missing "
- + RECONFIGURE_METHOD
- + " method"
- )
- kwargs = {}
- if user_subscribed_to_rank:
- # For backwards compatibility, only pass rank if it is an argument to the reconfigure method.
- kwargs["rank"] = rank
- await self._call_func_or_gen(
- getattr(self._callable, RECONFIGURE_METHOD),
- args=(user_config,),
- kwargs=kwargs,
- )
- async def _handle_user_method_result(
- self,
- result: Any,
- user_method_info: UserMethodInfo,
- *,
- is_streaming: bool,
- is_http_request: bool,
- sync_gen_consumed: bool,
- generator_result_callback: Optional[Callable],
- asgi_args: Optional[ASGIArgs],
- ) -> Any:
- """Postprocess the result of a user method.
- User methods can be regular unary functions or return a sync or async generator.
- This method will raise an exception if the result is not of the expected type
- (e.g., non-generator for streaming requests or generator for unary requests).
- Generator outputs will be written to the `generator_result_callback`.
- Note that HTTP requests are an exception: they are *always* streaming requests,
- but for ASGI apps (like FastAPI), the actual method will be a regular function
- implementing the ASGI `__call__` protocol.
- """
- result_is_gen = inspect.isgenerator(result)
- result_is_async_gen = inspect.isasyncgen(result)
- if is_streaming:
- if result_is_gen:
- for r in result:
- generator_result_callback(r)
- elif result_is_async_gen:
- async for r in result:
- generator_result_callback(r)
- elif is_http_request and not user_method_info.is_asgi_app:
- # For the FastAPI codepath, the response has already been sent over
- # ASGI, but for the vanilla deployment codepath we need to send it.
- await self._send_user_result_over_asgi(result, asgi_args)
- elif not is_http_request and not sync_gen_consumed:
- # If a unary method is called with stream=True for anything EXCEPT
- # an HTTP request, raise an error.
- # HTTP requests are always streaming regardless of if the method
- # returns a generator, because it's provided the result queue as its
- # ASGI `send` interface to stream back results.
- raise TypeError(
- f"Called method '{user_method_info.name}' with "
- "`handle.options(stream=True)` but it did not return a "
- "generator."
- )
- else:
- assert (
- not is_http_request
- ), "All HTTP requests go through the streaming codepath."
- if result_is_gen or result_is_async_gen:
- raise TypeError(
- f"Method '{user_method_info.name}' returned a generator. "
- "You must use `handle.options(stream=True)` to call "
- "generators on a deployment."
- )
- return result
- async def call_http_entrypoint(
- self,
- request_metadata: RequestMetadata,
- status_code_callback: StatusCodeCallback,
- scope: Scope,
- receive: Receive,
- ) -> Any:
- result_queue = MessageQueue()
- user_method_info = self.get_user_method_info(request_metadata.call_method)
- if self._run_user_code_in_separate_thread:
- # `asyncio.Event`s are not thread safe, so `call_soon_threadsafe` must be
- # used to interact with the result queue from the user callable thread.
- system_event_loop = asyncio.get_running_loop()
- async def enqueue(item: Any):
- system_event_loop.call_soon_threadsafe(result_queue.put_nowait, item)
- call_future = self._call_http_entrypoint(
- user_method_info, scope, receive, enqueue
- )
- else:
- async def enqueue(item: Any):
- result_queue.put_nowait(item)
- call_future = asyncio.create_task(
- self._call_http_entrypoint(user_method_info, scope, receive, enqueue)
- )
- first_message_peeked = False
- async for messages in result_queue.fetch_messages_from_queue(call_future):
- # HTTP (ASGI) messages are only consumed by the proxy so batch them
- # and use vanilla pickle (we know it's safe because these messages
- # only contain primitive Python types).
- # Peek the first ASGI message to determine the status code.
- if not first_message_peeked:
- msg = messages[0]
- first_message_peeked = True
- if msg["type"] == "http.response.start":
- # HTTP responses begin with exactly one
- # "http.response.start" message containing the "status"
- # field. Other response types like WebSockets may not.
- status_code_callback(str(msg["status"]))
- yield messages
- @_run_user_code
- async def _call_http_entrypoint(
- self,
- user_method_info: UserMethodInfo,
- scope: Scope,
- receive: Receive,
- send: Send,
- ) -> Any:
- """Call an HTTP entrypoint.
- `send` is used to communicate the results of streaming responses.
- Raises any exception raised by the user code so it can be propagated as a
- `RayTaskError`.
- """
- self._raise_if_not_initialized("_call_http_entrypoint")
- if self._is_enabled_for_debug:
- logger.debug(
- f"Started executing request to method '{user_method_info.name}'.",
- extra={"log_to_stderr": False, "serve_access_log": True},
- )
- if user_method_info.is_asgi_app:
- request_args = (scope, receive, send)
- elif not user_method_info.takes_any_args:
- # Edge case to support empty HTTP handlers: don't pass the Request
- # argument if the callable has no parameters.
- request_args = tuple()
- else:
- # Non-FastAPI HTTP handlers take only the starlette `Request`.
- request_args = (starlette.requests.Request(scope, receive, send),)
- receive_task = None
- try:
- if hasattr(receive, "fetch_until_disconnect"):
- receive_task = asyncio.create_task(receive.fetch_until_disconnect())
- result, sync_gen_consumed = await self._call_func_or_gen(
- user_method_info.callable,
- args=request_args,
- kwargs={},
- is_streaming=True,
- generator_result_callback=send,
- )
- final_result = await self._handle_user_method_result(
- result,
- user_method_info,
- is_streaming=True,
- is_http_request=True,
- sync_gen_consumed=sync_gen_consumed,
- generator_result_callback=send,
- asgi_args=ASGIArgs(scope, receive, send),
- )
- if receive_task is not None and not receive_task.done():
- receive_task.cancel()
- return final_result
- except Exception as e:
- if not user_method_info.is_asgi_app:
- response = self.handle_exception(e)
- await self._send_user_result_over_asgi(
- response, ASGIArgs(scope, receive, send)
- )
- if receive_task is not None and not receive_task.done():
- receive_task.cancel()
- raise
- except asyncio.CancelledError:
- if receive_task is not None and not receive_task.done():
- # Do NOT cancel the receive task if the request has been
- # cancelled, but the call is a batched call. This is
- # because we cannot guarantee cancelling the batched
- # call, so in the case that the call continues executing
- # we should continue fetching data from the client.
- if not hasattr(user_method_info.callable, "set_max_batch_size"):
- receive_task.cancel()
- raise
- async def call_user_generator(
- self,
- request_metadata: RequestMetadata,
- request_args: Tuple[Any],
- request_kwargs: Dict[str, Any],
- ) -> AsyncGenerator[Any, None]:
- """Calls a user method for a streaming call and yields its results.
- The user method is called in an asyncio `Task` and places its results on a
- `result_queue`. This method pulls and yields from the `result_queue`.
- """
- if not self._run_user_code_in_separate_thread:
- gen = await self._call_user_generator(
- request_metadata, request_args, request_kwargs
- )
- async for result in gen:
- yield result
- else:
- result_queue = MessageQueue()
- # `asyncio.Event`s are not thread safe, so `call_soon_threadsafe` must be
- # used to interact with the result queue from the user callable thread.
- system_event_loop = asyncio.get_running_loop()
- def _enqueue_thread_safe(item: Any):
- system_event_loop.call_soon_threadsafe(result_queue.put_nowait, item)
- call_future = self._call_user_generator(
- request_metadata,
- request_args,
- request_kwargs,
- enqueue=_enqueue_thread_safe,
- )
- async for messages in result_queue.fetch_messages_from_queue(call_future):
- for msg in messages:
- yield msg
- @_run_user_code
- async def _call_user_generator(
- self,
- request_metadata: RequestMetadata,
- request_args: Tuple[Any],
- request_kwargs: Dict[str, Any],
- *,
- enqueue: Optional[Callable] = None,
- ) -> Optional[AsyncGenerator[Any, None]]:
- """Call a user generator.
- The `generator_result_callback` is used to communicate the results of generator
- methods.
- Raises any exception raised by the user code so it can be propagated as a
- `RayTaskError`.
- """
- self._raise_if_not_initialized("_call_user_generator")
- request_args = request_args if request_args is not None else tuple()
- request_kwargs = request_kwargs if request_kwargs is not None else dict()
- user_method_info = self.get_user_method_info(request_metadata.call_method)
- callable = user_method_info.callable
- is_sync_method = (
- inspect.isfunction(callable) or inspect.ismethod(callable)
- ) and not (
- inspect.iscoroutinefunction(callable)
- or inspect.isasyncgenfunction(callable)
- )
- if self._is_enabled_for_debug:
- logger.debug(
- f"Started executing request to method '{user_method_info.name}'.",
- extra={"log_to_stderr": False, "serve_access_log": True},
- )
- async def _call_generator_async() -> AsyncGenerator[Any, None]:
- gen = callable(*request_args, **request_kwargs)
- if inspect.iscoroutine(gen):
- gen = await gen
- if inspect.isgenerator(gen):
- for result in gen:
- yield result
- elif inspect.isasyncgen(gen):
- async for result in gen:
- yield result
- else:
- raise TypeError(
- f"Called method '{user_method_info.name}' with "
- "`handle.options(stream=True)` but it did not return a generator."
- )
- def _call_generator_sync():
- gen = callable(*request_args, **request_kwargs)
- if inspect.isgenerator(gen):
- for result in gen:
- enqueue(result)
- else:
- raise TypeError(
- f"Called method '{user_method_info.name}' with "
- "`handle.options(stream=True)` but it did not return a generator."
- )
- if enqueue and is_sync_method and self._run_sync_methods_in_threadpool:
- await to_thread.run_sync(_call_generator_sync)
- elif enqueue:
- async def gen_coro_wrapper():
- async for result in _call_generator_async():
- enqueue(result)
- await gen_coro_wrapper()
- else:
- return _call_generator_async()
- @_run_user_code
- async def call_user_method(
- self,
- request_metadata: RequestMetadata,
- request_args: Tuple[Any],
- request_kwargs: Dict[str, Any],
- ) -> Any:
- """Call a (unary) user method.
- Raises any exception raised by the user code so it can be propagated as a
- `RayTaskError`.
- """
- self._raise_if_not_initialized("call_user_method")
- if self._is_enabled_for_debug:
- logger.debug(
- f"Started executing request to method '{request_metadata.call_method}'.",
- extra={"log_to_stderr": False, "serve_access_log": True},
- )
- user_method_info = self.get_user_method_info(request_metadata.call_method)
- result, _ = await self._call_func_or_gen(
- user_method_info.callable,
- args=request_args,
- kwargs=request_kwargs,
- is_streaming=False,
- )
- if inspect.isgenerator(result) or inspect.isasyncgen(result):
- raise TypeError(
- f"Method '{user_method_info.name}' returned a generator. "
- "You must use `handle.options(stream=True)` to call "
- "generators on a deployment."
- )
- return result
- def handle_exception(self, exc: Exception):
- if isinstance(exc, self.service_unavailable_exceptions):
- return starlette.responses.Response(exc.message, status_code=503)
- else:
- return starlette.responses.Response(
- "Internal Server Error", status_code=500
- )
- @_run_user_code
- async def call_destructor(self):
- """Explicitly call the `__del__` method of the user callable.
- Calling this multiple times has no effect; only the first call will
- actually call the destructor.
- """
- if self._callable is None:
- logger.debug(
- "This replica has not yet started running user code. "
- "Skipping __del__."
- )
- return
- # Only run the destructor once. This is safe because there is no `await` between
- # checking the flag here and flipping it to `True` below.
- if self._destructor_called:
- return
- self._destructor_called = True
- try:
- if hasattr(self._callable, "__del__"):
- # Make sure to accept `async def __del__(self)` as well.
- await self._call_func_or_gen(
- self._callable.__del__,
- # Always run the destructor on the main user callable thread.
- run_sync_methods_in_threadpool_override=False,
- )
- if hasattr(self._callable, "__serve_multiplex_wrapper"):
- await getattr(self._callable, "__serve_multiplex_wrapper").shutdown()
- except Exception as e:
- logger.exception(f"Exception during graceful shutdown of replica: {e}")
- finally:
- if self._user_code_threadpool is not None:
- self._user_code_threadpool.shutdown(wait=False)
|