replica.py 134 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397
  1. import asyncio
  2. import concurrent.futures
  3. import errno
  4. import functools
  5. import inspect
  6. import logging
  7. import math
  8. import os
  9. import pickle
  10. import threading
  11. import time
  12. import traceback
  13. import warnings
  14. from abc import ABC, abstractmethod
  15. from collections import defaultdict, deque
  16. from contextlib import asynccontextmanager, contextmanager
  17. from dataclasses import dataclass
  18. from functools import wraps
  19. from importlib import import_module
  20. from typing import (
  21. Any,
  22. AsyncGenerator,
  23. Callable,
  24. Dict,
  25. Generator,
  26. List,
  27. Optional,
  28. Set,
  29. Tuple,
  30. Union,
  31. )
  32. import grpc
  33. import starlette.responses
  34. from anyio import to_thread
  35. from fastapi import Request
  36. from starlette.applications import Starlette
  37. from starlette.types import ASGIApp, Receive, Scope, Send
  38. import ray
  39. from ray import cloudpickle
  40. from ray._common.filters import CoreContextFilter
  41. from ray._common.utils import get_or_create_event_loop
  42. from ray.actor import ActorClass, ActorHandle
  43. from ray.dag.py_obj_scanner import _PyObjScanner
  44. from ray.remote_function import RemoteFunction
  45. from ray.serve import metrics
  46. from ray.serve._private.common import (
  47. RUNNING_REQUESTS_KEY,
  48. DeploymentID,
  49. ReplicaID,
  50. ReplicaMetricReport,
  51. ReplicaQueueLengthInfo,
  52. RequestMetadata,
  53. RequestProtocol,
  54. ServeComponentType,
  55. StreamingHTTPRequest,
  56. gRPCRequest,
  57. )
  58. from ray.serve._private.config import DeploymentConfig
  59. from ray.serve._private.constants import (
  60. GRPC_CONTEXT_ARG_NAME,
  61. HEALTH_CHECK_METHOD,
  62. HEALTHY_MESSAGE,
  63. RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
  64. RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S,
  65. RAY_SERVE_DIRECT_INGRESS_PORT_RETRY_COUNT,
  66. RAY_SERVE_ENABLE_DIRECT_INGRESS,
  67. RAY_SERVE_METRICS_EXPORT_INTERVAL_MS,
  68. RAY_SERVE_RECORD_AUTOSCALING_STATS_TIMEOUT_S,
  69. RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_INTERVAL_S,
  70. RAY_SERVE_REPLICA_GRPC_MAX_MESSAGE_LENGTH,
  71. RAY_SERVE_REQUEST_PATH_LOG_BUFFER_SIZE,
  72. RAY_SERVE_RUN_SYNC_IN_THREADPOOL,
  73. RAY_SERVE_RUN_SYNC_IN_THREADPOOL_WARNING,
  74. RAY_SERVE_RUN_USER_CODE_IN_SEPARATE_THREAD,
  75. RECONFIGURE_METHOD,
  76. REQUEST_LATENCY_BUCKETS_MS,
  77. REQUEST_ROUTING_STATS_METHOD,
  78. SERVE_CONTROLLER_NAME,
  79. SERVE_HTTP_REQUEST_DISCONNECT_DISABLED_HEADER,
  80. SERVE_HTTP_REQUEST_ID_HEADER,
  81. SERVE_HTTP_REQUEST_TIMEOUT_S_HEADER,
  82. SERVE_LOG_APPLICATION,
  83. SERVE_LOG_COMPONENT,
  84. SERVE_LOG_DEPLOYMENT,
  85. SERVE_LOG_REPLICA,
  86. SERVE_LOG_REQUEST_ID,
  87. SERVE_LOG_ROUTE,
  88. SERVE_LOGGER_NAME,
  89. SERVE_NAMESPACE,
  90. )
  91. from ray.serve._private.default_impl import (
  92. create_replica_impl,
  93. create_replica_metrics_manager,
  94. )
  95. from ray.serve._private.direct_ingress_http_util import ASGIDIReceiveProxy
  96. from ray.serve._private.event_loop_monitoring import EventLoopMonitor
  97. from ray.serve._private.grpc_util import (
  98. get_grpc_response_status,
  99. set_grpc_code_and_details,
  100. start_grpc_server,
  101. )
  102. from ray.serve._private.http_util import (
  103. ASGIAppReplicaWrapper,
  104. ASGIArgs,
  105. ASGIReceiveProxy,
  106. MessageQueue,
  107. Response,
  108. configure_http_middlewares,
  109. configure_http_options_with_defaults,
  110. convert_object_to_asgi_messages,
  111. start_asgi_http_server,
  112. )
  113. from ray.serve._private.logging_utils import (
  114. access_log_msg,
  115. configure_component_logger,
  116. configure_component_memory_profiler,
  117. get_component_logger_file_path,
  118. )
  119. from ray.serve._private.metrics_utils import InMemoryMetricsStore, MetricsPusher
  120. from ray.serve._private.proxy_request_response import ResponseStatus
  121. from ray.serve._private.replica_response_generator import ReplicaResponseGenerator
  122. from ray.serve._private.serialization import RPCSerializer
  123. from ray.serve._private.task_consumer import TaskConsumerWrapper
  124. from ray.serve._private.thirdparty.get_asgi_route_name import (
  125. extract_route_patterns,
  126. get_asgi_route_name,
  127. )
  128. from ray.serve._private.usage import ServeUsageTag
  129. from ray.serve._private.utils import (
  130. Semaphore,
  131. asyncio_grpc_exception_handler,
  132. generate_request_id,
  133. get_component_file_name, # noqa: F401
  134. is_grpc_enabled,
  135. parse_import_path,
  136. )
  137. from ray.serve._private.version import DeploymentVersion
  138. from ray.serve.config import AutoscalingConfig, HTTPOptions, gRPCOptions
  139. from ray.serve.context import _get_in_flight_requests
  140. from ray.serve.deployment import Deployment
  141. from ray.serve.exceptions import (
  142. BackPressureError,
  143. DeploymentUnavailableError,
  144. RayServeException,
  145. gRPCStatusError,
  146. )
  147. from ray.serve.generated.serve_pb2 import (
  148. ASGIRequest,
  149. ASGIResponse,
  150. HealthzResponse,
  151. ListApplicationsResponse,
  152. )
  153. from ray.serve.generated.serve_pb2_grpc import add_ASGIServiceServicer_to_server
  154. from ray.serve.grpc_util import RayServegRPCContext
  155. from ray.serve.handle import DeploymentHandle
  156. from ray.serve.schema import EncodingType, LoggingConfig, ReplicaRank
  157. from ray.util import metrics as ray_metrics
  158. logger = logging.getLogger(SERVE_LOGGER_NAME)
  159. def _wrap_grpc_call(f):
  160. """Decorator that processes grpc methods."""
  161. def serialize(result, metadata):
  162. if metadata.is_streaming and metadata.is_http_request:
  163. return result
  164. else:
  165. # Use cached serializer to avoid per-request instantiation overhead
  166. serializer = RPCSerializer.get_cached_serializer(
  167. metadata.request_serialization,
  168. metadata.response_serialization,
  169. )
  170. return serializer.dumps_response(result)
  171. @wraps(f)
  172. async def wrapper(
  173. self,
  174. request: ASGIRequest,
  175. context: grpc.aio.ServicerContext,
  176. ):
  177. request_metadata = pickle.loads(request.pickled_request_metadata)
  178. # Get cached serializer with options from metadata
  179. serializer = RPCSerializer.get_cached_serializer(
  180. request_metadata.request_serialization,
  181. request_metadata.response_serialization,
  182. )
  183. request_args = serializer.loads_request(request.request_args)
  184. request_kwargs = serializer.loads_request(request.request_kwargs)
  185. if request_metadata.is_http_request or request_metadata.is_grpc_request:
  186. request_args = (pickle.loads(request_args[0]),)
  187. try:
  188. result = await f(
  189. self, context, request_metadata, *request_args, **request_kwargs
  190. )
  191. return ASGIResponse(serialized_message=serialize(result, request_metadata))
  192. except (Exception, asyncio.CancelledError) as e:
  193. return ASGIResponse(
  194. serialized_message=serializer.dumps_response(e),
  195. is_error=True,
  196. )
  197. @wraps(f)
  198. async def gen_wrapper(
  199. self,
  200. request: ASGIRequest,
  201. context: grpc.aio.ServicerContext,
  202. ):
  203. request_metadata = pickle.loads(request.pickled_request_metadata)
  204. # Get cached serializer with options from metadata
  205. serializer = RPCSerializer.get_cached_serializer(
  206. request_metadata.request_serialization,
  207. request_metadata.response_serialization,
  208. )
  209. request_args = serializer.loads_request(request.request_args)
  210. request_kwargs = serializer.loads_request(request.request_kwargs)
  211. if request_metadata.is_http_request or request_metadata.is_grpc_request:
  212. request_args = (pickle.loads(request_args[0]),)
  213. try:
  214. async for result in f(
  215. self, context, request_metadata, *request_args, **request_kwargs
  216. ):
  217. yield ASGIResponse(
  218. serialized_message=serialize(result, request_metadata)
  219. )
  220. except (Exception, asyncio.CancelledError) as e:
  221. yield ASGIResponse(
  222. serialized_message=serializer.dumps_response(e),
  223. is_error=True,
  224. )
  225. if inspect.isasyncgenfunction(f):
  226. return gen_wrapper
  227. else:
  228. return wrapper
  229. ReplicaMetadata = Tuple[
  230. DeploymentConfig,
  231. DeploymentVersion,
  232. Optional[float],
  233. Optional[int],
  234. Optional[str],
  235. int,
  236. int,
  237. ReplicaRank, # rank
  238. Optional[List[str]], # route_patterns
  239. Optional[List[DeploymentID]], # outbound_deployments
  240. ]
  241. def _load_deployment_def_from_import_path(import_path: str) -> Callable:
  242. module_name, attr_name = parse_import_path(import_path)
  243. deployment_def = getattr(import_module(module_name), attr_name)
  244. # For ray or serve decorated class or function, strip to return
  245. # original body.
  246. if isinstance(deployment_def, RemoteFunction):
  247. deployment_def = deployment_def._function
  248. elif isinstance(deployment_def, ActorClass):
  249. deployment_def = deployment_def.__ray_metadata__.modified_class
  250. elif isinstance(deployment_def, Deployment):
  251. logger.warning(
  252. f'The import path "{import_path}" contains a '
  253. "decorated Serve deployment. The decorator's settings "
  254. "are ignored when deploying via import path."
  255. )
  256. deployment_def = deployment_def.func_or_class
  257. return deployment_def
  258. class ReplicaMetricsManager:
  259. """Manages metrics for the replica.
  260. A variety of metrics are managed:
  261. - Fine-grained metrics are set for every request.
  262. - Autoscaling statistics are periodically pushed to the controller.
  263. - Queue length metrics are periodically recorded as user-facing gauges.
  264. """
  265. PUSH_METRICS_TO_CONTROLLER_TASK_NAME = "push_metrics_to_controller"
  266. RECORD_METRICS_TASK_NAME = "record_metrics"
  267. SET_REPLICA_REQUEST_METRIC_GAUGE_TASK_NAME = "set_replica_request_metric_gauge"
  268. def __init__(
  269. self,
  270. replica_id: ReplicaID,
  271. event_loop: asyncio.BaseEventLoop,
  272. autoscaling_config: Optional[AutoscalingConfig],
  273. ingress: bool,
  274. ):
  275. self._replica_id = replica_id
  276. self._deployment_id = replica_id.deployment_id
  277. self._metrics_pusher = MetricsPusher()
  278. self._metrics_store = InMemoryMetricsStore()
  279. self._ingress = ingress
  280. self._controller_handle = ray.get_actor(
  281. SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE
  282. )
  283. self._num_ongoing_requests = 0
  284. # Store event loop for scheduling async tasks from sync context
  285. self._event_loop = event_loop or asyncio.get_event_loop()
  286. # Cache user_callable_wrapper initialization state to avoid repeated runtime checks
  287. self._custom_metrics_enabled = False
  288. # On first call to _fetch_custom_autoscaling_metrics. Failing validation disables _custom_metrics_enabled
  289. self._checked_custom_metrics = False
  290. self._record_autoscaling_stats_fn = None
  291. # If the interval is set to 0, eagerly sets all metrics.
  292. self._cached_metrics_enabled = RAY_SERVE_METRICS_EXPORT_INTERVAL_MS != 0
  293. self._cached_metrics_interval_s = RAY_SERVE_METRICS_EXPORT_INTERVAL_MS / 1000
  294. # Request counter (only set on replica startup).
  295. self._restart_counter = metrics.Counter(
  296. "serve_deployment_replica_starts",
  297. description=(
  298. "The number of times this replica has been restarted due to failure."
  299. ),
  300. )
  301. self._restart_counter.inc()
  302. # Per-request metrics.
  303. self._request_counter = metrics.Counter(
  304. "serve_deployment_request_counter",
  305. description=(
  306. "The number of queries that have been processed in this replica."
  307. ),
  308. tag_keys=("route",),
  309. )
  310. if self._cached_metrics_enabled:
  311. self._cached_request_counter = defaultdict(int)
  312. self._error_counter = metrics.Counter(
  313. "serve_deployment_error_counter",
  314. description=(
  315. "The number of exceptions that have occurred in this replica."
  316. ),
  317. tag_keys=("route",),
  318. )
  319. if self._cached_metrics_enabled:
  320. self._cached_error_counter = defaultdict(int)
  321. # log REQUEST_LATENCY_BUCKET_MS
  322. logger.debug(f"REQUEST_LATENCY_BUCKETS_MS: {REQUEST_LATENCY_BUCKETS_MS}")
  323. self._processing_latency_tracker = metrics.Histogram(
  324. "serve_deployment_processing_latency_ms",
  325. description="The latency for queries to be processed.",
  326. boundaries=REQUEST_LATENCY_BUCKETS_MS,
  327. tag_keys=("route",),
  328. )
  329. if self._cached_metrics_enabled:
  330. self._cached_latencies = defaultdict(deque)
  331. self._event_loop.create_task(self._report_cached_metrics_forever())
  332. self._num_ongoing_requests_gauge = metrics.Gauge(
  333. "serve_replica_processing_queries",
  334. description="The current number of queries being processed.",
  335. )
  336. self.record_autoscaling_stats_failed_counter = metrics.Counter(
  337. "serve_record_autoscaling_stats_failed",
  338. tag_keys=("exception_name",),
  339. description="The number of errored record_autoscaling_stats invocations.",
  340. )
  341. self.user_autoscaling_stats_latency_tracker = metrics.Histogram(
  342. "serve_user_autoscaling_stats_latency_ms",
  343. description=(
  344. "Time taken to execute the user-defined autoscaling stats function "
  345. "in milliseconds."
  346. ),
  347. boundaries=REQUEST_LATENCY_BUCKETS_MS,
  348. )
  349. self.set_autoscaling_config(autoscaling_config)
  350. if self._is_direct_ingress:
  351. # TODO(alexyang): De-duplicate these metrics from those collected by
  352. # the proxy.
  353. self.ingress_http_request_counter = self._init_ingress_request_counter(
  354. "HTTP"
  355. )
  356. self.ingress_http_request_error_counter = (
  357. self._init_ingress_request_error_counter("HTTP")
  358. )
  359. self.deployment_http_request_error_counter = (
  360. self._init_deployment_request_error_counter("HTTP")
  361. )
  362. logger.debug(f"REQUEST_LATENCY_BUCKETS_MS: {REQUEST_LATENCY_BUCKETS_MS}")
  363. self.ingress_http_processing_latency_tracker = (
  364. self._init_ingress_processing_latency_tracker("HTTP")
  365. )
  366. node_id = ray.get_runtime_context().get_node_id()
  367. node_ip_address = ray.util.get_node_ip_address()
  368. self.ingress_num_ongoing_http_requests_gauge = (
  369. self._init_ingress_num_ongoing_requests_gauge(
  370. "HTTP", node_id, node_ip_address
  371. )
  372. )
  373. self._ingress_ongoing_http_requests = 0
  374. if self._cached_metrics_enabled:
  375. self._cached_ingress_request_counter = defaultdict(
  376. lambda: defaultdict(int)
  377. )
  378. self._cached_ingress_request_error_counter = defaultdict(
  379. lambda: defaultdict(int)
  380. )
  381. self._cached_deployment_request_error_counter = defaultdict(
  382. lambda: defaultdict(int)
  383. )
  384. self._cached_ingress_processing_latencies = defaultdict(
  385. lambda: defaultdict(deque)
  386. )
  387. @property
  388. def _is_direct_ingress(self) -> bool:
  389. return self._ingress and RAY_SERVE_ENABLE_DIRECT_INGRESS
  390. def _init_ingress_request_counter(self, protocol: str):
  391. return ray_metrics.Counter(
  392. f"serve_num_{protocol.lower()}_requests",
  393. description=f"The number of {protocol} requests processed.",
  394. tag_keys=("route", "method", "application", "status_code"),
  395. )
  396. def _init_ingress_request_error_counter(self, protocol: str):
  397. return ray_metrics.Counter(
  398. f"serve_num_{protocol.lower()}_error_requests",
  399. description=(f"The number of errored {protocol} responses."),
  400. tag_keys=(
  401. "route",
  402. "error_code",
  403. "method",
  404. "application",
  405. ),
  406. )
  407. def _init_deployment_request_error_counter(self, protocol: str):
  408. return ray_metrics.Counter(
  409. f"serve_num_deployment_{protocol.lower()}_error_requests",
  410. description=(
  411. f"The number of errored {protocol} responses returned by each deployment."
  412. ),
  413. tag_keys=(
  414. "deployment",
  415. "error_code",
  416. "method",
  417. "route",
  418. "application",
  419. ),
  420. )
  421. def _init_ingress_processing_latency_tracker(self, protocol: str):
  422. return ray_metrics.Histogram(
  423. f"serve_{protocol.lower()}_request_latency_ms",
  424. description=(
  425. f"The end-to-end latency of {protocol} requests "
  426. f"(measured from the Serve ingress)."
  427. ),
  428. boundaries=REQUEST_LATENCY_BUCKETS_MS,
  429. tag_keys=(
  430. "method",
  431. "route",
  432. "application",
  433. "status_code",
  434. ),
  435. )
  436. def _init_ingress_num_ongoing_requests_gauge(
  437. self, protocol: str, node_id: str, node_ip_address: str
  438. ):
  439. return ray_metrics.Gauge(
  440. name=f"serve_num_ongoing_{protocol.lower()}_requests",
  441. description=f"The number of ongoing requests in this {protocol} ingress.",
  442. tag_keys=("node_id", "node_ip_address"),
  443. ).set_default_tags(
  444. {
  445. "node_id": node_id,
  446. "node_ip_address": node_ip_address,
  447. }
  448. )
  449. def _report_cached_metrics(self):
  450. for route, count in self._cached_request_counter.items():
  451. self._request_counter.inc(count, tags={"route": route})
  452. self._cached_request_counter.clear()
  453. for route, count in self._cached_error_counter.items():
  454. self._error_counter.inc(count, tags={"route": route})
  455. self._cached_error_counter.clear()
  456. for route, latencies in self._cached_latencies.items():
  457. for latency_ms in latencies:
  458. self._processing_latency_tracker.observe(
  459. latency_ms, tags={"route": route}
  460. )
  461. self._cached_latencies.clear()
  462. self._num_ongoing_requests_gauge.set(self._num_ongoing_requests)
  463. if not self._is_direct_ingress:
  464. return
  465. for protocol in [RequestProtocol.HTTP]:
  466. if protocol == RequestProtocol.HTTP:
  467. ingress_request_counter = self.ingress_http_request_counter
  468. ingress_request_error_counter = self.ingress_http_request_error_counter
  469. deployment_request_error_counter = (
  470. self.deployment_http_request_error_counter
  471. )
  472. ingress_processing_latencies = (
  473. self.ingress_http_processing_latency_tracker
  474. )
  475. self.ingress_num_ongoing_http_requests_gauge.set(
  476. self._ingress_ongoing_http_requests
  477. )
  478. else:
  479. # TODO(alexyang): Add metrics for gRPC.
  480. continue
  481. for request_tags, count in self._cached_ingress_request_counter[
  482. protocol
  483. ].items():
  484. ingress_request_counter.inc(count, tags=dict(request_tags))
  485. for request_tags, count in self._cached_ingress_request_error_counter[
  486. protocol
  487. ].items():
  488. ingress_request_error_counter.inc(count, tags=dict(request_tags))
  489. for request_tags, count in self._cached_deployment_request_error_counter[
  490. protocol
  491. ].items():
  492. deployment_request_error_counter.inc(count, tags=dict(request_tags))
  493. for latency_tags, latencies in self._cached_ingress_processing_latencies[
  494. protocol
  495. ].items():
  496. for latency_ms in latencies:
  497. ingress_processing_latencies.observe(
  498. latency_ms, tags=dict(latency_tags)
  499. )
  500. self._cached_ingress_request_counter.clear()
  501. self._cached_ingress_request_error_counter.clear()
  502. self._cached_deployment_request_error_counter.clear()
  503. self._cached_ingress_processing_latencies.clear()
  504. async def _report_cached_metrics_forever(self):
  505. assert self._cached_metrics_interval_s > 0
  506. consecutive_errors = 0
  507. while True:
  508. try:
  509. await asyncio.sleep(self._cached_metrics_interval_s)
  510. self._report_cached_metrics()
  511. consecutive_errors = 0
  512. except Exception:
  513. logger.exception("Unexpected error reporting metrics.")
  514. # Exponential backoff starting at 1s and capping at 10s.
  515. backoff_time_s = min(10, 2**consecutive_errors)
  516. consecutive_errors += 1
  517. await asyncio.sleep(backoff_time_s)
  518. async def shutdown(self):
  519. """Stop periodic background tasks."""
  520. await self._metrics_pusher.graceful_shutdown()
  521. def start_metrics_pusher(self):
  522. self._metrics_pusher.start()
  523. # Push autoscaling metrics to the controller periodically.
  524. self._metrics_pusher.register_or_update_task(
  525. self.PUSH_METRICS_TO_CONTROLLER_TASK_NAME,
  526. self._push_autoscaling_metrics,
  527. self._autoscaling_config.metrics_interval_s,
  528. )
  529. # Collect autoscaling metrics locally periodically.
  530. self._metrics_pusher.register_or_update_task(
  531. self.RECORD_METRICS_TASK_NAME,
  532. self._add_autoscaling_metrics_point_async,
  533. min(
  534. RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_INTERVAL_S,
  535. self._autoscaling_config.metrics_interval_s,
  536. ),
  537. )
  538. def should_collect_ongoing_requests(self) -> bool:
  539. """Determine if replicas should collect ongoing request metrics.
  540. ┌────────────────────────────────────────────────────────────────┐
  541. │ Replica-based metrics collection │
  542. ├────────────────────────────────────────────────────────────────┤
  543. │ │
  544. │ Client Handle Replicas │
  545. │ ┌──────┐ ┌────────┐ │
  546. │ │ App │─────>│ Handle │────┬───>┌─────────┐ │
  547. │ │ │ │ Tracks │ │ │ Replica │ │
  548. │ └──────┘ │ Queued │ │ │ 1 │ │
  549. │ │Requests│ │ │ Tracks │ │
  550. │ └────────┘ │ │ Running │ │
  551. │ │ │ └─────────┘ │
  552. │ │ │ │ │
  553. │ │ │ │ │
  554. │ │ │ ┌─────────┐ │
  555. │ │ └───>│ Replica │ │
  556. │ │ │ 2 │ │
  557. │ │ │ Tracks │ │
  558. │ │ │ Running │ │
  559. │ │ └─────────┘ │
  560. │ │ │ │
  561. │ │ │ │
  562. │ ▼ ▼ │
  563. │ ┌──────────────────────────────┐ │
  564. │ │ Controller │ │
  565. │ │ • Queued metrics (handle) │ │
  566. │ │ • Running metrics (replica1)│ │
  567. │ │ • Running metrics (replica2)│ │
  568. │ └──────────────────────────────┘ │
  569. │ │
  570. └────────────────────────────────────────────────────────────────┘
  571. For direct ingress deployments, metrics must be collected from replicas regardless
  572. of whether autoscaling metrics are being collected via handles. This is necessary
  573. because direct ingress traffic bypasses deployment handles and goes directly to
  574. the replicas.
  575. """
  576. if self._is_direct_ingress and self._autoscaling_config:
  577. return True
  578. return not RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE
  579. def set_autoscaling_config(self, autoscaling_config: Optional[AutoscalingConfig]):
  580. """Dynamically update autoscaling config."""
  581. self._autoscaling_config = autoscaling_config
  582. if self._autoscaling_config and self.should_collect_ongoing_requests():
  583. self.start_metrics_pusher()
  584. def enable_custom_autoscaling_metrics(
  585. self,
  586. custom_metrics_enabled: bool,
  587. record_autoscaling_stats_fn: Callable[[], Optional[concurrent.futures.Future]],
  588. ):
  589. """Runs after the user callable wrapper is initialized to enable autoscaling metrics collection."""
  590. if custom_metrics_enabled:
  591. self._custom_metrics_enabled = custom_metrics_enabled
  592. self._record_autoscaling_stats_fn = record_autoscaling_stats_fn
  593. self.start_metrics_pusher()
  594. def inc_num_ongoing_requests(self, request_metadata: RequestMetadata) -> int:
  595. self._num_ongoing_requests += 1
  596. if self._is_direct_ingress and request_metadata.is_direct_ingress:
  597. self._ingress_ongoing_http_requests += 1
  598. if not self._cached_metrics_enabled:
  599. self._num_ongoing_requests_gauge.set(self._num_ongoing_requests)
  600. if self._is_direct_ingress and request_metadata.is_direct_ingress:
  601. if request_metadata.is_http_request:
  602. self.ingress_num_ongoing_http_requests_gauge.set(
  603. self._ingress_ongoing_http_requests
  604. )
  605. def dec_num_ongoing_requests(self, request_metadata: RequestMetadata) -> int:
  606. self._num_ongoing_requests -= 1
  607. if self._is_direct_ingress and request_metadata.is_direct_ingress:
  608. self._ingress_ongoing_http_requests -= 1
  609. if not self._cached_metrics_enabled:
  610. self._num_ongoing_requests_gauge.set(self._num_ongoing_requests)
  611. if self._is_direct_ingress and request_metadata.is_direct_ingress:
  612. if request_metadata.is_http_request:
  613. self.ingress_num_ongoing_http_requests_gauge.set(
  614. self._ingress_ongoing_http_requests
  615. )
  616. def get_num_ongoing_requests(self) -> int:
  617. """Get current total queue length of requests for this replica."""
  618. return self._num_ongoing_requests
  619. def record_request_metrics(self, *, route: str, latency_ms: float, was_error: bool):
  620. """Records per-request metrics."""
  621. if self._cached_metrics_enabled:
  622. self._cached_latencies[route].append(latency_ms)
  623. if was_error:
  624. self._cached_error_counter[route] += 1
  625. else:
  626. self._cached_request_counter[route] += 1
  627. else:
  628. self._processing_latency_tracker.observe(latency_ms, tags={"route": route})
  629. if was_error:
  630. self._error_counter.inc(tags={"route": route})
  631. else:
  632. self._request_counter.inc(tags={"route": route})
  633. def record_ingress_request_metrics(
  634. self,
  635. *,
  636. protocol: RequestProtocol,
  637. method: str,
  638. route: str,
  639. app_name: str,
  640. deployment_name: str,
  641. latency_ms: float,
  642. was_error: bool,
  643. status_code: str,
  644. ):
  645. """Record per-request metrics."""
  646. if not self._is_direct_ingress:
  647. return
  648. if protocol == RequestProtocol.HTTP:
  649. latency_tracker = self.ingress_http_processing_latency_tracker
  650. request_error_counter = self.ingress_http_request_error_counter
  651. deployment_error_counter = self.deployment_http_request_error_counter
  652. request_counter = self.ingress_http_request_counter
  653. else:
  654. # TODO(alexyang): Add metrics for gRPC.
  655. return
  656. request_tags = {
  657. "route": route,
  658. "method": method,
  659. "application": app_name,
  660. "status_code": status_code,
  661. }
  662. latency_tags = request_tags
  663. request_error_tags = {
  664. "route": route,
  665. "method": method,
  666. "application": app_name,
  667. "error_code": status_code,
  668. }
  669. deployment_error_tags = {
  670. "route": route,
  671. "method": method,
  672. "application": app_name,
  673. "error_code": status_code,
  674. "deployment": deployment_name,
  675. }
  676. if self._cached_metrics_enabled:
  677. self._cached_ingress_request_counter[protocol][
  678. frozenset(request_tags.items())
  679. ] += 1
  680. self._cached_ingress_processing_latencies[protocol][
  681. frozenset(latency_tags.items())
  682. ].append(latency_ms)
  683. if was_error:
  684. self._cached_ingress_request_error_counter[protocol][
  685. frozenset(request_error_tags.items())
  686. ] += 1
  687. self._cached_deployment_request_error_counter[protocol][
  688. frozenset(deployment_error_tags.items())
  689. ] += 1
  690. else:
  691. request_counter.inc(tags=request_tags)
  692. latency_tracker.observe(latency_ms, tags=latency_tags)
  693. if was_error:
  694. request_error_counter.inc(tags=request_error_tags)
  695. deployment_error_counter.inc(tags=deployment_error_tags)
  696. def _push_autoscaling_metrics(self) -> Dict[str, Any]:
  697. look_back_period = self._autoscaling_config.look_back_period_s
  698. self._metrics_store.prune_keys_and_compact_data(time.time() - look_back_period)
  699. new_aggregated_metrics = {}
  700. new_metrics = {**self._metrics_store.data}
  701. if self.should_collect_ongoing_requests():
  702. # Keep the legacy window_avg ongoing requests in the merged metrics dict
  703. window_avg = (
  704. self._metrics_store.aggregate_avg([RUNNING_REQUESTS_KEY])[0] or 0.0
  705. )
  706. new_aggregated_metrics.update({RUNNING_REQUESTS_KEY: window_avg})
  707. replica_metric_report = ReplicaMetricReport(
  708. replica_id=self._replica_id,
  709. timestamp=time.time(),
  710. aggregated_metrics=new_aggregated_metrics,
  711. metrics=new_metrics,
  712. )
  713. self._controller_handle.record_autoscaling_metrics_from_replica.remote(
  714. replica_metric_report
  715. )
  716. async def _fetch_custom_autoscaling_metrics(
  717. self,
  718. ) -> Optional[Dict[str, Union[int, float]]]:
  719. try:
  720. start_time = time.time()
  721. res = await asyncio.wait_for(
  722. self._record_autoscaling_stats_fn(),
  723. timeout=RAY_SERVE_RECORD_AUTOSCALING_STATS_TIMEOUT_S,
  724. )
  725. latency_ms = (time.time() - start_time) * 1000
  726. self.user_autoscaling_stats_latency_tracker.observe(latency_ms)
  727. # Perform validation only first call
  728. if not self._checked_custom_metrics:
  729. # Enforce return type to be Dict[str, Union[int, float]]
  730. if not isinstance(res, dict):
  731. logger.error(
  732. f"User autoscaling stats method returned {type(res).__name__}, "
  733. f"expected Dict[str, Union[int, float]]. Disabling autoscaling stats."
  734. )
  735. self._custom_metrics_enabled = False
  736. return None
  737. for key, value in res.items():
  738. if not isinstance(value, (int, float)):
  739. logger.error(
  740. f"User autoscaling stats method returned invalid value type "
  741. f"{type(value).__name__} for key '{key}', expected int or float. "
  742. f"Disabling autoscaling stats."
  743. )
  744. self._custom_metrics_enabled = False
  745. return None
  746. self._checked_custom_metrics = True
  747. return res
  748. except asyncio.TimeoutError as e:
  749. logger.error(
  750. f"Replica autoscaling stats timed out after {RAY_SERVE_RECORD_AUTOSCALING_STATS_TIMEOUT_S}s."
  751. )
  752. self.record_autoscaling_stats_failed_counter.inc(
  753. tags={"exception_name": e.__class__.__name__}
  754. )
  755. except Exception as e:
  756. logger.error(f"Replica autoscaling stats failed. {e}")
  757. self.record_autoscaling_stats_failed_counter.inc(
  758. tags={"exception_name": e.__class__.__name__}
  759. )
  760. return None
  761. async def _add_autoscaling_metrics_point_async(self) -> None:
  762. metrics_dict = {}
  763. if self.should_collect_ongoing_requests():
  764. metrics_dict = {RUNNING_REQUESTS_KEY: self._num_ongoing_requests}
  765. # Use cached availability flag to avoid repeated runtime checks
  766. if self._custom_metrics_enabled:
  767. custom_metrics = await self._fetch_custom_autoscaling_metrics()
  768. if custom_metrics:
  769. metrics_dict.update(custom_metrics)
  770. self._metrics_store.add_metrics_point(
  771. metrics_dict,
  772. time.time(),
  773. )
  774. StatusCodeCallback = Callable[[str], None]
  775. class ReplicaBase(ABC):
  776. def __init__(
  777. self,
  778. replica_id: ReplicaID,
  779. deployment_def: Callable,
  780. init_args: Tuple,
  781. init_kwargs: Dict,
  782. deployment_config: DeploymentConfig,
  783. version: DeploymentVersion,
  784. ingress: bool,
  785. route_prefix: str,
  786. ):
  787. self._version = version
  788. self._replica_id = replica_id
  789. self._deployment_id = replica_id.deployment_id
  790. self._deployment_config = deployment_config
  791. self._ingress = ingress
  792. self._route_prefix = route_prefix
  793. self._component_name = f"{self._deployment_id.name}"
  794. if self._deployment_id.app_name:
  795. self._component_name = (
  796. f"{self._deployment_id.app_name}_" + self._component_name
  797. )
  798. self._component_id = self._replica_id.unique_id
  799. self._configure_logger_and_profilers(self._deployment_config.logging_config)
  800. self._event_loop = get_or_create_event_loop()
  801. actor_id = ray.get_runtime_context().get_actor_id()
  802. self._user_callable_wrapper = UserCallableWrapper(
  803. deployment_def,
  804. init_args,
  805. init_kwargs,
  806. deployment_id=self._deployment_id,
  807. run_sync_methods_in_threadpool=RAY_SERVE_RUN_SYNC_IN_THREADPOOL,
  808. run_user_code_in_separate_thread=RAY_SERVE_RUN_USER_CODE_IN_SEPARATE_THREAD,
  809. local_testing_mode=False,
  810. deployment_config=deployment_config,
  811. actor_id=actor_id,
  812. ray_actor_options=self._version.ray_actor_options,
  813. )
  814. self._semaphore = Semaphore(lambda: self.max_ongoing_requests)
  815. # Guards against calling the user's callable constructor multiple times.
  816. self._user_callable_initialized = False
  817. self._user_callable_initialized_lock = asyncio.Lock()
  818. self._initialization_latency: Optional[float] = None
  819. # Track deployment handles created dynamically via get_deployment_handle()
  820. self._dynamically_created_handles: Set[DeploymentID] = set()
  821. # Flipped to `True` when health checks pass and `False` when they fail. May be
  822. # used by replica subclass implementations.
  823. self._healthy = False
  824. # Flipped to `True` once graceful shutdown is initiated. May be used by replica
  825. # subclass implementations.
  826. self._shutting_down = False
  827. # Will be populated with the wrapped ASGI app if the user callable is an
  828. # `ASGIAppReplicaWrapper` (i.e., they are using the FastAPI integration).
  829. self._user_callable_asgi_app: Optional[ASGIApp] = None
  830. # Set metadata for logs and metrics.
  831. # servable_object will be populated in `initialize_and_get_metadata`.
  832. self._set_internal_replica_context(servable_object=None, rank=None)
  833. self._metrics_manager = create_replica_metrics_manager(
  834. replica_id=replica_id,
  835. event_loop=self._event_loop,
  836. autoscaling_config=self._deployment_config.autoscaling_config,
  837. ingress=ingress,
  838. )
  839. # Start event loop monitoring for the replica's main event loop.
  840. self._main_loop_monitor = EventLoopMonitor(
  841. component=EventLoopMonitor.COMPONENT_REPLICA,
  842. loop_type=EventLoopMonitor.LOOP_TYPE_MAIN,
  843. actor_id=actor_id,
  844. extra_tags={
  845. "deployment": self._deployment_id.name,
  846. "application": self._deployment_id.app_name,
  847. },
  848. )
  849. self._main_loop_monitor.start(self._event_loop)
  850. self._internal_grpc_port: Optional[int] = None
  851. self._docs_path: Optional[str] = None
  852. self._http_port: Optional[int] = None
  853. self._grpc_port: Optional[int] = None
  854. self._rank: Optional[ReplicaRank] = None
  855. # gRPC server for inter-deployment communication
  856. self._server = grpc.aio.server(
  857. options=[
  858. (
  859. "grpc.max_receive_message_length",
  860. RAY_SERVE_REPLICA_GRPC_MAX_MESSAGE_LENGTH,
  861. )
  862. ]
  863. )
  864. # Silence spammy false positive errors from gRPC Python
  865. self._event_loop.set_exception_handler(asyncio_grpc_exception_handler)
  866. @property
  867. def max_ongoing_requests(self) -> int:
  868. return self._deployment_config.max_ongoing_requests
  869. def get_num_ongoing_requests(self) -> int:
  870. return self._metrics_manager.get_num_ongoing_requests()
  871. def get_metadata(self) -> ReplicaMetadata:
  872. current_rank = ray.serve.context._get_internal_replica_context().rank
  873. # Extract route patterns from ASGI app if available
  874. route_patterns = None
  875. if self._user_callable_asgi_app is not None:
  876. # _user_callable_asgi_app is the actual ASGI app (FastAPI/Starlette)
  877. # It's set when initialize_callable() returns an ASGI app
  878. if hasattr(self._user_callable_asgi_app, "routes"):
  879. route_patterns = extract_route_patterns(self._user_callable_asgi_app)
  880. return (
  881. self._version.deployment_config,
  882. self._version,
  883. self._initialization_latency,
  884. self._internal_grpc_port,
  885. self._docs_path,
  886. self._http_port,
  887. self._grpc_port,
  888. current_rank,
  889. route_patterns,
  890. self.list_outbound_deployments(),
  891. )
  892. def get_dynamically_created_handles(self) -> Set[DeploymentID]:
  893. return self._dynamically_created_handles
  894. def list_outbound_deployments(self) -> List[DeploymentID]:
  895. """List all outbound deployment IDs this replica calls into.
  896. This includes:
  897. - Handles created via get_deployment_handle()
  898. - Handles passed as init args/kwargs to the deployment constructor
  899. This is used to determine which deployments are reachable from this replica.
  900. The list of DeploymentIDs can change over time as new handles can be created at runtime.
  901. Also its not guaranteed that the list of DeploymentIDs are identical across replicas
  902. because it depends on user code.
  903. Returns:
  904. A list of DeploymentIDs that this replica calls into.
  905. """
  906. seen_deployment_ids: Set[DeploymentID] = set()
  907. # First, collect dynamically created handles
  908. for deployment_id in self.get_dynamically_created_handles():
  909. seen_deployment_ids.add(deployment_id)
  910. # Get the init args/kwargs
  911. init_args = self._user_callable_wrapper._init_args
  912. init_kwargs = self._user_callable_wrapper._init_kwargs
  913. # Use _PyObjScanner to find all DeploymentHandle objects in:
  914. # The init_args and init_kwargs (handles might be passed as init args)
  915. scanner = _PyObjScanner(source_type=DeploymentHandle)
  916. try:
  917. handles = scanner.find_nodes((init_args, init_kwargs))
  918. for handle in handles:
  919. deployment_id = handle.deployment_id
  920. seen_deployment_ids.add(deployment_id)
  921. finally:
  922. scanner.clear()
  923. return list(seen_deployment_ids)
  924. def _set_internal_replica_context(
  925. self, *, servable_object: Callable = None, rank: ReplicaRank = None
  926. ):
  927. # Calculate world_size from deployment config instead of storing it
  928. world_size = self._deployment_config.num_replicas
  929. # Create callback for registering dynamically created handles
  930. def register_handle_callback(deployment_id: DeploymentID) -> None:
  931. self._dynamically_created_handles.add(deployment_id)
  932. ray.serve.context._set_internal_replica_context(
  933. replica_id=self._replica_id,
  934. servable_object=servable_object,
  935. _deployment_config=self._deployment_config,
  936. rank=rank,
  937. world_size=world_size,
  938. handle_registration_callback=register_handle_callback,
  939. )
  940. def _configure_logger_and_profilers(
  941. self, logging_config: Union[None, Dict, LoggingConfig]
  942. ):
  943. if logging_config is None:
  944. logging_config = {}
  945. if isinstance(logging_config, dict):
  946. logging_config = LoggingConfig(**logging_config)
  947. configure_component_logger(
  948. component_type=ServeComponentType.REPLICA,
  949. component_name=self._component_name,
  950. component_id=self._component_id,
  951. logging_config=logging_config,
  952. buffer_size=RAY_SERVE_REQUEST_PATH_LOG_BUFFER_SIZE,
  953. )
  954. configure_component_memory_profiler(
  955. component_type=ServeComponentType.REPLICA,
  956. component_name=self._component_name,
  957. component_id=self._component_id,
  958. )
  959. if logging_config.encoding == EncodingType.JSON:
  960. # Create logging context for access logs as a performance optimization.
  961. # While logging_utils can automatically add Ray core and Serve access log context,
  962. # we pre-compute it here since context evaluation is expensive and this context
  963. # will be reused for multiple access log entries.
  964. ray_core_logging_context = CoreContextFilter.get_ray_core_logging_context()
  965. # remove task level log keys from ray core logging context, it would be nice
  966. # to have task level log keys here but we are letting those go in favor of
  967. # performance optimization. Also we cannot include task level log keys here because
  968. # they would referance the current task (__init__) and not the task that is logging.
  969. for key in CoreContextFilter.TASK_LEVEL_LOG_KEYS:
  970. ray_core_logging_context.pop(key, None)
  971. self._access_log_context = {
  972. **ray_core_logging_context,
  973. SERVE_LOG_DEPLOYMENT: self._component_name,
  974. SERVE_LOG_REPLICA: self._component_id,
  975. SERVE_LOG_COMPONENT: ServeComponentType.REPLICA,
  976. SERVE_LOG_APPLICATION: self._deployment_id.app_name,
  977. "skip_context_filter": True,
  978. "serve_access_log": True,
  979. }
  980. else:
  981. self._access_log_context = {
  982. "skip_context_filter": True,
  983. "serve_access_log": True,
  984. }
  985. def _can_accept_request(self, request_metadata: RequestMetadata) -> bool:
  986. # This replica gates concurrent request handling with an asyncio.Semaphore.
  987. # Each in-flight request acquires the semaphore. When the number of ongoing
  988. # requests reaches max_ongoing_requests, the semaphore becomes locked.
  989. # A new request can be accepted if the semaphore is currently unlocked.
  990. return not self._semaphore.locked()
  991. @contextmanager
  992. def _handle_errors_and_metrics(
  993. self, request_metadata: RequestMetadata
  994. ) -> Generator[StatusCodeCallback, None, None]:
  995. start_time = time.time()
  996. user_exception = None
  997. status_code = None
  998. def _status_code_callback(s: str):
  999. nonlocal status_code
  1000. status_code = s
  1001. try:
  1002. yield _status_code_callback
  1003. except asyncio.CancelledError as e:
  1004. user_exception = e
  1005. self._on_request_cancelled(request_metadata, e)
  1006. except Exception as e:
  1007. user_exception = e
  1008. logger.exception("Request failed.")
  1009. self._on_request_failed(request_metadata, e)
  1010. latency_ms = (time.time() - start_time) * 1000
  1011. self._record_errors_and_metrics(
  1012. user_exception, status_code, latency_ms, request_metadata
  1013. )
  1014. if user_exception is not None:
  1015. raise user_exception from None
  1016. def _record_errors_and_metrics(
  1017. self,
  1018. user_exception: Optional[BaseException],
  1019. status_code: Optional[str],
  1020. latency_ms: float,
  1021. request_metadata: RequestMetadata,
  1022. ):
  1023. http_method = request_metadata._http_method
  1024. http_route = request_metadata.route
  1025. call_method = request_metadata.call_method
  1026. if user_exception is None:
  1027. status_str = "OK"
  1028. elif isinstance(user_exception, asyncio.CancelledError):
  1029. status_str = "CANCELLED"
  1030. else:
  1031. status_str = "ERROR"
  1032. # Mutating self._access_log_context is not thread safe, but since this
  1033. # is only called from the same thread, it is safe. Mutating the same object
  1034. # because creating a new dict is expensive.
  1035. self._access_log_context[SERVE_LOG_ROUTE] = http_route
  1036. self._access_log_context[SERVE_LOG_REQUEST_ID] = request_metadata.request_id
  1037. logger.info(
  1038. access_log_msg(
  1039. method=http_method or "CALL",
  1040. route=http_route if self._ingress and http_route else call_method,
  1041. # Prefer the HTTP status code if it was populated.
  1042. status=status_code or status_str,
  1043. latency_ms=latency_ms,
  1044. ),
  1045. extra=self._access_log_context,
  1046. )
  1047. self._metrics_manager.record_request_metrics(
  1048. route=http_route,
  1049. latency_ms=latency_ms,
  1050. was_error=user_exception is not None,
  1051. )
  1052. # Record ingress metrics for direct ingress HTTP requests
  1053. if request_metadata.is_direct_ingress and status_code is not None:
  1054. self._metrics_manager.record_ingress_request_metrics(
  1055. protocol=RequestProtocol.HTTP,
  1056. method=request_metadata._http_method,
  1057. route=self._route_prefix,
  1058. app_name=self._deployment_id.app_name,
  1059. deployment_name=self._deployment_id.name,
  1060. latency_ms=latency_ms,
  1061. was_error=status_code.startswith(("4", "5")),
  1062. status_code=status_code,
  1063. )
  1064. def _unpack_proxy_args(
  1065. self,
  1066. request_metadata: RequestMetadata,
  1067. request_args: Tuple[Any],
  1068. request_kwargs: Dict[str, Any],
  1069. ) -> Tuple[Tuple[Any], Dict[str, Any], Any]:
  1070. # Extract _ray_trace_ctx from kwargs at the entry point.
  1071. #
  1072. # Context: When tracing is enabled, Ray's tracing decorators inject
  1073. # _ray_trace_ctx into ServeReplica actor method calls. The ServeReplica
  1074. # actor methods properly handle this, but we
  1075. # need to extract it before calling user-defined deployment methods.
  1076. #
  1077. # Design: We return it so it can be passed to _wrap_request() which
  1078. # stores it in _RequestContext. Users can then access it via serve.context
  1079. # if needed (advanced use case), while keeping it out of their method signatures.
  1080. ray_trace_ctx = request_kwargs.pop("_ray_trace_ctx", None)
  1081. if request_metadata.is_http_request:
  1082. assert len(request_args) == 1 and isinstance(
  1083. request_args[0], StreamingHTTPRequest
  1084. )
  1085. request: StreamingHTTPRequest = request_args[0]
  1086. scope = request.asgi_scope
  1087. receive = ASGIReceiveProxy(
  1088. scope, request_metadata, request.receive_asgi_messages
  1089. )
  1090. request_metadata._http_method = scope.get("method", "WS")
  1091. request_args = (scope, receive)
  1092. elif request_metadata.is_grpc_request:
  1093. assert len(request_args) == 1 and isinstance(request_args[0], gRPCRequest)
  1094. request: gRPCRequest = request_args[0]
  1095. method_info = self._user_callable_wrapper.get_user_method_info(
  1096. request_metadata.call_method
  1097. )
  1098. request_args = (request.user_request_proto,)
  1099. request_kwargs = (
  1100. {GRPC_CONTEXT_ARG_NAME: request_metadata.grpc_context}
  1101. if method_info.takes_grpc_context_kwarg
  1102. else {}
  1103. )
  1104. return request_args, request_kwargs, ray_trace_ctx
  1105. async def handle_request(
  1106. self, request_metadata: RequestMetadata, *request_args, **request_kwargs
  1107. ) -> Tuple[bytes, Any]:
  1108. request_args, request_kwargs, ray_trace_ctx = self._unpack_proxy_args(
  1109. request_metadata, request_args, request_kwargs
  1110. )
  1111. with self._wrap_request(request_metadata, ray_trace_ctx):
  1112. async with self._start_request(request_metadata):
  1113. try:
  1114. return await self._user_callable_wrapper.call_user_method(
  1115. request_metadata, request_args, request_kwargs
  1116. )
  1117. except Exception as e:
  1118. # For gRPC requests, wrap exception with user-set status code
  1119. raise self._maybe_wrap_grpc_exception(e, request_metadata) from e
  1120. async def handle_request_streaming(
  1121. self, request_metadata: RequestMetadata, *request_args, **request_kwargs
  1122. ) -> AsyncGenerator[Any, None]:
  1123. """Generator that is the entrypoint for all `stream=True` handle calls."""
  1124. request_args, request_kwargs, ray_trace_ctx = self._unpack_proxy_args(
  1125. request_metadata, request_args, request_kwargs
  1126. )
  1127. with self._wrap_request(
  1128. request_metadata, ray_trace_ctx
  1129. ) as status_code_callback:
  1130. async with self._start_request(request_metadata):
  1131. try:
  1132. if request_metadata.is_http_request:
  1133. scope, receive = request_args
  1134. async for msgs in self._user_callable_wrapper.call_http_entrypoint(
  1135. request_metadata,
  1136. status_code_callback,
  1137. scope,
  1138. receive,
  1139. ):
  1140. yield pickle.dumps(msgs)
  1141. else:
  1142. async for result in self._user_callable_wrapper.call_user_generator(
  1143. request_metadata,
  1144. request_args,
  1145. request_kwargs,
  1146. ):
  1147. yield result
  1148. except Exception as e:
  1149. # For gRPC requests, wrap exception with user-set status code
  1150. raise self._maybe_wrap_grpc_exception(e, request_metadata) from e
  1151. def _maybe_wrap_grpc_exception(
  1152. self, e: BaseException, request_metadata: RequestMetadata
  1153. ) -> BaseException:
  1154. """Wrap exception with gRPCStatusError if user set a status code.
  1155. For gRPC requests, if the user set a status code on the grpc_context before
  1156. raising an exception, we wrap the exception with gRPCStatusError to preserve
  1157. the user's intended status code through the error handling path.
  1158. """
  1159. if request_metadata.is_grpc_request:
  1160. grpc_context = request_metadata.grpc_context
  1161. if grpc_context and grpc_context.code():
  1162. return gRPCStatusError(
  1163. original_exception=e,
  1164. code=grpc_context.code(),
  1165. details=grpc_context.details(),
  1166. )
  1167. return e
  1168. async def handle_request_with_rejection(
  1169. self, request_metadata: RequestMetadata, *request_args, **request_kwargs
  1170. ):
  1171. # Check if the replica has capacity for the request.
  1172. if not self._can_accept_request(request_metadata):
  1173. limit = self.max_ongoing_requests
  1174. logger.warning(
  1175. f"Replica at capacity of max_ongoing_requests={limit}, "
  1176. f"rejecting request {request_metadata.request_id}.",
  1177. extra={"log_to_stderr": False},
  1178. )
  1179. yield ReplicaQueueLengthInfo(False, self.get_num_ongoing_requests())
  1180. return
  1181. request_args, request_kwargs, ray_trace_ctx = self._unpack_proxy_args(
  1182. request_metadata, request_args, request_kwargs
  1183. )
  1184. with self._wrap_request(
  1185. request_metadata, ray_trace_ctx
  1186. ) as status_code_callback:
  1187. async with self._start_request(request_metadata):
  1188. yield ReplicaQueueLengthInfo(
  1189. accepted=True,
  1190. # NOTE(edoakes): `_wrap_request` will increment the number
  1191. # of ongoing requests to include this one, so re-fetch the value.
  1192. num_ongoing_requests=self.get_num_ongoing_requests(),
  1193. )
  1194. try:
  1195. if request_metadata.is_http_request:
  1196. scope, receive = request_args
  1197. async for msgs in self._user_callable_wrapper.call_http_entrypoint(
  1198. request_metadata,
  1199. status_code_callback,
  1200. scope,
  1201. receive,
  1202. ):
  1203. yield pickle.dumps(msgs)
  1204. elif request_metadata.is_streaming:
  1205. async for result in self._user_callable_wrapper.call_user_generator(
  1206. request_metadata,
  1207. request_args,
  1208. request_kwargs,
  1209. ):
  1210. yield result
  1211. else:
  1212. yield await self._user_callable_wrapper.call_user_method(
  1213. request_metadata, request_args, request_kwargs
  1214. )
  1215. except Exception as e:
  1216. # For gRPC requests, wrap exception with user-set status code
  1217. raise self._maybe_wrap_grpc_exception(e, request_metadata) from e
  1218. @abstractmethod
  1219. async def _on_initialized(self):
  1220. raise NotImplementedError
  1221. async def initialize(
  1222. self, deployment_config: Optional[DeploymentConfig], rank: Optional[ReplicaRank]
  1223. ):
  1224. if rank is not None:
  1225. self._rank = rank
  1226. self._set_internal_replica_context(
  1227. servable_object=self._user_callable_wrapper.user_callable, rank=rank
  1228. )
  1229. try:
  1230. # Ensure that initialization is only performed once.
  1231. # When controller restarts, it will call this method again.
  1232. async with self._user_callable_initialized_lock:
  1233. self._initialization_start_time = time.time()
  1234. if not self._user_callable_initialized:
  1235. self._user_callable_asgi_app = (
  1236. await self._user_callable_wrapper.initialize_callable()
  1237. )
  1238. if self._user_callable_asgi_app:
  1239. self._docs_path = (
  1240. self._user_callable_wrapper._callable.docs_path
  1241. )
  1242. await self._on_initialized()
  1243. self._user_callable_initialized = True
  1244. if self._user_callable_wrapper is not None:
  1245. initialized = (
  1246. hasattr(
  1247. self._user_callable_wrapper, "_user_autoscaling_stats"
  1248. )
  1249. and self._user_callable_wrapper._user_autoscaling_stats
  1250. is not None
  1251. )
  1252. self._metrics_manager.enable_custom_autoscaling_metrics(
  1253. custom_metrics_enabled=initialized,
  1254. record_autoscaling_stats_fn=self._user_callable_wrapper.call_record_autoscaling_stats,
  1255. )
  1256. if deployment_config is not None:
  1257. await self._user_callable_wrapper.set_sync_method_threadpool_limit(
  1258. deployment_config.max_ongoing_requests
  1259. )
  1260. rank = ray.serve.context._get_internal_replica_context().rank
  1261. await self._user_callable_wrapper.call_reconfigure(
  1262. deployment_config.user_config,
  1263. rank=rank,
  1264. )
  1265. # A new replica should not be considered healthy until it passes
  1266. # an initial health check. If an initial health check fails,
  1267. # consider it an initialization failure.
  1268. await self.check_health()
  1269. except Exception:
  1270. raise RuntimeError(traceback.format_exc()) from None
  1271. async def reconfigure(
  1272. self,
  1273. deployment_config: DeploymentConfig,
  1274. rank: ReplicaRank,
  1275. route_prefix: Optional[str] = None,
  1276. ):
  1277. try:
  1278. user_config_changed = (
  1279. deployment_config.user_config != self._deployment_config.user_config
  1280. )
  1281. rank_changed = rank != self._rank
  1282. self._rank = rank
  1283. logging_config_changed = (
  1284. deployment_config.logging_config
  1285. != self._deployment_config.logging_config
  1286. )
  1287. self._deployment_config = deployment_config
  1288. self._version = DeploymentVersion.from_deployment_version(
  1289. self._version, deployment_config, route_prefix
  1290. )
  1291. self._metrics_manager.set_autoscaling_config(
  1292. deployment_config.autoscaling_config
  1293. )
  1294. if logging_config_changed:
  1295. self._configure_logger_and_profilers(deployment_config.logging_config)
  1296. await self._user_callable_wrapper.set_sync_method_threadpool_limit(
  1297. deployment_config.max_ongoing_requests
  1298. )
  1299. if user_config_changed or rank_changed:
  1300. await self._user_callable_wrapper.call_reconfigure(
  1301. deployment_config.user_config,
  1302. rank=rank,
  1303. )
  1304. # We need to update internal replica context to reflect the new
  1305. # deployment_config and rank.
  1306. self._set_internal_replica_context(
  1307. servable_object=self._user_callable_wrapper.user_callable,
  1308. rank=rank,
  1309. )
  1310. self._route_prefix = self._version.route_prefix
  1311. except Exception:
  1312. raise RuntimeError(traceback.format_exc()) from None
  1313. @abstractmethod
  1314. def _on_request_cancelled(
  1315. self, request_metadata: RequestMetadata, e: asyncio.CancelledError
  1316. ):
  1317. pass
  1318. @abstractmethod
  1319. def _on_request_failed(self, request_metadata: RequestMetadata, e: Exception):
  1320. pass
  1321. @abstractmethod
  1322. @contextmanager
  1323. def _wrap_request(
  1324. self, request_metadata: RequestMetadata
  1325. ) -> Generator[StatusCodeCallback, None, None]:
  1326. pass
  1327. @asynccontextmanager
  1328. async def _start_request(self, request_metadata: RequestMetadata):
  1329. async with self._semaphore:
  1330. try:
  1331. self._metrics_manager.inc_num_ongoing_requests(request_metadata)
  1332. yield
  1333. finally:
  1334. self._metrics_manager.dec_num_ongoing_requests(request_metadata)
  1335. async def _drain_ongoing_requests(self):
  1336. """Wait for any ongoing requests to finish.
  1337. Sleep for a grace period before the first time we check the number of ongoing
  1338. requests to allow the notification to remove this replica to propagate to
  1339. callers first.
  1340. """
  1341. wait_loop_period_s = self._deployment_config.graceful_shutdown_wait_loop_s
  1342. while True:
  1343. await asyncio.sleep(wait_loop_period_s)
  1344. num_ongoing_requests = self._metrics_manager.get_num_ongoing_requests()
  1345. if num_ongoing_requests > 0:
  1346. logger.info(
  1347. f"Waiting for an additional {wait_loop_period_s}s to shut down "
  1348. f"because there are {num_ongoing_requests} ongoing requests."
  1349. )
  1350. else:
  1351. logger.info(
  1352. "Graceful shutdown complete; replica exiting.",
  1353. extra={"log_to_stderr": False},
  1354. )
  1355. break
  1356. async def shutdown(self):
  1357. try:
  1358. await self._user_callable_wrapper.call_destructor()
  1359. except: # noqa: E722
  1360. # We catch a blanket exception since the constructor may still be
  1361. # running, so instance variables used by the destructor may not exist.
  1362. if self._user_callable_initialized:
  1363. logger.exception(
  1364. "__del__ ran before replica finished initializing, and "
  1365. "raised an exception."
  1366. )
  1367. else:
  1368. logger.exception("__del__ raised an exception.")
  1369. await self._metrics_manager.shutdown()
  1370. async def perform_graceful_shutdown(self):
  1371. self._shutting_down = True
  1372. # If the replica was never initialized it never served traffic, so we
  1373. # can skip the wait period.
  1374. if self._user_callable_initialized:
  1375. await self._drain_ongoing_requests()
  1376. await self.shutdown()
  1377. async def check_health(self):
  1378. try:
  1379. # If there's no user-defined health check, nothing runs on the user code event
  1380. # loop and no future is returned.
  1381. f = self._user_callable_wrapper.call_user_health_check()
  1382. if f is not None:
  1383. await f
  1384. self._healthy = True
  1385. except Exception as e:
  1386. logger.warning("Replica health check failed.")
  1387. self._healthy = False
  1388. raise e from None
  1389. async def record_routing_stats(self) -> Dict[str, Any]:
  1390. try:
  1391. f = self._user_callable_wrapper.call_user_record_routing_stats()
  1392. if f is not None:
  1393. return await f
  1394. return {}
  1395. except Exception as e:
  1396. logger.warning("Replica record routing stats failed.")
  1397. raise e from None
  1398. async def send_http_response(message, status_code, send):
  1399. for msg in convert_object_to_asgi_messages(
  1400. message,
  1401. status_code=status_code,
  1402. ):
  1403. await send(msg)
  1404. class Replica(ReplicaBase):
  1405. def __init__(self, **kwargs):
  1406. super().__init__(**kwargs)
  1407. self._controller_handle = ray.get_actor(
  1408. SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE
  1409. )
  1410. # get node ID
  1411. self._node_id = ray.get_runtime_context().get_node_id()
  1412. self._http_options: Optional[HTTPOptions] = None
  1413. self._grpc_options: Optional[gRPCOptions] = None
  1414. self._direct_ingress_http_server_task: Optional[asyncio.Task] = None
  1415. self._direct_ingress_grpc_server_task: Optional[asyncio.Task] = None
  1416. self._num_queued_requests = 0
  1417. @property
  1418. def max_queued_requests(self) -> int:
  1419. return self._deployment_config.max_queued_requests
  1420. async def _maybe_start_direct_ingress_servers(self):
  1421. if not RAY_SERVE_ENABLE_DIRECT_INGRESS:
  1422. return
  1423. if not self._ingress:
  1424. return
  1425. async def allocate_and_start_server(start_server_fn, protocol):
  1426. """Attempt to allocate a port and start the server with retries."""
  1427. is_port_in_use = False
  1428. for _ in range(RAY_SERVE_DIRECT_INGRESS_PORT_RETRY_COUNT):
  1429. port = await self._controller_handle.allocate_replica_port.remote(
  1430. self._node_id, self._replica_id.unique_id, protocol
  1431. )
  1432. logger.info(f"Allocated port {port} for {protocol}")
  1433. try:
  1434. server_task = await start_server_fn(port)
  1435. logger.info(
  1436. f"Successfully started {protocol} server on port {port}"
  1437. )
  1438. return port, server_task
  1439. except RuntimeError as e:
  1440. logger.warning(
  1441. f"Failed to start {protocol} server on port {port}: {e}. Retrying..."
  1442. )
  1443. # `start_asgi_http_server` raises a RuntimeError with the original OSError as the cause.
  1444. if isinstance(e.__cause__, OSError) and e.__cause__.errno in (
  1445. errno.EADDRINUSE,
  1446. errno.EADDRNOTAVAIL,
  1447. ):
  1448. is_port_in_use = True
  1449. else:
  1450. is_port_in_use = False
  1451. # setting block_port to True because we are concluding that the port is
  1452. # in use by another service on the same node. Blocking port here is a small
  1453. # optimization to avoid trying to start the server on a the same port
  1454. # multiple times by other replicas.
  1455. await self._controller_handle.release_replica_port.remote(
  1456. self._node_id,
  1457. self._replica_id.unique_id,
  1458. port,
  1459. protocol,
  1460. block_port=True,
  1461. )
  1462. err_msg = f"Failed to allocate and start {protocol} server after retries"
  1463. if is_port_in_use:
  1464. err_msg = f"""
  1465. Failed to start {protocol} server: port already in use. Suggestion: Ensure that the Ray Serve direct ingress port ranges do not overlap with the Ray worker port range (min_worker_port to max_worker_port).
  1466. """
  1467. raise RuntimeError(err_msg)
  1468. # Fetch configs
  1469. self._http_options, self._grpc_options = ray.get(
  1470. [
  1471. self._controller_handle.get_http_config.remote(),
  1472. self._controller_handle.get_grpc_config.remote(),
  1473. ]
  1474. )
  1475. grpc_enabled = is_grpc_enabled(self._grpc_options)
  1476. # Allocate and start HTTP server
  1477. async def start_http_server(port):
  1478. options = configure_http_middlewares(
  1479. configure_http_options_with_defaults(
  1480. HTTPOptions(**{**self._http_options.dict(), "port": port})
  1481. )
  1482. )
  1483. return await start_asgi_http_server(
  1484. self._direct_ingress_asgi,
  1485. options,
  1486. event_loop=self._event_loop,
  1487. enable_so_reuseport=False,
  1488. )
  1489. (
  1490. self._http_port,
  1491. self._direct_ingress_http_server_task,
  1492. ) = await allocate_and_start_server(
  1493. start_server_fn=start_http_server,
  1494. protocol=RequestProtocol.HTTP,
  1495. )
  1496. # Allocate and start gRPC server if enabled
  1497. if grpc_enabled:
  1498. async def start_grpc_server_fn(port):
  1499. options = gRPCOptions(**{**self._grpc_options.dict(), "port": port})
  1500. return await start_grpc_server(
  1501. self._direct_ingress_service_handler_factory,
  1502. options,
  1503. event_loop=self._event_loop,
  1504. enable_so_reuseport=False,
  1505. )
  1506. (
  1507. self._grpc_port,
  1508. self._direct_ingress_grpc_server_task,
  1509. ) = await allocate_and_start_server(
  1510. start_server_fn=start_grpc_server_fn,
  1511. protocol=RequestProtocol.GRPC,
  1512. )
  1513. logger.info(
  1514. f"Started HTTP server on port {self._http_port}"
  1515. + (f" and gRPC server on port {self._grpc_port}" if grpc_enabled else "")
  1516. )
  1517. async def _on_initialized(self):
  1518. await self._maybe_start_direct_ingress_servers()
  1519. current_rank = ray.serve.context._get_internal_replica_context().rank
  1520. self._set_internal_replica_context(
  1521. servable_object=self._user_callable_wrapper.user_callable,
  1522. rank=current_rank,
  1523. )
  1524. # Start the gRPC server for inter-deployment communication
  1525. add_ASGIServiceServicer_to_server(self, self._server)
  1526. self._internal_grpc_port = self._server.add_insecure_port("[::]:0")
  1527. await self._server.start()
  1528. logger.debug(
  1529. f"Started inter-deployment gRPC server on port {self._internal_grpc_port}"
  1530. )
  1531. # Save the initialization latency if the replica is initializing
  1532. # for the first time.
  1533. if self._initialization_latency is None:
  1534. self._initialization_latency = time.time() - self._initialization_start_time
  1535. def _on_request_cancelled(
  1536. self, metadata: RequestMetadata, e: asyncio.CancelledError
  1537. ):
  1538. """Recursively cancel child requests.
  1539. This includes all requests that are pending assignment, and gRPC
  1540. requests that have already been assigned.
  1541. """
  1542. # Cancel child requests pending assignment
  1543. requests_pending_assignment = (
  1544. ray.serve.context._get_requests_pending_assignment(
  1545. metadata.internal_request_id
  1546. )
  1547. )
  1548. for task in requests_pending_assignment.values():
  1549. task.cancel()
  1550. # Cancel child requests that have already been assigned.
  1551. # This is for gRPC requests and direct ingress requests.
  1552. in_flight_requests = _get_in_flight_requests(metadata.internal_request_id)
  1553. for replica_result in in_flight_requests.values():
  1554. replica_result.cancel()
  1555. def _on_request_failed(self, request_metadata: RequestMetadata, e: Exception):
  1556. if ray.util.pdb._is_ray_debugger_post_mortem_enabled():
  1557. ray.util.pdb._post_mortem()
  1558. def _can_accept_request(self, request_metadata: RequestMetadata):
  1559. if request_metadata.is_direct_ingress:
  1560. limit = self.max_queued_requests
  1561. if limit != -1 and self._num_queued_requests >= limit:
  1562. return False
  1563. return True
  1564. else:
  1565. return super()._can_accept_request(request_metadata)
  1566. @contextmanager
  1567. def _wrap_request(
  1568. self, request_metadata: RequestMetadata, ray_trace_ctx: Optional[Any] = None
  1569. ) -> Generator[StatusCodeCallback, None, None]:
  1570. """Context manager that wraps user method calls.
  1571. 1) Sets the request context var with appropriate metadata.
  1572. 2) Records the access log message (if not disabled).
  1573. 3) Records per-request metrics via the metrics manager.
  1574. """
  1575. ray.serve.context._serve_request_context.set(
  1576. ray.serve.context._RequestContext(
  1577. route=request_metadata.route,
  1578. request_id=request_metadata.request_id,
  1579. _internal_request_id=request_metadata.internal_request_id,
  1580. app_name=self._deployment_id.app_name,
  1581. multiplexed_model_id=request_metadata.multiplexed_model_id,
  1582. grpc_context=request_metadata.grpc_context,
  1583. cancel_on_parent_request_cancel=self._ingress
  1584. and RAY_SERVE_ENABLE_DIRECT_INGRESS,
  1585. _ray_trace_ctx=ray_trace_ctx,
  1586. )
  1587. )
  1588. with self._handle_errors_and_metrics(request_metadata) as status_code_callback:
  1589. yield status_code_callback
  1590. @_wrap_grpc_call
  1591. async def HandleRequest(
  1592. self,
  1593. context: grpc.aio.ServicerContext,
  1594. request_metadata: RequestMetadata,
  1595. *request_args,
  1596. **request_kwargs,
  1597. ):
  1598. result = await self.handle_request(
  1599. request_metadata, *request_args, **request_kwargs
  1600. )
  1601. if request_metadata.is_grpc_request:
  1602. result = (request_metadata.grpc_context, result.SerializeToString())
  1603. return result
  1604. @_wrap_grpc_call
  1605. async def HandleRequestStreaming(
  1606. self,
  1607. context: grpc.aio.ServicerContext,
  1608. request_metadata: RequestMetadata,
  1609. *request_args,
  1610. **request_kwargs,
  1611. ):
  1612. async for result in self.handle_request_streaming(
  1613. request_metadata, *request_args, **request_kwargs
  1614. ):
  1615. if request_metadata.is_grpc_request:
  1616. result = (request_metadata.grpc_context, result.SerializeToString())
  1617. yield result
  1618. @_wrap_grpc_call
  1619. async def HandleRequestWithRejection(
  1620. self,
  1621. context: grpc.aio.ServicerContext,
  1622. request_metadata: RequestMetadata,
  1623. *request_args,
  1624. **request_kwargs,
  1625. ):
  1626. """gRPC entrypoint for all unary requests with strict max_ongoing_requests enforcement
  1627. This generator yields a system message indicating if the request was accepted,
  1628. then the actual response.
  1629. If an exception occurred while processing the request, whether it's a user
  1630. exception or an error intentionally raised by Serve, it will be returned as
  1631. a gRPC response instead of raised directly.
  1632. """
  1633. result_gen = self.handle_request_with_rejection(
  1634. request_metadata, *request_args, **request_kwargs
  1635. )
  1636. queue_len_info: ReplicaQueueLengthInfo = await result_gen.__anext__()
  1637. await context.send_initial_metadata(
  1638. [
  1639. ("accepted", str(int(queue_len_info.accepted))),
  1640. ("num_ongoing_requests", str(queue_len_info.num_ongoing_requests)),
  1641. ]
  1642. )
  1643. if not queue_len_info.accepted:
  1644. # NOTE(edoakes): in gRPC, it's not guaranteed that the initial metadata sent
  1645. # by the server will be delivered for a stream with no messages. Therefore,
  1646. # we send a dummy message here to ensure it is populated in every case.
  1647. return b""
  1648. result = await result_gen.__anext__()
  1649. # Consume the result generator to ensure all request operations are completed.
  1650. async for _ in result_gen:
  1651. pass
  1652. if request_metadata.is_grpc_request:
  1653. result = (request_metadata.grpc_context, result.SerializeToString())
  1654. return result
  1655. @_wrap_grpc_call
  1656. async def HandleRequestWithRejectionStreaming(
  1657. self,
  1658. context: grpc.aio.ServicerContext,
  1659. request_metadata: RequestMetadata,
  1660. *request_args,
  1661. **request_kwargs,
  1662. ) -> AsyncGenerator[Any, None]:
  1663. """gRPC entrypoint for all streaming requests with strict max_ongoing_requests enforcement
  1664. This generator yields a system message indicating if the request was accepted,
  1665. then the actual response(s).
  1666. If an exception occurred while processing the request, whether it's a user
  1667. exception or an error intentionally raised by Serve, it will be returned as
  1668. a gRPC response instead of raised directly.
  1669. """
  1670. result_gen = self.handle_request_with_rejection(
  1671. request_metadata, *request_args, **request_kwargs
  1672. )
  1673. queue_len_info: ReplicaQueueLengthInfo = await result_gen.__anext__()
  1674. await context.send_initial_metadata(
  1675. [
  1676. ("accepted", str(int(queue_len_info.accepted))),
  1677. ("num_ongoing_requests", str(queue_len_info.num_ongoing_requests)),
  1678. ]
  1679. )
  1680. if not queue_len_info.accepted:
  1681. # NOTE(edoakes): in gRPC, it's not guaranteed that the initial metadata sent
  1682. # by the server will be delivered for a stream with no messages. Therefore,
  1683. # we send a dummy message here to ensure it is populated in every case.
  1684. yield b""
  1685. return
  1686. async for result in result_gen:
  1687. if request_metadata.is_grpc_request:
  1688. result = (request_metadata.grpc_context, result.SerializeToString())
  1689. yield result
  1690. async def _dataplane_health_check(self) -> Tuple[bool, str]:
  1691. healthy, message = True, HEALTHY_MESSAGE
  1692. if self._shutting_down:
  1693. healthy = False
  1694. message = "DRAINING"
  1695. elif not self._healthy:
  1696. healthy = False
  1697. message = "UNHEALTHY"
  1698. return healthy, message
  1699. async def _direct_ingress_unary_unary(
  1700. self,
  1701. service_method: str,
  1702. request_proto: Any,
  1703. context: grpc._cython.cygrpc._ServicerContext,
  1704. ) -> bytes:
  1705. if service_method == "/ray.serve.RayServeAPIService/Healthz":
  1706. healthy, message = await self._dataplane_health_check()
  1707. context.set_code(
  1708. grpc.StatusCode.OK if healthy else grpc.StatusCode.UNAVAILABLE
  1709. )
  1710. context.set_details(message)
  1711. return HealthzResponse(message=message).SerializeToString()
  1712. if service_method == "/ray.serve.RayServeAPIService/ListApplications":
  1713. # NOTE(edoakes): ListApplications may be used for health checking.
  1714. healthy, message = await self._dataplane_health_check()
  1715. context.set_code(
  1716. grpc.StatusCode.OK if healthy else grpc.StatusCode.UNAVAILABLE
  1717. )
  1718. context.set_details(message)
  1719. # ListApplications returns only the app name the replica is serving.
  1720. application_names = [self._deployment_id.app_name]
  1721. return ListApplicationsResponse(
  1722. application_names=application_names
  1723. ).SerializeToString()
  1724. request_id = generate_request_id()
  1725. c = RayServegRPCContext(context)
  1726. c.set_trailing_metadata([("request_id", request_id)])
  1727. request_metadata = RequestMetadata(
  1728. # TODO: pick up the request ID from gRPC initial metadata.
  1729. request_id=request_id,
  1730. internal_request_id=generate_request_id(),
  1731. call_method=service_method.split("/")[-1],
  1732. _request_protocol=RequestProtocol.GRPC,
  1733. grpc_context=c,
  1734. app_name=self._deployment_id.app_name,
  1735. # TODO(edoakes): populate this.
  1736. multiplexed_model_id="",
  1737. route=self._deployment_id.app_name,
  1738. tracing_context=None,
  1739. is_streaming=False,
  1740. is_direct_ingress=True,
  1741. )
  1742. if not self._can_accept_request(request_metadata):
  1743. status = ResponseStatus(
  1744. code=grpc.StatusCode.RESOURCE_EXHAUSTED,
  1745. message="Request dropped due to backpressure",
  1746. )
  1747. set_grpc_code_and_details(context, status)
  1748. return
  1749. method_info = self._user_callable_wrapper.get_user_method_info(
  1750. request_metadata.call_method
  1751. )
  1752. request_args = (request_proto,)
  1753. request_kwargs = (
  1754. {GRPC_CONTEXT_ARG_NAME: request_metadata.grpc_context}
  1755. if method_info.takes_grpc_context_kwarg
  1756. else {}
  1757. )
  1758. async def call_unary():
  1759. yield await self._user_callable_wrapper.call_user_method(
  1760. request_metadata, request_args, request_kwargs
  1761. )
  1762. with self._wrap_request(request_metadata):
  1763. self._num_queued_requests += 1
  1764. async with self._start_request(request_metadata):
  1765. self._num_queued_requests -= 1
  1766. # Use the generic disconnect detecting wrapper
  1767. result_gen = call_unary()
  1768. replica_response_generator = ReplicaResponseGenerator(
  1769. result_gen,
  1770. timeout_s=self._grpc_options.request_timeout_s,
  1771. )
  1772. try:
  1773. result = await replica_response_generator.__anext__()
  1774. c._set_on_grpc_context(context)
  1775. status = ResponseStatus(code=grpc.StatusCode.OK)
  1776. # NOTE(edoakes): we need to fully consume the generator otherwise the
  1777. # finalizers that run after the `yield` statement won't run. There might
  1778. # be a cleaner way to structure this.
  1779. try:
  1780. await replica_response_generator.__anext__()
  1781. except StopAsyncIteration:
  1782. pass
  1783. except BaseException as e:
  1784. # For gRPC requests, wrap exception with user-set status code
  1785. e = self._maybe_wrap_grpc_exception(e, request_metadata)
  1786. status = get_grpc_response_status(
  1787. e,
  1788. self._grpc_options.request_timeout_s,
  1789. request_metadata.request_id,
  1790. )
  1791. return
  1792. finally:
  1793. set_grpc_code_and_details(context, status)
  1794. return result.SerializeToString()
  1795. async def _direct_ingress_unary_stream(
  1796. self,
  1797. service_method: str,
  1798. request: Any,
  1799. context: grpc._cython.cygrpc._ServicerContext,
  1800. ):
  1801. raise NotImplementedError("unary_stream not implemented.")
  1802. def _direct_ingress_service_handler_factory(
  1803. self, service_method: str, stream: bool
  1804. ) -> Callable:
  1805. if stream:
  1806. async def handler(*args, **kwargs):
  1807. return await self._direct_ingress_unary_stream(
  1808. service_method, *args, **kwargs
  1809. )
  1810. else:
  1811. async def handler(*args, **kwargs):
  1812. return await self._direct_ingress_unary_unary(
  1813. service_method, *args, **kwargs
  1814. )
  1815. return handler
  1816. def _determine_http_route(self, scope: Scope) -> str:
  1817. # Default to route prefix for consistency with non-DI mode
  1818. route = self._route_prefix
  1819. if self._user_callable_asgi_app is not None:
  1820. try:
  1821. matched_route = get_asgi_route_name(self._user_callable_asgi_app, scope)
  1822. if matched_route is not None:
  1823. route = matched_route
  1824. except Exception:
  1825. # If route matching fails, keep the route prefix
  1826. pass
  1827. return route
  1828. def _parse_request_timeout(self, headers: Dict[str, str]) -> Optional[float]:
  1829. """Gets the desired request timeout from the headers.
  1830. If the header is missing or invalid, returns the default request timeout
  1831. from HttpOptions. If the header is non-positive, timeout is disabled.
  1832. """
  1833. header_name = SERVE_HTTP_REQUEST_TIMEOUT_S_HEADER.encode("utf-8")
  1834. if header_name not in headers:
  1835. return self._http_options.request_timeout_s
  1836. value = headers.get(header_name).decode("utf-8")
  1837. try:
  1838. timeout = float(value)
  1839. if timeout > 0:
  1840. return timeout
  1841. return None
  1842. except ValueError:
  1843. return self._http_options.request_timeout_s
  1844. async def _direct_ingress_asgi(
  1845. self,
  1846. scope: Scope,
  1847. receive: Receive,
  1848. send: Send,
  1849. ):
  1850. # NOTE(edoakes): it's important to only start the replica server after the
  1851. # constructor runs because we are using SO_REUSEPORT. We don't want a new
  1852. # replica to start handling connections until it's ready to serve traffic.
  1853. #
  1854. # This can be loosened to listen on the port but fail health checks once we no
  1855. # longer rely on SO_REUSEPORT.
  1856. assert (
  1857. self._user_callable_initialized
  1858. ), "Replica server should only be started *after* the replica is initialized."
  1859. if self._route_prefix and self._route_prefix != "/":
  1860. scope["root_path"] = self._route_prefix
  1861. start_time = time.time()
  1862. method = scope.get("method", "WS").upper()
  1863. route = scope.get("path", "")
  1864. # Handle health check or routes request.
  1865. if route in ["/-/healthz", "/-/routes"]:
  1866. healthy, message = await self._dataplane_health_check()
  1867. status_code = 200 if healthy else 503
  1868. if route == "/-/routes" and healthy:
  1869. # routes endpoint returns only the route prefix andapp name the replica is serving.
  1870. message = {
  1871. self._route_prefix: self._deployment_id.app_name,
  1872. }
  1873. for msg in convert_object_to_asgi_messages(
  1874. message,
  1875. status_code=status_code,
  1876. ):
  1877. await send(msg)
  1878. latency_ms = (time.time() - start_time) * 1000.0
  1879. self._metrics_manager.record_ingress_request_metrics(
  1880. protocol=RequestProtocol.HTTP,
  1881. method=method,
  1882. route=route,
  1883. app_name=self._deployment_id.app_name,
  1884. deployment_name=self._deployment_id.name,
  1885. latency_ms=latency_ms,
  1886. was_error=not healthy,
  1887. status_code=str(status_code),
  1888. )
  1889. return
  1890. # If the HTTP path does not match the deployment route prefix,
  1891. # it is invalid and we should not serve it.
  1892. if not route.startswith(self._route_prefix):
  1893. for msg in convert_object_to_asgi_messages(
  1894. f"Path '{route}' not found. "
  1895. "Ping http://.../-/routes for available routes.",
  1896. status_code=404,
  1897. ):
  1898. await send(msg)
  1899. return
  1900. headers = dict(scope["headers"])
  1901. request_id = (
  1902. headers.get(SERVE_HTTP_REQUEST_ID_HEADER.encode("utf-8")).decode("utf-8")
  1903. or generate_request_id()
  1904. )
  1905. request_disconnect_disabled = (
  1906. headers.get(
  1907. SERVE_HTTP_REQUEST_DISCONNECT_DISABLED_HEADER.encode("utf-8"), b"?0"
  1908. ).decode("utf-8")
  1909. ) == "?1"
  1910. request_timeout_s = self._parse_request_timeout(headers)
  1911. request_metadata = RequestMetadata(
  1912. request_id=request_id,
  1913. internal_request_id=generate_request_id(),
  1914. call_method="__call__",
  1915. route=self._determine_http_route(scope),
  1916. app_name=self._deployment_id.app_name,
  1917. # TODO(edoakes): populate the multiplexed model ID.
  1918. multiplexed_model_id="",
  1919. is_streaming=True,
  1920. _request_protocol=RequestProtocol.HTTP,
  1921. tracing_context=None,
  1922. _http_method=scope.get("method", "WS"),
  1923. is_direct_ingress=True,
  1924. )
  1925. if not self._can_accept_request(request_metadata):
  1926. # NOTE(abrar): its possible that we drop more requests than actual max_queued_requests
  1927. # because between incrementing and decrementing the queued requests, we yield to the event loop.
  1928. for msg in convert_object_to_asgi_messages(
  1929. "Request dropped due to backpressure",
  1930. status_code=503,
  1931. ):
  1932. await send(msg)
  1933. return
  1934. # Optimization: we can avoid creating an async receive task if the client
  1935. # has disabled handling disconnects for this request.
  1936. if request_disconnect_disabled:
  1937. receive_proxy = receive
  1938. receive_task = None
  1939. else:
  1940. receive_proxy = ASGIDIReceiveProxy(
  1941. scope, receive, self._user_callable_wrapper.event_loop
  1942. )
  1943. receive_task = receive_proxy.fetch_until_disconnect_task()
  1944. response_started = False
  1945. response_finished = False
  1946. first_message_peeked = False
  1947. with self._wrap_request(request_metadata) as status_code_callback:
  1948. self._num_queued_requests += 1
  1949. async def send_user_message(msg: Dict):
  1950. nonlocal response_started
  1951. nonlocal response_finished
  1952. nonlocal first_message_peeked
  1953. if not first_message_peeked:
  1954. first_message_peeked = True
  1955. if msg["type"] == "http.response.start":
  1956. status_code_callback(str(msg["status"]))
  1957. await send(msg)
  1958. response_started = True
  1959. if msg.get("more_body") is False:
  1960. response_finished = True
  1961. async def call_asgi():
  1962. async with self._start_request(request_metadata):
  1963. self._num_queued_requests -= 1
  1964. if (
  1965. not self._user_callable_wrapper._run_user_code_in_separate_thread
  1966. ):
  1967. user_method_info = (
  1968. self._user_callable_wrapper.get_user_method_info(
  1969. request_metadata.call_method
  1970. )
  1971. )
  1972. # `_call_http_entrypoint` will have already called
  1973. # `send_user_message`, so the ASGI messages will have
  1974. # already been sent back to the client.
  1975. await self._user_callable_wrapper._call_http_entrypoint(
  1976. user_method_info, scope, receive_proxy, send_user_message
  1977. )
  1978. else:
  1979. async for asgi_messages in self._user_callable_wrapper.call_http_entrypoint(
  1980. request_metadata, status_code_callback, scope, receive_proxy
  1981. ):
  1982. for message in asgi_messages:
  1983. await send_user_message(message)
  1984. # Optimization: if Serve doesn't need to handle disconnects and
  1985. # timeouts for this request, we can avoid event loop overhead by
  1986. # directly awaiting the user code.
  1987. if receive_task is None and request_timeout_s is None:
  1988. return await call_asgi()
  1989. # Otherwise, we'd always need the call_asgi() task.
  1990. request_task = asyncio.create_task(call_asgi())
  1991. tasks = [request_task]
  1992. if receive_task is not None:
  1993. tasks.append(receive_task)
  1994. done, _ = await asyncio.wait(
  1995. tasks,
  1996. timeout=request_timeout_s,
  1997. return_when=asyncio.FIRST_COMPLETED,
  1998. )
  1999. # NOTE(zcin): it's possible that the request task has finished sending
  2000. # all ASGI messages, but the task is suspended and before it can fully
  2001. # complete, the client has sent a disconnect message after the request
  2002. # is completed. That is why we check for `response_finished` here.
  2003. if request_task in done or response_finished:
  2004. if receive_task is not None:
  2005. receive_task.cancel()
  2006. await request_task
  2007. elif receive_task in done:
  2008. request_task.cancel()
  2009. status_code_callback("499")
  2010. if not response_started:
  2011. msg = (
  2012. f"Client for request {request_id} disconnected, "
  2013. "cancelling request."
  2014. )
  2015. await send_http_response(msg, 499, send)
  2016. raise asyncio.CancelledError
  2017. else:
  2018. request_task.cancel()
  2019. status_code_callback("408")
  2020. if not response_started:
  2021. msg = (
  2022. f"Request {request_id} timed out after "
  2023. f"{self._http_options.request_timeout_s}s."
  2024. )
  2025. await send_http_response(msg, 408, send)
  2026. raise asyncio.CancelledError
  2027. async def perform_graceful_shutdown(self):
  2028. if (
  2029. RAY_SERVE_ENABLE_DIRECT_INGRESS
  2030. and self._ingress
  2031. and self._user_callable_initialized
  2032. ):
  2033. # In direct ingress mode, we need to wait at least
  2034. # RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S to give external load
  2035. # balancers (e.g., ALB) time to deregister the replica, in addition to
  2036. # waiting for requests to drain.
  2037. await asyncio.gather(
  2038. asyncio.sleep(RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S),
  2039. super().perform_graceful_shutdown(),
  2040. )
  2041. else:
  2042. await super().perform_graceful_shutdown()
  2043. # Cancel direct ingress HTTP/gRPC server tasks if they exist.
  2044. if self._direct_ingress_http_server_task:
  2045. self._direct_ingress_http_server_task.cancel()
  2046. if self._direct_ingress_grpc_server_task:
  2047. self._direct_ingress_grpc_server_task.cancel()
  2048. class ReplicaActor:
  2049. """Actor definition for replicas of Ray Serve deployments.
  2050. This class defines the interface that the controller and deployment handles
  2051. (i.e., from proxies and other replicas) use to interact with a replica.
  2052. All interaction with the user-provided callable is done via the
  2053. `UserCallableWrapper` class.
  2054. """
  2055. async def __init__(
  2056. self,
  2057. replica_id: ReplicaID,
  2058. serialized_deployment_def: bytes,
  2059. serialized_init_args: bytes,
  2060. serialized_init_kwargs: bytes,
  2061. deployment_config_proto_bytes: bytes,
  2062. version: DeploymentVersion,
  2063. ingress: bool,
  2064. route_prefix: str,
  2065. ):
  2066. deployment_config = DeploymentConfig.from_proto_bytes(
  2067. deployment_config_proto_bytes
  2068. )
  2069. deployment_def = cloudpickle.loads(serialized_deployment_def)
  2070. if isinstance(deployment_def, str):
  2071. deployment_def = _load_deployment_def_from_import_path(deployment_def)
  2072. self._replica_impl: ReplicaBase = create_replica_impl(
  2073. replica_id=replica_id,
  2074. deployment_def=deployment_def,
  2075. init_args=cloudpickle.loads(serialized_init_args),
  2076. init_kwargs=cloudpickle.loads(serialized_init_kwargs),
  2077. deployment_config=deployment_config,
  2078. version=version,
  2079. ingress=ingress,
  2080. route_prefix=route_prefix,
  2081. )
  2082. def push_proxy_handle(self, handle: ActorHandle):
  2083. # NOTE(edoakes): it's important to call a method on the proxy handle to
  2084. # initialize its state in the C++ core worker.
  2085. handle.pong.remote()
  2086. def get_num_ongoing_requests(self) -> int:
  2087. """Fetch the number of ongoing requests at this replica (queue length).
  2088. This runs on a separate thread (using a Ray concurrency group) so it will
  2089. not be blocked by user code.
  2090. """
  2091. return self._replica_impl.get_num_ongoing_requests()
  2092. async def is_allocated(self) -> str:
  2093. """poke the replica to check whether it's alive.
  2094. When calling this method on an ActorHandle, it will complete as
  2095. soon as the actor has started running. We use this mechanism to
  2096. detect when a replica has been allocated a worker slot.
  2097. At this time, the replica can transition from PENDING_ALLOCATION
  2098. to PENDING_INITIALIZATION startup state.
  2099. Returns:
  2100. The PID, actor ID, node ID, node IP, and log filepath id of the replica.
  2101. """
  2102. return (
  2103. os.getpid(),
  2104. ray.get_runtime_context().get_actor_id(),
  2105. ray.get_runtime_context().get_worker_id(),
  2106. ray.get_runtime_context().get_node_id(),
  2107. ray.util.get_node_ip_address(),
  2108. ray.util.get_node_instance_id(),
  2109. get_component_logger_file_path(),
  2110. )
  2111. def list_outbound_deployments(self) -> Optional[List[DeploymentID]]:
  2112. return self._replica_impl.list_outbound_deployments()
  2113. async def initialize_and_get_metadata(
  2114. self, deployment_config: DeploymentConfig = None, rank: ReplicaRank = None
  2115. ) -> ReplicaMetadata:
  2116. """Handles initializing the replica.
  2117. Returns: 5-tuple containing
  2118. 1. DeploymentConfig of the replica
  2119. 2. DeploymentVersion of the replica
  2120. 3. Initialization duration in seconds
  2121. 4. Port
  2122. 5. FastAPI `docs_path`, if relevant (i.e. this is an ingress deployment integrated with FastAPI).
  2123. """
  2124. # Unused `_after` argument is for scheduling: passing an ObjectRef
  2125. # allows delaying this call until after the `_after` call has returned.
  2126. await self._replica_impl.initialize(deployment_config, rank)
  2127. return self._replica_impl.get_metadata()
  2128. async def check_health(self):
  2129. await self._replica_impl.check_health()
  2130. async def record_routing_stats(self) -> Dict[str, Any]:
  2131. return await self._replica_impl.record_routing_stats()
  2132. async def reconfigure(
  2133. self, deployment_config, rank: ReplicaRank, route_prefix: Optional[str] = None
  2134. ) -> ReplicaMetadata:
  2135. await self._replica_impl.reconfigure(deployment_config, rank, route_prefix)
  2136. return self._replica_impl.get_metadata()
  2137. def _preprocess_request_args(
  2138. self,
  2139. pickled_request_metadata: bytes,
  2140. request_args: Tuple[Any],
  2141. ) -> Tuple[RequestMetadata, Tuple[Any]]:
  2142. request_metadata = pickle.loads(pickled_request_metadata)
  2143. if request_metadata.is_http_request or request_metadata.is_grpc_request:
  2144. request_args = (pickle.loads(request_args[0]),)
  2145. return request_metadata, request_args
  2146. async def handle_request(
  2147. self,
  2148. pickled_request_metadata: bytes,
  2149. *request_args,
  2150. **request_kwargs,
  2151. ) -> Tuple[bytes, Any]:
  2152. """Entrypoint for `stream=False` calls."""
  2153. request_metadata, request_args = self._preprocess_request_args(
  2154. pickled_request_metadata, request_args
  2155. )
  2156. result = await self._replica_impl.handle_request(
  2157. request_metadata, *request_args, **request_kwargs
  2158. )
  2159. if request_metadata.is_grpc_request:
  2160. result = (request_metadata.grpc_context, result.SerializeToString())
  2161. return result
  2162. async def handle_request_streaming(
  2163. self,
  2164. pickled_request_metadata: bytes,
  2165. *request_args,
  2166. **request_kwargs,
  2167. ) -> AsyncGenerator[Any, None]:
  2168. """Generator that is the entrypoint for all `stream=True` handle calls."""
  2169. request_metadata, request_args = self._preprocess_request_args(
  2170. pickled_request_metadata, request_args
  2171. )
  2172. async for result in self._replica_impl.handle_request_streaming(
  2173. request_metadata, *request_args, **request_kwargs
  2174. ):
  2175. if request_metadata.is_grpc_request:
  2176. result = (request_metadata.grpc_context, result.SerializeToString())
  2177. yield result
  2178. async def handle_request_with_rejection(
  2179. self,
  2180. pickled_request_metadata: bytes,
  2181. *request_args,
  2182. **request_kwargs,
  2183. ) -> AsyncGenerator[Any, None]:
  2184. """Entrypoint for all requests with strict max_ongoing_requests enforcement.
  2185. The first response from this generator is always a system message indicating
  2186. if the request was accepted (the replica has capacity for the request) or
  2187. rejected (the replica is already at max_ongoing_requests).
  2188. For non-streaming requests, there will only be one more message, the unary
  2189. result of the user request handler.
  2190. For streaming requests, the subsequent messages will be the results of the
  2191. user request handler (which must be a generator).
  2192. """
  2193. request_metadata, request_args = self._preprocess_request_args(
  2194. pickled_request_metadata, request_args
  2195. )
  2196. async for result in self._replica_impl.handle_request_with_rejection(
  2197. request_metadata, *request_args, **request_kwargs
  2198. ):
  2199. if isinstance(result, ReplicaQueueLengthInfo):
  2200. yield pickle.dumps(result)
  2201. else:
  2202. if request_metadata.is_grpc_request:
  2203. result = (request_metadata.grpc_context, result.SerializeToString())
  2204. yield result
  2205. async def handle_request_from_java(
  2206. self,
  2207. proto_request_metadata: bytes,
  2208. *request_args,
  2209. **request_kwargs,
  2210. ) -> Any:
  2211. from ray.serve.generated.serve_pb2 import (
  2212. RequestMetadata as RequestMetadataProto,
  2213. )
  2214. proto = RequestMetadataProto.FromString(proto_request_metadata)
  2215. request_metadata: RequestMetadata = RequestMetadata(
  2216. request_id=proto.request_id,
  2217. internal_request_id=proto.internal_request_id,
  2218. call_method=proto.call_method,
  2219. multiplexed_model_id=proto.multiplexed_model_id,
  2220. route=proto.route,
  2221. )
  2222. return await self._replica_impl.handle_request(
  2223. request_metadata, *request_args, **request_kwargs
  2224. )
  2225. async def perform_graceful_shutdown(self):
  2226. await self._replica_impl.perform_graceful_shutdown()
  2227. @dataclass
  2228. class UserMethodInfo:
  2229. """Wrapper for a user method and its relevant metadata."""
  2230. callable: Callable
  2231. name: str
  2232. is_asgi_app: bool
  2233. takes_any_args: bool
  2234. takes_grpc_context_kwarg: bool
  2235. @classmethod
  2236. def from_callable(cls, c: Callable, *, is_asgi_app: bool) -> "UserMethodInfo":
  2237. params = inspect.signature(c).parameters
  2238. return cls(
  2239. callable=c,
  2240. name=c.__name__,
  2241. is_asgi_app=is_asgi_app,
  2242. takes_any_args=len(params) > 0,
  2243. takes_grpc_context_kwarg=GRPC_CONTEXT_ARG_NAME in params,
  2244. )
  2245. class UserCallableWrapper:
  2246. """Wraps a user-provided callable that is used to handle requests to a replica."""
  2247. service_unavailable_exceptions = (BackPressureError, DeploymentUnavailableError)
  2248. def __init__(
  2249. self,
  2250. deployment_def: Callable,
  2251. init_args: Tuple,
  2252. init_kwargs: Dict,
  2253. *,
  2254. deployment_id: DeploymentID,
  2255. run_sync_methods_in_threadpool: bool,
  2256. run_user_code_in_separate_thread: bool,
  2257. local_testing_mode: bool,
  2258. deployment_config: DeploymentConfig,
  2259. actor_id: str,
  2260. ray_actor_options: Optional[Dict] = None,
  2261. ):
  2262. if not (inspect.isfunction(deployment_def) or inspect.isclass(deployment_def)):
  2263. raise TypeError(
  2264. "deployment_def must be a function or class. Instead, its type was "
  2265. f"{type(deployment_def)}."
  2266. )
  2267. self._deployment_def = deployment_def
  2268. self._init_args = init_args
  2269. self._init_kwargs = init_kwargs
  2270. self._is_function = inspect.isfunction(deployment_def)
  2271. self._deployment_id = deployment_id
  2272. self._local_testing_mode = local_testing_mode
  2273. self._destructor_called = False
  2274. self._run_sync_methods_in_threadpool = run_sync_methods_in_threadpool
  2275. self._run_user_code_in_separate_thread = run_user_code_in_separate_thread
  2276. self._warned_about_sync_method_change = False
  2277. self._cached_user_method_info: Dict[str, UserMethodInfo] = {}
  2278. # This is for performance optimization https://docs.python.org/3/howto/logging.html#optimization
  2279. self._is_enabled_for_debug = logger.isEnabledFor(logging.DEBUG)
  2280. # Will be populated in `initialize_callable`.
  2281. self._callable = None
  2282. self._deployment_config = deployment_config
  2283. self._ray_actor_options = ray_actor_options or {}
  2284. self._user_code_threadpool: Optional[
  2285. concurrent.futures.ThreadPoolExecutor
  2286. ] = None
  2287. if self._run_user_code_in_separate_thread:
  2288. # All interactions with user code run on this loop to avoid blocking the
  2289. # replica's main event loop.
  2290. self._user_code_event_loop: asyncio.AbstractEventLoop = (
  2291. asyncio.new_event_loop()
  2292. )
  2293. # Start event loop monitoring for the user code event loop.
  2294. # We create the monitor here but start it inside the thread function
  2295. # so the task is created on the correct thread.
  2296. self._user_code_loop_monitor = EventLoopMonitor(
  2297. component=EventLoopMonitor.COMPONENT_REPLICA,
  2298. loop_type=EventLoopMonitor.LOOP_TYPE_USER_CODE,
  2299. actor_id=actor_id,
  2300. extra_tags={
  2301. "deployment": self._deployment_id.name,
  2302. "application": self._deployment_id.app_name,
  2303. },
  2304. )
  2305. def _run_user_code_event_loop():
  2306. # Required so that calls to get the current running event loop work
  2307. # properly in user code.
  2308. asyncio.set_event_loop(self._user_code_event_loop)
  2309. self._configure_user_code_threadpool()
  2310. # Start monitoring before run_forever so the task is scheduled.
  2311. self._user_code_loop_monitor.start(self._user_code_event_loop)
  2312. self._user_code_event_loop.run_forever()
  2313. self._user_code_event_loop_thread = threading.Thread(
  2314. daemon=True,
  2315. target=_run_user_code_event_loop,
  2316. )
  2317. self._user_code_event_loop_thread.start()
  2318. else:
  2319. self._user_code_event_loop = asyncio.get_running_loop()
  2320. self._user_code_loop_monitor = None
  2321. self._configure_user_code_threadpool()
  2322. @property
  2323. def event_loop(self) -> asyncio.AbstractEventLoop:
  2324. return self._user_code_event_loop
  2325. def _get_user_code_threadpool_max_workers(self) -> Optional[int]:
  2326. num_cpus = self._ray_actor_options.get("num_cpus")
  2327. if num_cpus is None:
  2328. return None
  2329. # Mirror ThreadPoolExecutor default behavior while respecting num_cpus.
  2330. return min(32, max(1, int(math.ceil(num_cpus))) + 4)
  2331. def _configure_user_code_threadpool(self) -> None:
  2332. max_workers = self._get_user_code_threadpool_max_workers()
  2333. if max_workers is None:
  2334. return
  2335. self._user_code_threadpool = concurrent.futures.ThreadPoolExecutor(
  2336. max_workers=max_workers
  2337. )
  2338. self._user_code_event_loop.set_default_executor(self._user_code_threadpool)
  2339. def _run_user_code(f: Callable) -> Callable:
  2340. """Decorator to run a coroutine method on the user code event loop.
  2341. The method will be modified to be a sync function that returns a
  2342. `asyncio.Future` if user code is running in a separate event loop.
  2343. Otherwise, it will return the coroutine directly.
  2344. """
  2345. assert inspect.iscoroutinefunction(
  2346. f
  2347. ), "_run_user_code can only be used on coroutine functions."
  2348. @functools.wraps(f)
  2349. def wrapper(self, *args, **kwargs) -> Any:
  2350. coro = f(self, *args, **kwargs)
  2351. if self._run_user_code_in_separate_thread:
  2352. fut = asyncio.run_coroutine_threadsafe(coro, self._user_code_event_loop)
  2353. if self._local_testing_mode:
  2354. return fut
  2355. return asyncio.wrap_future(fut)
  2356. else:
  2357. return coro
  2358. return wrapper
  2359. @_run_user_code
  2360. async def set_sync_method_threadpool_limit(self, limit: int):
  2361. # NOTE(edoakes): the limit is thread local, so this must
  2362. # be run on the user code event loop.
  2363. to_thread.current_default_thread_limiter().total_tokens = limit
  2364. def get_user_method_info(self, method_name: str) -> UserMethodInfo:
  2365. """Get UserMethodInfo for the provided call method name.
  2366. This method is cached to avoid repeated expensive calls to `inspect.signature`.
  2367. """
  2368. if method_name in self._cached_user_method_info:
  2369. return self._cached_user_method_info[method_name]
  2370. if self._is_function:
  2371. user_method = self._callable
  2372. elif hasattr(self._callable, method_name):
  2373. user_method = getattr(self._callable, method_name)
  2374. else:
  2375. # Filter to methods that don't start with '__' prefix.
  2376. def callable_method_filter(attr):
  2377. if attr.startswith("__"):
  2378. return False
  2379. elif not callable(getattr(self._callable, attr)):
  2380. return False
  2381. return True
  2382. methods = list(filter(callable_method_filter, dir(self._callable)))
  2383. raise RayServeException(
  2384. f"Tried to call a method '{method_name}' "
  2385. "that does not exist. Available methods: "
  2386. f"{methods}."
  2387. )
  2388. info = UserMethodInfo.from_callable(
  2389. user_method,
  2390. is_asgi_app=isinstance(self._callable, ASGIAppReplicaWrapper),
  2391. )
  2392. self._cached_user_method_info[method_name] = info
  2393. return info
  2394. async def _send_user_result_over_asgi(
  2395. self,
  2396. result: Any,
  2397. asgi_args: ASGIArgs,
  2398. ):
  2399. """Handle the result from user code and send it over the ASGI interface.
  2400. If the result is already a Response type, it is sent directly. Otherwise, it
  2401. is converted to a custom Response type that handles serialization for
  2402. common Python objects.
  2403. """
  2404. scope, receive, send = asgi_args.to_args_tuple()
  2405. if isinstance(result, starlette.responses.Response):
  2406. await result(scope, receive, send)
  2407. else:
  2408. await Response(result).send(scope, receive, send)
  2409. async def _call_func_or_gen(
  2410. self,
  2411. callable: Callable,
  2412. *,
  2413. args: Optional[Tuple[Any]] = None,
  2414. kwargs: Optional[Dict[str, Any]] = None,
  2415. is_streaming: bool = False,
  2416. generator_result_callback: Optional[Callable] = None,
  2417. run_sync_methods_in_threadpool_override: Optional[bool] = None,
  2418. ) -> Tuple[Any, bool]:
  2419. """Call the callable with the provided arguments.
  2420. This is a convenience wrapper that will work for `def`, `async def`,
  2421. generator, and async generator functions.
  2422. Returns the result and a boolean indicating if the result was a sync generator
  2423. that has already been consumed.
  2424. """
  2425. sync_gen_consumed = False
  2426. args = args if args is not None else tuple()
  2427. kwargs = kwargs if kwargs is not None else dict()
  2428. run_sync_in_threadpool = (
  2429. self._run_sync_methods_in_threadpool
  2430. if run_sync_methods_in_threadpool_override is None
  2431. else run_sync_methods_in_threadpool_override
  2432. )
  2433. is_sync_method = (
  2434. inspect.isfunction(callable) or inspect.ismethod(callable)
  2435. ) and not (
  2436. inspect.iscoroutinefunction(callable)
  2437. or inspect.isasyncgenfunction(callable)
  2438. )
  2439. if is_sync_method and run_sync_in_threadpool:
  2440. is_generator = inspect.isgeneratorfunction(callable)
  2441. if is_generator:
  2442. sync_gen_consumed = True
  2443. if not is_streaming:
  2444. # TODO(edoakes): make this check less redundant with the one in
  2445. # _handle_user_method_result.
  2446. raise TypeError(
  2447. f"Method '{callable.__name__}' returned a generator. "
  2448. "You must use `handle.options(stream=True)` to call "
  2449. "generators on a deployment."
  2450. )
  2451. def run_callable():
  2452. result = callable(*args, **kwargs)
  2453. if is_generator:
  2454. for r in result:
  2455. generator_result_callback(r)
  2456. result = None
  2457. return result
  2458. # NOTE(edoakes): we use anyio.to_thread here because it's what Starlette
  2459. # uses (and therefore FastAPI too). The max size of the threadpool is
  2460. # set to max_ongoing_requests in the replica wrapper.
  2461. # anyio.to_thread propagates ContextVars to the worker thread automatically.
  2462. result = await to_thread.run_sync(run_callable)
  2463. else:
  2464. if (
  2465. is_sync_method
  2466. and not self._warned_about_sync_method_change
  2467. and run_sync_methods_in_threadpool_override is None
  2468. ):
  2469. self._warned_about_sync_method_change = True
  2470. warnings.warn(
  2471. RAY_SERVE_RUN_SYNC_IN_THREADPOOL_WARNING.format(
  2472. method_name=callable.__name__,
  2473. )
  2474. )
  2475. result = callable(*args, **kwargs)
  2476. if inspect.iscoroutine(result):
  2477. result = await result
  2478. return result, sync_gen_consumed
  2479. @property
  2480. def user_callable(self) -> Optional[Callable]:
  2481. return self._callable
  2482. async def _initialize_asgi_callable(self) -> None:
  2483. self._callable: ASGIAppReplicaWrapper
  2484. app: Starlette = self._callable.app
  2485. # The reason we need to do this is because BackPressureError is a serve internal exception
  2486. # and FastAPI doesn't know how to handle it, so it treats it as a 500 error.
  2487. # With same reasoning, we are not handling TimeoutError because it's a generic exception
  2488. # the FastAPI knows how to handle. See https://www.starlette.io/exceptions/
  2489. def handle_exception(_: Request, exc: Exception):
  2490. return self.handle_exception(exc)
  2491. for exc in self.service_unavailable_exceptions:
  2492. app.add_exception_handler(exc, handle_exception)
  2493. await self._callable._run_asgi_lifespan_startup()
  2494. @_run_user_code
  2495. async def initialize_callable(self) -> Optional[ASGIApp]:
  2496. """Initialize the user callable.
  2497. If the callable is an ASGI app wrapper (e.g., using @serve.ingress), returns
  2498. the ASGI app object, which may be used *read only* by the caller.
  2499. """
  2500. if self._callable is not None:
  2501. raise RuntimeError("initialize_callable should only be called once.")
  2502. # This closure initializes user code and finalizes replica
  2503. # startup. By splitting the initialization step like this,
  2504. # we can already access this actor before the user code
  2505. # has finished initializing.
  2506. # The supervising state manager can then wait
  2507. # for allocation of this replica by using the `is_allocated`
  2508. # method. After that, it calls `reconfigure` to trigger
  2509. # user code initialization.
  2510. logger.info(
  2511. "Started initializing replica.",
  2512. extra={"log_to_stderr": False},
  2513. )
  2514. if self._is_function:
  2515. self._callable = self._deployment_def
  2516. else:
  2517. # This allows deployments to define an async __init__
  2518. # method (mostly used for testing).
  2519. self._callable = self._deployment_def.__new__(self._deployment_def)
  2520. await self._call_func_or_gen(
  2521. self._callable.__init__,
  2522. args=self._init_args,
  2523. kwargs=self._init_kwargs,
  2524. # Always run the constructor on the main user code thread.
  2525. run_sync_methods_in_threadpool_override=False,
  2526. )
  2527. if isinstance(self._callable, ASGIAppReplicaWrapper):
  2528. await self._initialize_asgi_callable()
  2529. if isinstance(self._callable, TaskConsumerWrapper):
  2530. self._callable.initialize_callable(
  2531. self._deployment_config.max_ongoing_requests
  2532. )
  2533. ServeUsageTag.NUM_REPLICAS_USING_ASYNCHRONOUS_INFERENCE.record("1")
  2534. self._user_health_check = getattr(self._callable, HEALTH_CHECK_METHOD, None)
  2535. self._user_record_routing_stats = getattr(
  2536. self._callable, REQUEST_ROUTING_STATS_METHOD, None
  2537. )
  2538. self._user_autoscaling_stats = getattr(
  2539. self._callable, "record_autoscaling_stats", None
  2540. )
  2541. logger.info(
  2542. "Finished initializing replica.",
  2543. extra={"log_to_stderr": False},
  2544. )
  2545. return (
  2546. self._callable.app
  2547. if isinstance(self._callable, ASGIAppReplicaWrapper)
  2548. else None
  2549. )
  2550. def _raise_if_not_initialized(self, method_name: str):
  2551. if self._callable is None:
  2552. raise RuntimeError(
  2553. f"`initialize_callable` must be called before `{method_name}`."
  2554. )
  2555. def call_user_health_check(self) -> Optional[concurrent.futures.Future]:
  2556. self._raise_if_not_initialized("call_user_health_check")
  2557. # If the user provided a health check, call it on the user code thread. If user
  2558. # code blocks the event loop the health check may time out.
  2559. #
  2560. # To avoid this issue for basic cases without a user-defined health check, skip
  2561. # interacting with the user callable entirely.
  2562. if self._user_health_check is not None:
  2563. return self._call_user_health_check()
  2564. return None
  2565. def call_user_record_routing_stats(self) -> Optional[concurrent.futures.Future]:
  2566. self._raise_if_not_initialized("call_user_record_routing_stats")
  2567. if self._user_record_routing_stats is not None:
  2568. return self._call_user_record_routing_stats()
  2569. return None
  2570. def call_record_autoscaling_stats(self) -> Optional[concurrent.futures.Future]:
  2571. self._raise_if_not_initialized("call_record_autoscaling_stats")
  2572. if self._user_autoscaling_stats is not None:
  2573. return self._call_user_autoscaling_stats()
  2574. return None
  2575. @_run_user_code
  2576. async def _call_user_health_check(self):
  2577. await self._call_func_or_gen(self._user_health_check)
  2578. @_run_user_code
  2579. async def _call_user_record_routing_stats(self) -> Dict[str, Any]:
  2580. result, _ = await self._call_func_or_gen(self._user_record_routing_stats)
  2581. return result
  2582. @_run_user_code
  2583. async def _call_user_autoscaling_stats(self) -> Dict[str, Union[int, float]]:
  2584. result, _ = await self._call_func_or_gen(self._user_autoscaling_stats)
  2585. return result
  2586. @_run_user_code
  2587. async def call_reconfigure(self, user_config: Optional[Any], rank: ReplicaRank):
  2588. self._raise_if_not_initialized("call_reconfigure")
  2589. # NOTE(edoakes): there is the possibility of a race condition in user code if
  2590. # they don't have any form of concurrency control between `reconfigure` and
  2591. # other methods. See https://github.com/ray-project/ray/pull/42159.
  2592. # NOTE(abrar): The only way to subscribe to rank changes is to provide some user config.
  2593. # We can relax this in the future as more use cases arise for rank. I am reluctant to
  2594. # introduce behavior change for a feature we might not need.
  2595. user_subscribed_to_rank = False
  2596. if not self._is_function and hasattr(self._callable, RECONFIGURE_METHOD):
  2597. reconfigure_method = getattr(self._callable, RECONFIGURE_METHOD)
  2598. params = inspect.signature(reconfigure_method).parameters
  2599. user_subscribed_to_rank = "rank" in params
  2600. if user_config is not None or user_subscribed_to_rank:
  2601. if self._is_function:
  2602. raise ValueError(
  2603. "deployment_def must be a class to use user_config or rank"
  2604. )
  2605. elif not hasattr(self._callable, RECONFIGURE_METHOD):
  2606. raise RayServeException(
  2607. "user_config or rank specified but deployment "
  2608. + self._deployment_id
  2609. + " missing "
  2610. + RECONFIGURE_METHOD
  2611. + " method"
  2612. )
  2613. kwargs = {}
  2614. if user_subscribed_to_rank:
  2615. # For backwards compatibility, only pass rank if it is an argument to the reconfigure method.
  2616. kwargs["rank"] = rank
  2617. await self._call_func_or_gen(
  2618. getattr(self._callable, RECONFIGURE_METHOD),
  2619. args=(user_config,),
  2620. kwargs=kwargs,
  2621. )
  2622. async def _handle_user_method_result(
  2623. self,
  2624. result: Any,
  2625. user_method_info: UserMethodInfo,
  2626. *,
  2627. is_streaming: bool,
  2628. is_http_request: bool,
  2629. sync_gen_consumed: bool,
  2630. generator_result_callback: Optional[Callable],
  2631. asgi_args: Optional[ASGIArgs],
  2632. ) -> Any:
  2633. """Postprocess the result of a user method.
  2634. User methods can be regular unary functions or return a sync or async generator.
  2635. This method will raise an exception if the result is not of the expected type
  2636. (e.g., non-generator for streaming requests or generator for unary requests).
  2637. Generator outputs will be written to the `generator_result_callback`.
  2638. Note that HTTP requests are an exception: they are *always* streaming requests,
  2639. but for ASGI apps (like FastAPI), the actual method will be a regular function
  2640. implementing the ASGI `__call__` protocol.
  2641. """
  2642. result_is_gen = inspect.isgenerator(result)
  2643. result_is_async_gen = inspect.isasyncgen(result)
  2644. if is_streaming:
  2645. if result_is_gen:
  2646. for r in result:
  2647. generator_result_callback(r)
  2648. elif result_is_async_gen:
  2649. async for r in result:
  2650. generator_result_callback(r)
  2651. elif is_http_request and not user_method_info.is_asgi_app:
  2652. # For the FastAPI codepath, the response has already been sent over
  2653. # ASGI, but for the vanilla deployment codepath we need to send it.
  2654. await self._send_user_result_over_asgi(result, asgi_args)
  2655. elif not is_http_request and not sync_gen_consumed:
  2656. # If a unary method is called with stream=True for anything EXCEPT
  2657. # an HTTP request, raise an error.
  2658. # HTTP requests are always streaming regardless of if the method
  2659. # returns a generator, because it's provided the result queue as its
  2660. # ASGI `send` interface to stream back results.
  2661. raise TypeError(
  2662. f"Called method '{user_method_info.name}' with "
  2663. "`handle.options(stream=True)` but it did not return a "
  2664. "generator."
  2665. )
  2666. else:
  2667. assert (
  2668. not is_http_request
  2669. ), "All HTTP requests go through the streaming codepath."
  2670. if result_is_gen or result_is_async_gen:
  2671. raise TypeError(
  2672. f"Method '{user_method_info.name}' returned a generator. "
  2673. "You must use `handle.options(stream=True)` to call "
  2674. "generators on a deployment."
  2675. )
  2676. return result
  2677. async def call_http_entrypoint(
  2678. self,
  2679. request_metadata: RequestMetadata,
  2680. status_code_callback: StatusCodeCallback,
  2681. scope: Scope,
  2682. receive: Receive,
  2683. ) -> Any:
  2684. result_queue = MessageQueue()
  2685. user_method_info = self.get_user_method_info(request_metadata.call_method)
  2686. if self._run_user_code_in_separate_thread:
  2687. # `asyncio.Event`s are not thread safe, so `call_soon_threadsafe` must be
  2688. # used to interact with the result queue from the user callable thread.
  2689. system_event_loop = asyncio.get_running_loop()
  2690. async def enqueue(item: Any):
  2691. system_event_loop.call_soon_threadsafe(result_queue.put_nowait, item)
  2692. call_future = self._call_http_entrypoint(
  2693. user_method_info, scope, receive, enqueue
  2694. )
  2695. else:
  2696. async def enqueue(item: Any):
  2697. result_queue.put_nowait(item)
  2698. call_future = asyncio.create_task(
  2699. self._call_http_entrypoint(user_method_info, scope, receive, enqueue)
  2700. )
  2701. first_message_peeked = False
  2702. async for messages in result_queue.fetch_messages_from_queue(call_future):
  2703. # HTTP (ASGI) messages are only consumed by the proxy so batch them
  2704. # and use vanilla pickle (we know it's safe because these messages
  2705. # only contain primitive Python types).
  2706. # Peek the first ASGI message to determine the status code.
  2707. if not first_message_peeked:
  2708. msg = messages[0]
  2709. first_message_peeked = True
  2710. if msg["type"] == "http.response.start":
  2711. # HTTP responses begin with exactly one
  2712. # "http.response.start" message containing the "status"
  2713. # field. Other response types like WebSockets may not.
  2714. status_code_callback(str(msg["status"]))
  2715. yield messages
  2716. @_run_user_code
  2717. async def _call_http_entrypoint(
  2718. self,
  2719. user_method_info: UserMethodInfo,
  2720. scope: Scope,
  2721. receive: Receive,
  2722. send: Send,
  2723. ) -> Any:
  2724. """Call an HTTP entrypoint.
  2725. `send` is used to communicate the results of streaming responses.
  2726. Raises any exception raised by the user code so it can be propagated as a
  2727. `RayTaskError`.
  2728. """
  2729. self._raise_if_not_initialized("_call_http_entrypoint")
  2730. if self._is_enabled_for_debug:
  2731. logger.debug(
  2732. f"Started executing request to method '{user_method_info.name}'.",
  2733. extra={"log_to_stderr": False, "serve_access_log": True},
  2734. )
  2735. if user_method_info.is_asgi_app:
  2736. request_args = (scope, receive, send)
  2737. elif not user_method_info.takes_any_args:
  2738. # Edge case to support empty HTTP handlers: don't pass the Request
  2739. # argument if the callable has no parameters.
  2740. request_args = tuple()
  2741. else:
  2742. # Non-FastAPI HTTP handlers take only the starlette `Request`.
  2743. request_args = (starlette.requests.Request(scope, receive, send),)
  2744. receive_task = None
  2745. try:
  2746. if hasattr(receive, "fetch_until_disconnect"):
  2747. receive_task = asyncio.create_task(receive.fetch_until_disconnect())
  2748. result, sync_gen_consumed = await self._call_func_or_gen(
  2749. user_method_info.callable,
  2750. args=request_args,
  2751. kwargs={},
  2752. is_streaming=True,
  2753. generator_result_callback=send,
  2754. )
  2755. final_result = await self._handle_user_method_result(
  2756. result,
  2757. user_method_info,
  2758. is_streaming=True,
  2759. is_http_request=True,
  2760. sync_gen_consumed=sync_gen_consumed,
  2761. generator_result_callback=send,
  2762. asgi_args=ASGIArgs(scope, receive, send),
  2763. )
  2764. if receive_task is not None and not receive_task.done():
  2765. receive_task.cancel()
  2766. return final_result
  2767. except Exception as e:
  2768. if not user_method_info.is_asgi_app:
  2769. response = self.handle_exception(e)
  2770. await self._send_user_result_over_asgi(
  2771. response, ASGIArgs(scope, receive, send)
  2772. )
  2773. if receive_task is not None and not receive_task.done():
  2774. receive_task.cancel()
  2775. raise
  2776. except asyncio.CancelledError:
  2777. if receive_task is not None and not receive_task.done():
  2778. # Do NOT cancel the receive task if the request has been
  2779. # cancelled, but the call is a batched call. This is
  2780. # because we cannot guarantee cancelling the batched
  2781. # call, so in the case that the call continues executing
  2782. # we should continue fetching data from the client.
  2783. if not hasattr(user_method_info.callable, "set_max_batch_size"):
  2784. receive_task.cancel()
  2785. raise
  2786. async def call_user_generator(
  2787. self,
  2788. request_metadata: RequestMetadata,
  2789. request_args: Tuple[Any],
  2790. request_kwargs: Dict[str, Any],
  2791. ) -> AsyncGenerator[Any, None]:
  2792. """Calls a user method for a streaming call and yields its results.
  2793. The user method is called in an asyncio `Task` and places its results on a
  2794. `result_queue`. This method pulls and yields from the `result_queue`.
  2795. """
  2796. if not self._run_user_code_in_separate_thread:
  2797. gen = await self._call_user_generator(
  2798. request_metadata, request_args, request_kwargs
  2799. )
  2800. async for result in gen:
  2801. yield result
  2802. else:
  2803. result_queue = MessageQueue()
  2804. # `asyncio.Event`s are not thread safe, so `call_soon_threadsafe` must be
  2805. # used to interact with the result queue from the user callable thread.
  2806. system_event_loop = asyncio.get_running_loop()
  2807. def _enqueue_thread_safe(item: Any):
  2808. system_event_loop.call_soon_threadsafe(result_queue.put_nowait, item)
  2809. call_future = self._call_user_generator(
  2810. request_metadata,
  2811. request_args,
  2812. request_kwargs,
  2813. enqueue=_enqueue_thread_safe,
  2814. )
  2815. async for messages in result_queue.fetch_messages_from_queue(call_future):
  2816. for msg in messages:
  2817. yield msg
  2818. @_run_user_code
  2819. async def _call_user_generator(
  2820. self,
  2821. request_metadata: RequestMetadata,
  2822. request_args: Tuple[Any],
  2823. request_kwargs: Dict[str, Any],
  2824. *,
  2825. enqueue: Optional[Callable] = None,
  2826. ) -> Optional[AsyncGenerator[Any, None]]:
  2827. """Call a user generator.
  2828. The `generator_result_callback` is used to communicate the results of generator
  2829. methods.
  2830. Raises any exception raised by the user code so it can be propagated as a
  2831. `RayTaskError`.
  2832. """
  2833. self._raise_if_not_initialized("_call_user_generator")
  2834. request_args = request_args if request_args is not None else tuple()
  2835. request_kwargs = request_kwargs if request_kwargs is not None else dict()
  2836. user_method_info = self.get_user_method_info(request_metadata.call_method)
  2837. callable = user_method_info.callable
  2838. is_sync_method = (
  2839. inspect.isfunction(callable) or inspect.ismethod(callable)
  2840. ) and not (
  2841. inspect.iscoroutinefunction(callable)
  2842. or inspect.isasyncgenfunction(callable)
  2843. )
  2844. if self._is_enabled_for_debug:
  2845. logger.debug(
  2846. f"Started executing request to method '{user_method_info.name}'.",
  2847. extra={"log_to_stderr": False, "serve_access_log": True},
  2848. )
  2849. async def _call_generator_async() -> AsyncGenerator[Any, None]:
  2850. gen = callable(*request_args, **request_kwargs)
  2851. if inspect.iscoroutine(gen):
  2852. gen = await gen
  2853. if inspect.isgenerator(gen):
  2854. for result in gen:
  2855. yield result
  2856. elif inspect.isasyncgen(gen):
  2857. async for result in gen:
  2858. yield result
  2859. else:
  2860. raise TypeError(
  2861. f"Called method '{user_method_info.name}' with "
  2862. "`handle.options(stream=True)` but it did not return a generator."
  2863. )
  2864. def _call_generator_sync():
  2865. gen = callable(*request_args, **request_kwargs)
  2866. if inspect.isgenerator(gen):
  2867. for result in gen:
  2868. enqueue(result)
  2869. else:
  2870. raise TypeError(
  2871. f"Called method '{user_method_info.name}' with "
  2872. "`handle.options(stream=True)` but it did not return a generator."
  2873. )
  2874. if enqueue and is_sync_method and self._run_sync_methods_in_threadpool:
  2875. await to_thread.run_sync(_call_generator_sync)
  2876. elif enqueue:
  2877. async def gen_coro_wrapper():
  2878. async for result in _call_generator_async():
  2879. enqueue(result)
  2880. await gen_coro_wrapper()
  2881. else:
  2882. return _call_generator_async()
  2883. @_run_user_code
  2884. async def call_user_method(
  2885. self,
  2886. request_metadata: RequestMetadata,
  2887. request_args: Tuple[Any],
  2888. request_kwargs: Dict[str, Any],
  2889. ) -> Any:
  2890. """Call a (unary) user method.
  2891. Raises any exception raised by the user code so it can be propagated as a
  2892. `RayTaskError`.
  2893. """
  2894. self._raise_if_not_initialized("call_user_method")
  2895. if self._is_enabled_for_debug:
  2896. logger.debug(
  2897. f"Started executing request to method '{request_metadata.call_method}'.",
  2898. extra={"log_to_stderr": False, "serve_access_log": True},
  2899. )
  2900. user_method_info = self.get_user_method_info(request_metadata.call_method)
  2901. result, _ = await self._call_func_or_gen(
  2902. user_method_info.callable,
  2903. args=request_args,
  2904. kwargs=request_kwargs,
  2905. is_streaming=False,
  2906. )
  2907. if inspect.isgenerator(result) or inspect.isasyncgen(result):
  2908. raise TypeError(
  2909. f"Method '{user_method_info.name}' returned a generator. "
  2910. "You must use `handle.options(stream=True)` to call "
  2911. "generators on a deployment."
  2912. )
  2913. return result
  2914. def handle_exception(self, exc: Exception):
  2915. if isinstance(exc, self.service_unavailable_exceptions):
  2916. return starlette.responses.Response(exc.message, status_code=503)
  2917. else:
  2918. return starlette.responses.Response(
  2919. "Internal Server Error", status_code=500
  2920. )
  2921. @_run_user_code
  2922. async def call_destructor(self):
  2923. """Explicitly call the `__del__` method of the user callable.
  2924. Calling this multiple times has no effect; only the first call will
  2925. actually call the destructor.
  2926. """
  2927. if self._callable is None:
  2928. logger.debug(
  2929. "This replica has not yet started running user code. "
  2930. "Skipping __del__."
  2931. )
  2932. return
  2933. # Only run the destructor once. This is safe because there is no `await` between
  2934. # checking the flag here and flipping it to `True` below.
  2935. if self._destructor_called:
  2936. return
  2937. self._destructor_called = True
  2938. try:
  2939. if hasattr(self._callable, "__del__"):
  2940. # Make sure to accept `async def __del__(self)` as well.
  2941. await self._call_func_or_gen(
  2942. self._callable.__del__,
  2943. # Always run the destructor on the main user callable thread.
  2944. run_sync_methods_in_threadpool_override=False,
  2945. )
  2946. if hasattr(self._callable, "__serve_multiplex_wrapper"):
  2947. await getattr(self._callable, "__serve_multiplex_wrapper").shutdown()
  2948. except Exception as e:
  2949. logger.exception(f"Exception during graceful shutdown of replica: {e}")
  2950. finally:
  2951. if self._user_code_threadpool is not None:
  2952. self._user_code_threadpool.shutdown(wait=False)