long_poll.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. import asyncio
  2. import contextvars
  3. import logging
  4. import os
  5. import random
  6. import time
  7. from asyncio import sleep
  8. from asyncio.events import AbstractEventLoop
  9. from collections import defaultdict
  10. from collections.abc import Mapping
  11. from dataclasses import dataclass
  12. from enum import Enum, auto
  13. from typing import Any, Callable, DefaultDict, Dict, Optional, Set, Tuple, Union
  14. import ray
  15. from ray._common.utils import get_or_create_event_loop
  16. from ray.serve._private.constants import DEFAULT_LATENCY_BUCKET_MS, SERVE_LOGGER_NAME
  17. from ray.serve.generated.serve_pb2 import (
  18. DeploymentTargetInfo,
  19. EndpointInfo as EndpointInfoProto,
  20. EndpointSet,
  21. LongPollRequest,
  22. LongPollResult,
  23. UpdatedObject as UpdatedObjectProto,
  24. )
  25. from ray.util import metrics
  26. logger = logging.getLogger(SERVE_LOGGER_NAME)
  27. # Each LongPollClient will send requests to LongPollHost to poll changes
  28. # as blocking awaitable. This doesn't scale if we have many client instances
  29. # that will slow down, or even block controller actor's event loop if near
  30. # its max_concurrency limit. Therefore we timeout a polling request after
  31. # a few seconds and let each client retry on their end.
  32. # We randomly select a timeout within this range to avoid a "thundering herd"
  33. # when there are many clients subscribing at the same time.
  34. LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S = (
  35. float(os.environ.get("LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S_LOWER_BOUND", "30")),
  36. float(os.environ.get("LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S_UPPER_BOUND", "60")),
  37. )
  38. class LongPollNamespace(Enum):
  39. def __repr__(self):
  40. return f"{self.__class__.__name__}.{self.name}"
  41. DEPLOYMENT_TARGETS = auto()
  42. ROUTE_TABLE = auto()
  43. GLOBAL_LOGGING_CONFIG = auto()
  44. DEPLOYMENT_CONFIG = auto()
  45. @dataclass
  46. class UpdatedObject:
  47. object_snapshot: Any
  48. # The identifier for the object's version. There is not sequential relation
  49. # among different object's snapshot_ids.
  50. snapshot_id: int
  51. # Timestamp (in seconds since epoch) when notify_changed was called.
  52. # Used by clients to measure end-to-end propagation latency.
  53. notify_timestamp: float
  54. # Type signature for the update state callbacks. E.g.
  55. # async def update_state(updated_object: Any):
  56. # do_something(updated_object)
  57. UpdateStateCallable = Callable[[Any], None]
  58. KeyType = Union[str, LongPollNamespace, Tuple[LongPollNamespace, str]]
  59. class LongPollState(Enum):
  60. TIME_OUT = auto()
  61. class LongPollClient:
  62. """The asynchronous long polling client.
  63. Args:
  64. host_actor: handle to actor embedding LongPollHost.
  65. key_listeners: a dictionary mapping keys to
  66. callbacks to be called on state update for the corresponding keys.
  67. call_in_event_loop: an asyncio event loop
  68. to post the callback into.
  69. """
  70. def __init__(
  71. self,
  72. host_actor,
  73. key_listeners: Dict[KeyType, UpdateStateCallable],
  74. call_in_event_loop: AbstractEventLoop,
  75. ) -> None:
  76. # We used to allow this to be optional, but due to Ray Client issue
  77. # we now enforce all long poll client to post callback to event loop
  78. # See https://github.com/ray-project/ray/issues/20971
  79. assert call_in_event_loop is not None
  80. self.host_actor = host_actor
  81. self.key_listeners = key_listeners
  82. self.event_loop = call_in_event_loop
  83. self.snapshot_ids: Dict[KeyType, int] = {
  84. # The initial snapshot id for each key is < 0,
  85. # but real snapshot keys in the long poll host are always >= 0,
  86. # so this will always trigger an initial update.
  87. key: -1
  88. for key in self.key_listeners.keys()
  89. }
  90. self.is_running = True
  91. # Metric to track end-to-end latency from controller to client
  92. self.long_poll_latency_histogram = metrics.Histogram(
  93. "serve_long_poll_latency_ms",
  94. description=(
  95. "The time in milliseconds for updates to propagate from "
  96. "controller to clients."
  97. ),
  98. boundaries=DEFAULT_LATENCY_BUCKET_MS,
  99. tag_keys=("namespace",),
  100. )
  101. # NOTE(edoakes): we schedule the initial _poll_next call with an empty context
  102. # so that Ray will not recursively cancel the underlying `listen_for_change`
  103. # task. See: https://github.com/ray-project/ray/issues/52476.
  104. self.event_loop.call_soon_threadsafe(
  105. self._poll_next, context=contextvars.Context()
  106. )
  107. def stop(self) -> None:
  108. """Stop the long poll client after the next RPC returns."""
  109. self.is_running = False
  110. def add_key_listeners(
  111. self, key_listeners: Dict[KeyType, UpdateStateCallable]
  112. ) -> None:
  113. """Add more key listeners to the client.
  114. The new listeners will only be included in the *next* long poll request;
  115. the current request will continue with the existing listeners.
  116. If a key is already in the client, the new listener will replace the old one,
  117. but the snapshot ID will be preserved, so the new listener will only be called
  118. on the *next* update to that key.
  119. """
  120. # We need to run the underlying method in the same event loop that runs
  121. # the long poll loop, because we need to mutate the mapping of snapshot IDs,
  122. # which also needs to be serialized by the long poll's RPC to the
  123. # Serve Controller. If those happened concurrently in different threads,
  124. # we could get a `RuntimeError: dictionary changed size during iteration`.
  125. # See https://github.com/ray-project/ray/pull/52793 for more details.
  126. self.event_loop.call_soon_threadsafe(self._add_key_listeners, key_listeners)
  127. def _add_key_listeners(
  128. self, key_listeners: Dict[KeyType, UpdateStateCallable]
  129. ) -> None:
  130. """Inner method that actually adds the key listeners, to be called
  131. via call_soon_threadsafe for thread safety.
  132. """
  133. # Only initialize snapshot ids for *new* keys.
  134. self.snapshot_ids.update(
  135. {key: -1 for key in key_listeners.keys() if key not in self.key_listeners}
  136. )
  137. self.key_listeners.update(key_listeners)
  138. def _on_callback_completed(self, trigger_at: int):
  139. """Called after a single callback is completed.
  140. When the total number of callback completed equals to trigger_at,
  141. _poll_next() will be called. This is designed to make sure we only
  142. _poll_next() after all the state callbacks completed. This is a
  143. way to serialize the callback invocations between object versions.
  144. """
  145. self._callbacks_processed_count += 1
  146. if self._callbacks_processed_count == trigger_at:
  147. self._poll_next()
  148. def _poll_next(self):
  149. """Poll the update. The callback is expected to scheduler another
  150. _poll_next call.
  151. """
  152. if not self.is_running:
  153. return
  154. self._callbacks_processed_count = 0
  155. self._current_ref = self.host_actor.listen_for_change.remote(self.snapshot_ids)
  156. self._current_ref._on_completed(lambda update: self._process_update(update))
  157. def _schedule_to_event_loop(self, callback):
  158. # Schedule the next iteration only if the loop is running.
  159. # The event loop might not be running if users used a cached
  160. # version across loops.
  161. if self.event_loop.is_running():
  162. self.event_loop.call_soon_threadsafe(callback)
  163. else:
  164. logger.error("The event loop is closed, shutting down long poll client.")
  165. self.is_running = False
  166. def _process_update(self, updates: Dict[str, UpdatedObject]):
  167. if isinstance(updates, (ray.exceptions.RayActorError)):
  168. # This can happen during shutdown where the controller is
  169. # intentionally killed, the client should just gracefully
  170. # exit.
  171. logger.debug("LongPollClient failed to connect to host. Shutting down.")
  172. self.is_running = False
  173. return
  174. if isinstance(updates, ConnectionError):
  175. logger.warning("LongPollClient connection failed, shutting down.")
  176. self.is_running = False
  177. return
  178. if isinstance(updates, (ray.exceptions.RayTaskError)):
  179. # Some error happened in the controller. It could be a bug or
  180. # some undesired state.
  181. logger.error("LongPollHost errored\n" + updates.traceback_str)
  182. # We must call this in event loop so it works in Ray Client.
  183. # See https://github.com/ray-project/ray/issues/20971
  184. self._schedule_to_event_loop(self._poll_next)
  185. return
  186. if updates == LongPollState.TIME_OUT:
  187. logger.debug("LongPollClient polling timed out. Retrying.")
  188. self._schedule_to_event_loop(self._poll_next)
  189. return
  190. logger.debug(
  191. f"LongPollClient {self} received updates for keys: "
  192. f"{list(updates.keys())}.",
  193. extra={"log_to_stderr": False},
  194. )
  195. if not updates: # no updates, no callbacks to run, just poll again
  196. self._schedule_to_event_loop(self._poll_next)
  197. # Record latency metrics for received updates
  198. receive_time = time.time()
  199. for key, update in updates.items():
  200. # Record end-to-end latency from controller to client
  201. latency_ms = (receive_time - update.notify_timestamp) * 1000
  202. self.long_poll_latency_histogram.observe(
  203. latency_ms, tags={"namespace": str(key)}
  204. )
  205. self.snapshot_ids[key] = update.snapshot_id
  206. callback = self.key_listeners[key]
  207. # Bind the parameters because closures are late-binding.
  208. # https://docs.python-guide.org/writing/gotchas/#late-binding-closures # noqa: E501
  209. def chained(callback=callback, arg=update.object_snapshot):
  210. callback(arg)
  211. self._on_callback_completed(trigger_at=len(updates))
  212. self._schedule_to_event_loop(chained)
  213. class LongPollHost:
  214. """The server side object that manages long pulling requests.
  215. The desired use case is to embed this in an Ray actor. Client will be
  216. expected to call actor.listen_for_change.remote(...). On the host side,
  217. you can call host.notify_changed({key: object}) to update the state and
  218. potentially notify whoever is polling for these values.
  219. Internally, we use snapshot_ids for each object to identify client with
  220. outdated object and immediately return the result. If the client has the
  221. up-to-date version, then the listen_for_change call will only return when
  222. the object is updated.
  223. """
  224. def __init__(
  225. self,
  226. listen_for_change_request_timeout_s: Tuple[
  227. int, int
  228. ] = LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S,
  229. ):
  230. # Map object_key -> int
  231. self.snapshot_ids: Dict[KeyType, int] = {}
  232. # Map object_key -> object
  233. self.object_snapshots: Dict[KeyType, Any] = {}
  234. # Map object_key -> set(asyncio.Event waiting for updates)
  235. self.notifier_events: DefaultDict[KeyType, Set[asyncio.Event]] = defaultdict(
  236. set
  237. )
  238. # Map object_key -> timestamp when notify_changed was called
  239. # Used to track latency for propagating updates to clients
  240. self._notify_timestamps: Dict[KeyType, float] = {}
  241. self._listen_for_change_request_timeout_s = listen_for_change_request_timeout_s
  242. self.transmission_counter = metrics.Counter(
  243. "serve_long_poll_host_transmission_counter",
  244. description="The number of times the long poll host transmits data.",
  245. tag_keys=("namespace_or_state",),
  246. )
  247. self.pending_clients_gauge = metrics.Gauge(
  248. "serve_long_poll_pending_clients",
  249. description=("The number of clients waiting for updates per namespace."),
  250. tag_keys=("namespace",),
  251. )
  252. def _get_num_notifier_events(self, key: Optional[KeyType] = None):
  253. """Used for testing."""
  254. if key is not None:
  255. return len(self.notifier_events[key])
  256. else:
  257. return sum(len(events) for events in self.notifier_events.values())
  258. def _count_send(
  259. self, timeout_or_data: Union[LongPollState, Dict[KeyType, UpdatedObject]]
  260. ):
  261. """Helper method that tracks the data sent by listen_for_change.
  262. Records number of times long poll host sends data in the
  263. ray_serve_long_poll_host_send_counter metric.
  264. """
  265. if isinstance(timeout_or_data, LongPollState):
  266. # The only LongPollState is TIME_OUT– the long poll
  267. # connection has timed out.
  268. self.transmission_counter.inc(
  269. value=1, tags={"namespace_or_state": "TIMEOUT"}
  270. )
  271. else:
  272. data = timeout_or_data
  273. for key in data.keys():
  274. self.transmission_counter.inc(
  275. value=1, tags={"namespace_or_state": str(key)}
  276. )
  277. async def listen_for_change(
  278. self,
  279. keys_to_snapshot_ids: Dict[KeyType, int],
  280. ) -> Union[LongPollState, Dict[KeyType, UpdatedObject]]:
  281. """Listen for changed objects.
  282. This method will return a dictionary of updated objects. It returns
  283. immediately if any of the snapshot_ids are outdated,
  284. otherwise it will block until there's an update.
  285. """
  286. # If there are no keys to listen for,
  287. # just wait for a short time to provide backpressure,
  288. # then return an empty update.
  289. if not keys_to_snapshot_ids:
  290. await sleep(1)
  291. updated_objects = {}
  292. self._count_send(updated_objects)
  293. return updated_objects
  294. # If there are any keys with outdated snapshot ids,
  295. # return their updated values immediately.
  296. updated_objects = {}
  297. for key, client_snapshot_id in keys_to_snapshot_ids.items():
  298. try:
  299. existing_id = self.snapshot_ids[key]
  300. except KeyError:
  301. # The caller may ask for keys that we don't know about (yet),
  302. # just ignore them.
  303. # This can happen when, for example,
  304. # a deployment handle is manually created for an app
  305. # that hasn't been deployed yet (by bypassing the safety checks).
  306. continue
  307. if existing_id != client_snapshot_id:
  308. updated_objects[key] = UpdatedObject(
  309. self.object_snapshots[key],
  310. existing_id,
  311. self._notify_timestamps[key],
  312. )
  313. if len(updated_objects) > 0:
  314. self._count_send(updated_objects)
  315. return updated_objects
  316. # Otherwise, register asyncio events to be waited.
  317. async_task_to_events = {}
  318. async_task_to_watched_keys = {}
  319. for key in keys_to_snapshot_ids.keys():
  320. # Create a new asyncio event for this key.
  321. event = asyncio.Event()
  322. # Make sure future caller of notify_changed will unblock this
  323. # asyncio Event.
  324. self.notifier_events[key].add(event)
  325. # Update pending clients gauge for this key
  326. self.pending_clients_gauge.set(
  327. len(self.notifier_events[key]), tags={"namespace": str(key)}
  328. )
  329. task = get_or_create_event_loop().create_task(event.wait())
  330. async_task_to_events[task] = event
  331. async_task_to_watched_keys[task] = key
  332. done, not_done = await asyncio.wait(
  333. async_task_to_watched_keys.keys(),
  334. return_when=asyncio.FIRST_COMPLETED,
  335. timeout=random.uniform(*self._listen_for_change_request_timeout_s),
  336. )
  337. for task in not_done:
  338. task.cancel()
  339. try:
  340. event = async_task_to_events[task]
  341. key = async_task_to_watched_keys[task]
  342. self.notifier_events[key].remove(event)
  343. # Update pending clients gauge after removing
  344. self.pending_clients_gauge.set(
  345. len(self.notifier_events[key]), tags={"namespace": str(key)}
  346. )
  347. except KeyError:
  348. # Because we use `FIRST_COMPLETED` above, a task in `not_done` may
  349. # actually have had its event removed in `notify_changed`.
  350. pass
  351. if len(done) == 0:
  352. self._count_send(LongPollState.TIME_OUT)
  353. return LongPollState.TIME_OUT
  354. else:
  355. updated_objects = {}
  356. for task in done:
  357. updated_object_key = async_task_to_watched_keys[task]
  358. updated_objects[updated_object_key] = UpdatedObject(
  359. self.object_snapshots[updated_object_key],
  360. self.snapshot_ids[updated_object_key],
  361. self._notify_timestamps[updated_object_key],
  362. )
  363. self._count_send(updated_objects)
  364. return updated_objects
  365. async def listen_for_change_java(
  366. self,
  367. keys_to_snapshot_ids_bytes: bytes,
  368. ) -> bytes:
  369. """Listen for changed objects. only call by java proxy/router now.
  370. Args:
  371. keys_to_snapshot_ids_bytes (Dict[str, int]): the protobuf bytes of
  372. keys_to_snapshot_ids (Dict[str, int]).
  373. """
  374. request_proto = LongPollRequest.FromString(keys_to_snapshot_ids_bytes)
  375. keys_to_snapshot_ids = {
  376. self._parse_xlang_key(xlang_key): snapshot_id
  377. for xlang_key, snapshot_id in request_proto.keys_to_snapshot_ids.items()
  378. }
  379. keys_to_updated_objects = await self.listen_for_change(keys_to_snapshot_ids)
  380. return self._listen_result_to_proto_bytes(keys_to_updated_objects)
  381. def _parse_poll_namespace(self, name: str):
  382. if name == LongPollNamespace.ROUTE_TABLE.name:
  383. return LongPollNamespace.ROUTE_TABLE
  384. elif name == LongPollNamespace.DEPLOYMENT_TARGETS.name:
  385. return LongPollNamespace.DEPLOYMENT_TARGETS
  386. else:
  387. return name
  388. def _parse_xlang_key(self, xlang_key: str) -> KeyType:
  389. if xlang_key is None:
  390. raise ValueError("func _parse_xlang_key: xlang_key is None")
  391. if xlang_key.startswith("(") and xlang_key.endswith(")"):
  392. fields = xlang_key[1:-1].split(",")
  393. if len(fields) == 2:
  394. enum_field = self._parse_poll_namespace(fields[0].strip())
  395. if isinstance(enum_field, LongPollNamespace):
  396. return enum_field, fields[1].strip()
  397. else:
  398. return self._parse_poll_namespace(xlang_key)
  399. raise ValueError("can not parse key type from xlang_key {}".format(xlang_key))
  400. def _build_xlang_key(self, key: KeyType) -> str:
  401. if isinstance(key, tuple):
  402. return "(" + key[0].name + "," + key[1] + ")"
  403. elif isinstance(key, LongPollNamespace):
  404. return key.name
  405. else:
  406. return key
  407. def _object_snapshot_to_proto_bytes(
  408. self, key: KeyType, object_snapshot: Any
  409. ) -> bytes:
  410. if key == LongPollNamespace.ROUTE_TABLE:
  411. # object_snapshot is Dict[DeploymentID, EndpointInfo]
  412. # NOTE(zcin): the endpoint dictionary broadcasted to Java
  413. # HTTP proxies should use string as key because Java does
  414. # not yet support 2.x or applications
  415. xlang_endpoints = {
  416. str(endpoint_tag): EndpointInfoProto(route=endpoint_info.route)
  417. for endpoint_tag, endpoint_info in object_snapshot.items()
  418. }
  419. return EndpointSet(endpoints=xlang_endpoints).SerializeToString()
  420. elif isinstance(key, tuple) and key[0] == LongPollNamespace.DEPLOYMENT_TARGETS:
  421. # object_snapshot.running_replicas is List[RunningReplicaInfo]
  422. actor_name_list = [
  423. replica_info.replica_id.to_full_id_str()
  424. for replica_info in object_snapshot.running_replicas
  425. ]
  426. return DeploymentTargetInfo(
  427. replica_names=actor_name_list,
  428. is_available=object_snapshot.is_available,
  429. ).SerializeToString()
  430. else:
  431. return str.encode(str(object_snapshot))
  432. def _listen_result_to_proto_bytes(
  433. self, keys_to_updated_objects: Dict[KeyType, UpdatedObject]
  434. ) -> bytes:
  435. xlang_keys_to_updated_objects = {
  436. self._build_xlang_key(key): UpdatedObjectProto(
  437. snapshot_id=updated_object.snapshot_id,
  438. object_snapshot=self._object_snapshot_to_proto_bytes(
  439. key, updated_object.object_snapshot
  440. ),
  441. )
  442. for key, updated_object in keys_to_updated_objects.items()
  443. }
  444. data = {
  445. "updated_objects": xlang_keys_to_updated_objects,
  446. }
  447. proto = LongPollResult(**data)
  448. return proto.SerializeToString()
  449. def notify_changed(self, updates: Mapping[KeyType, Any]) -> None:
  450. """
  451. Update the current snapshot of some objects
  452. and notify any long poll clients.
  453. """
  454. notify_time = time.time()
  455. for object_key, updated_object in updates.items():
  456. try:
  457. self.snapshot_ids[object_key] += 1
  458. except KeyError:
  459. # Initial snapshot id must be >= 0, so that the long poll client
  460. # can send a negative initial snapshot id to get a fast update.
  461. # They should also be randomized; see
  462. # https://github.com/ray-project/ray/pull/45881#discussion_r1645243485
  463. self.snapshot_ids[object_key] = random.randint(0, 1_000_000)
  464. self.object_snapshots[object_key] = updated_object
  465. # Record timestamp for latency tracking
  466. self._notify_timestamps[object_key] = notify_time
  467. logger.debug(f"LongPollHost: Notify change for key {object_key}.")
  468. events_to_notify = self.notifier_events.pop(object_key, set())
  469. if events_to_notify:
  470. # Update pending clients gauge (now 0 for this key since we popped all)
  471. self.pending_clients_gauge.set(0, tags={"namespace": str(object_key)})
  472. for event in events_to_notify:
  473. event.set()