context.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. """
  2. This file stores global state for a Serve application. Deployment replicas
  3. can use this state to access metadata or the Serve controller.
  4. """
  5. import asyncio
  6. import contextvars
  7. import logging
  8. from collections import defaultdict
  9. from dataclasses import dataclass
  10. from typing import Callable, Dict, List, Optional
  11. import ray
  12. from ray.exceptions import RayActorError
  13. from ray.serve._private.client import ServeControllerClient
  14. from ray.serve._private.common import DeploymentID, ReplicaID
  15. from ray.serve._private.config import DeploymentConfig
  16. from ray.serve._private.constants import (
  17. SERVE_CONTROLLER_NAME,
  18. SERVE_LOGGER_NAME,
  19. SERVE_NAMESPACE,
  20. )
  21. from ray.serve._private.replica_result import ReplicaResult
  22. from ray.serve.exceptions import RayServeException
  23. from ray.serve.grpc_util import RayServegRPCContext
  24. from ray.serve.schema import ReplicaRank
  25. from ray.util.annotations import DeveloperAPI
  26. logger = logging.getLogger(SERVE_LOGGER_NAME)
  27. _INTERNAL_REPLICA_CONTEXT: "ReplicaContext" = None
  28. _global_client: ServeControllerClient = None
  29. @DeveloperAPI
  30. @dataclass
  31. class ReplicaContext:
  32. """Stores runtime context info for replicas.
  33. Fields:
  34. - app_name: name of the application the replica is a part of.
  35. - deployment: name of the deployment the replica is a part of.
  36. - replica_tag: unique ID for the replica.
  37. - servable_object: instance of the user class/function this replica is running.
  38. - rank: the rank of the replica.
  39. - world_size: the number of replicas in the deployment.
  40. """
  41. replica_id: ReplicaID
  42. servable_object: Callable
  43. _deployment_config: DeploymentConfig
  44. rank: ReplicaRank
  45. world_size: int
  46. _handle_registration_callback: Optional[Callable[[DeploymentID], None]] = None
  47. @property
  48. def app_name(self) -> str:
  49. return self.replica_id.deployment_id.app_name
  50. @property
  51. def deployment(self) -> str:
  52. return self.replica_id.deployment_id.name
  53. @property
  54. def replica_tag(self) -> str:
  55. return self.replica_id.unique_id
  56. def _get_global_client(
  57. _health_check_controller: bool = False, raise_if_no_controller_running: bool = True
  58. ) -> Optional[ServeControllerClient]:
  59. """Gets the global client, which stores the controller's handle.
  60. Args:
  61. _health_check_controller: If True, run a health check on the
  62. cached controller if it exists. If the check fails, try reconnecting
  63. to the controller.
  64. raise_if_no_controller_running: Whether to raise an exception if
  65. there is no currently running Serve controller.
  66. Returns:
  67. ServeControllerClient to the running Serve controller. If there
  68. is no running controller and raise_if_no_controller_running is
  69. set to False, returns None.
  70. Raises:
  71. RayServeException: If there is no running Serve controller actor
  72. and raise_if_no_controller_running is set to True.
  73. """
  74. try:
  75. if _global_client is not None:
  76. if _health_check_controller:
  77. ray.get(_global_client._controller.check_alive.remote())
  78. return _global_client
  79. except RayActorError:
  80. logger.info("The cached controller has died. Reconnecting.")
  81. _set_global_client(None)
  82. return _connect(raise_if_no_controller_running)
  83. def _set_global_client(client):
  84. global _global_client
  85. _global_client = client
  86. def _get_internal_replica_context():
  87. return _INTERNAL_REPLICA_CONTEXT
  88. def _set_internal_replica_context(
  89. *,
  90. replica_id: ReplicaID,
  91. servable_object: Callable,
  92. _deployment_config: DeploymentConfig,
  93. rank: ReplicaRank,
  94. world_size: int,
  95. handle_registration_callback: Optional[Callable[[str, str], None]] = None,
  96. ):
  97. global _INTERNAL_REPLICA_CONTEXT
  98. _INTERNAL_REPLICA_CONTEXT = ReplicaContext(
  99. replica_id=replica_id,
  100. servable_object=servable_object,
  101. _deployment_config=_deployment_config,
  102. rank=rank,
  103. world_size=world_size,
  104. _handle_registration_callback=handle_registration_callback,
  105. )
  106. def _connect(raise_if_no_controller_running: bool = True) -> ServeControllerClient:
  107. """Connect to an existing Serve application on this Ray cluster.
  108. If called from within a replica, this will connect to the same Serve
  109. app that the replica is running in.
  110. Returns:
  111. ServeControllerClient that encapsulates a Ray actor handle to the
  112. existing Serve application's Serve Controller. None if there is
  113. no running Serve controller actor and raise_if_no_controller_running
  114. is set to False.
  115. Raises:
  116. RayServeException: If there is no running Serve controller actor
  117. and raise_if_no_controller_running is set to True.
  118. """
  119. # Initialize ray if needed.
  120. ray._private.worker.global_worker._filter_logs_by_job = False
  121. if not ray.is_initialized():
  122. ray.init(namespace=SERVE_NAMESPACE)
  123. # Try to get serve controller if it exists
  124. try:
  125. controller = ray.get_actor(SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE)
  126. except ValueError:
  127. if raise_if_no_controller_running:
  128. raise RayServeException(
  129. "There is no Serve instance running on this Ray cluster."
  130. )
  131. return
  132. client = ServeControllerClient(
  133. controller,
  134. )
  135. _set_global_client(client)
  136. return client
  137. # Serve request context var which is used for storing the internal
  138. # request context information.
  139. # route_prefix: http url route path, e.g. http://127.0.0.1:/app
  140. # the route is "/app". When you send requests by handle,
  141. # the route is empty.
  142. # request_id: the request id is generated from http proxy, the value
  143. # shouldn't be changed when the variable is set.
  144. # This can be from the client and is used for logging.
  145. # _internal_request_id: the request id is generated from the proxy. Used to track the
  146. # request objects in the system.
  147. # note:
  148. # The request context is readonly to avoid potential
  149. # async task conflicts when using it concurrently.
  150. @dataclass(frozen=True)
  151. class _RequestContext:
  152. route: str = ""
  153. request_id: str = ""
  154. _internal_request_id: str = ""
  155. app_name: str = ""
  156. multiplexed_model_id: str = ""
  157. grpc_context: Optional[RayServegRPCContext] = None
  158. is_http_request: bool = False
  159. cancel_on_parent_request_cancel: bool = False
  160. # Ray tracing context for this request (if tracing is enabled)
  161. # This is extracted from _ray_trace_ctx kwarg at the replica entry point
  162. # Advanced users can access this to propagate tracing to external systems
  163. _ray_trace_ctx: Optional[dict] = None
  164. _serve_request_context = contextvars.ContextVar(
  165. "Serve internal request context variable", default=None
  166. )
  167. _serve_batch_request_context = contextvars.ContextVar(
  168. "Serve internal batching request context variable", default=None
  169. )
  170. def _get_serve_request_context():
  171. """Get the current request context.
  172. Returns:
  173. The current request context
  174. """
  175. if _serve_request_context.get() is None:
  176. _serve_request_context.set(_RequestContext())
  177. return _serve_request_context.get()
  178. def _get_serve_batch_request_context():
  179. """Get the list of request contexts for the current batch."""
  180. if _serve_batch_request_context.get() is None:
  181. _serve_batch_request_context.set([])
  182. return _serve_batch_request_context.get()
  183. def _set_request_context(
  184. route: str = "",
  185. request_id: str = "",
  186. _internal_request_id: str = "",
  187. app_name: str = "",
  188. multiplexed_model_id: str = "",
  189. ):
  190. """Set the request context. If the value is not set,
  191. the current context value will be used."""
  192. current_request_context = _get_serve_request_context()
  193. _serve_request_context.set(
  194. _RequestContext(
  195. route=route or current_request_context.route,
  196. request_id=request_id or current_request_context.request_id,
  197. _internal_request_id=_internal_request_id
  198. or current_request_context._internal_request_id,
  199. app_name=app_name or current_request_context.app_name,
  200. multiplexed_model_id=multiplexed_model_id
  201. or current_request_context.multiplexed_model_id,
  202. )
  203. )
  204. def _unset_request_context():
  205. """Unset the request context."""
  206. _serve_request_context.set(_RequestContext())
  207. def _set_batch_request_context(request_contexts: List[_RequestContext]):
  208. """Add the request context to the batch request context."""
  209. _serve_batch_request_context.set(request_contexts)
  210. # `_requests_pending_assignment` is a map from request ID to a
  211. # dictionary of asyncio tasks.
  212. # The request ID points to an ongoing request that is executing on the
  213. # current replica, and the asyncio tasks are ongoing tasks started on
  214. # the router to assign child requests to downstream replicas.
  215. # A dictionary is used over a set to track the asyncio tasks for more
  216. # efficient addition and deletion time complexity. A uniquely generated
  217. # `response_id` is used to identify each task.
  218. _requests_pending_assignment: Dict[str, Dict[str, asyncio.Task]] = defaultdict(dict)
  219. # Note that the functions below that manipulate
  220. # `_requests_pending_assignment` are NOT thread-safe. They are only
  221. # expected to be called from the same thread/asyncio event-loop.
  222. def _get_requests_pending_assignment(parent_request_id: str) -> Dict[str, asyncio.Task]:
  223. if parent_request_id in _requests_pending_assignment:
  224. return _requests_pending_assignment[parent_request_id]
  225. return {}
  226. def _add_request_pending_assignment(parent_request_id: str, response_id: str, task):
  227. # NOTE: `parent_request_id` is the `internal_request_id` corresponding
  228. # to an ongoing Serve request, so it is always non-empty.
  229. _requests_pending_assignment[parent_request_id][response_id] = task
  230. def _remove_request_pending_assignment(parent_request_id: str, response_id: str):
  231. if response_id in _requests_pending_assignment[parent_request_id]:
  232. del _requests_pending_assignment[parent_request_id][response_id]
  233. if len(_requests_pending_assignment[parent_request_id]) == 0:
  234. del _requests_pending_assignment[parent_request_id]
  235. # `_in_flight_requests` is a map from request ID to a dictionary of replica results.
  236. # The request ID points to an ongoing Serve request, and the replica results are
  237. # in-flight child requests that have been assigned to a downstream replica.
  238. # A dictionary is used over a set to track the replica results for more
  239. # efficient addition and deletion time complexity. A uniquely generated
  240. # `response_id` is used to identify each replica result.
  241. _in_flight_requests: Dict[str, Dict[str, ReplicaResult]] = defaultdict(dict)
  242. # Note that the functions below that manipulate `_in_flight_requests`
  243. # are NOT thread-safe. They are only expected to be called from the
  244. # same thread/asyncio event-loop.
  245. def _get_in_flight_requests(parent_request_id):
  246. if parent_request_id in _in_flight_requests:
  247. return _in_flight_requests[parent_request_id]
  248. return {}
  249. def _add_in_flight_request(parent_request_id, response_id, replica_result):
  250. _in_flight_requests[parent_request_id][response_id] = replica_result
  251. def _remove_in_flight_request(parent_request_id, response_id):
  252. if response_id in _in_flight_requests[parent_request_id]:
  253. del _in_flight_requests[parent_request_id][response_id]
  254. if len(_in_flight_requests[parent_request_id]) == 0:
  255. del _in_flight_requests[parent_request_id]