metrics_utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. import asyncio
  2. import bisect
  3. import heapq
  4. import logging
  5. import statistics
  6. from collections import defaultdict
  7. from dataclasses import dataclass
  8. from itertools import chain
  9. from typing import (
  10. Awaitable,
  11. Callable,
  12. DefaultDict,
  13. Dict,
  14. Hashable,
  15. Iterable,
  16. List,
  17. Optional,
  18. Tuple,
  19. Union,
  20. )
  21. from ray.serve._private.common import TimeSeries, TimeStampedValue
  22. from ray.serve._private.constants import (
  23. METRICS_PUSHER_GRACEFUL_SHUTDOWN_TIMEOUT_S,
  24. SERVE_LOGGER_NAME,
  25. )
  26. from ray.serve.config import AggregationFunction
  27. QUEUED_REQUESTS_KEY = "queued"
  28. logger = logging.getLogger(SERVE_LOGGER_NAME)
  29. @dataclass
  30. class _MetricsTask:
  31. task_func: Union[Callable, Callable[[], Awaitable]]
  32. interval_s: float
  33. class MetricsPusher:
  34. """Periodically runs registered asyncio tasks."""
  35. def __init__(
  36. self,
  37. *,
  38. async_sleep: Optional[Callable[[int], None]] = None,
  39. ):
  40. self._async_sleep = async_sleep or asyncio.sleep
  41. self._tasks: Dict[str, _MetricsTask] = dict()
  42. self._async_tasks: Dict[str, asyncio.Task] = dict()
  43. # The event needs to be lazily initialized because this class may be constructed
  44. # on the main thread but its methods called on a separate asyncio loop.
  45. self._stop_event: Optional[asyncio.Event] = None
  46. @property
  47. def stop_event(self) -> asyncio.Event:
  48. if self._stop_event is None:
  49. self._stop_event = asyncio.Event()
  50. return self._stop_event
  51. def start(self):
  52. self.stop_event.clear()
  53. async def metrics_task(self, name: str):
  54. """Periodically runs `task_func` every `interval_s` until `stop_event` is set.
  55. If `task_func` raises an error, an exception will be logged.
  56. Supports both sync and async task functions.
  57. """
  58. wait_for_stop_event = asyncio.create_task(self.stop_event.wait())
  59. while True:
  60. if wait_for_stop_event.done():
  61. return
  62. try:
  63. task_func = self._tasks[name].task_func
  64. # Check if the function is a coroutine function
  65. if asyncio.iscoroutinefunction(task_func):
  66. await task_func()
  67. else:
  68. task_func()
  69. except Exception as e:
  70. logger.exception(f"Failed to run metrics task '{name}': {e}")
  71. sleep_task = asyncio.create_task(
  72. self._async_sleep(self._tasks[name].interval_s)
  73. )
  74. await asyncio.wait(
  75. [sleep_task, wait_for_stop_event],
  76. return_when=asyncio.FIRST_COMPLETED,
  77. )
  78. if not sleep_task.done():
  79. sleep_task.cancel()
  80. def register_or_update_task(
  81. self,
  82. name: str,
  83. task_func: Union[Callable, Callable[[], Awaitable]],
  84. interval_s: int,
  85. ) -> None:
  86. """Register a sync or async task under the provided name, or update it.
  87. This method is idempotent - if a task is already registered with
  88. the specified name, it will update it with the most recent info.
  89. Args:
  90. name: Unique name for the task.
  91. task_func: Either a sync function or async function (coroutine function).
  92. interval_s: Interval in seconds between task executions.
  93. """
  94. self._tasks[name] = _MetricsTask(task_func, interval_s)
  95. if name not in self._async_tasks or self._async_tasks[name].done():
  96. self._async_tasks[name] = asyncio.create_task(self.metrics_task(name))
  97. def stop_tasks(self):
  98. self.stop_event.set()
  99. self._tasks.clear()
  100. self._async_tasks.clear()
  101. async def graceful_shutdown(self):
  102. """Shutdown metrics pusher gracefully.
  103. This method will ensure idempotency of shutdown call.
  104. """
  105. self.stop_event.set()
  106. if self._async_tasks:
  107. await asyncio.wait(
  108. list(self._async_tasks.values()),
  109. timeout=METRICS_PUSHER_GRACEFUL_SHUTDOWN_TIMEOUT_S,
  110. )
  111. self._tasks.clear()
  112. self._async_tasks.clear()
  113. class InMemoryMetricsStore:
  114. """A very simple, in memory time series database"""
  115. def __init__(self):
  116. self.data: DefaultDict[Hashable, TimeSeries] = defaultdict(list)
  117. def add_metrics_point(self, data_points: Dict[Hashable, float], timestamp: float):
  118. """Push new data points to the store.
  119. Args:
  120. data_points: dictionary containing the metrics values. The
  121. key should uniquely identify this time series
  122. and to be used to perform aggregation.
  123. timestamp: the unix epoch timestamp the metrics are
  124. collected at.
  125. """
  126. for name, value in data_points.items():
  127. # Using in-sort to insert while maintaining sorted ordering.
  128. bisect.insort(a=self.data[name], x=TimeStampedValue(timestamp, value))
  129. def prune_keys_and_compact_data(self, start_timestamp_s: float):
  130. """Prune keys and compact data that are outdated.
  131. For keys that haven't had new data recorded after the timestamp,
  132. remove them from the database.
  133. For keys that have, compact the datapoints that were recorded
  134. before the timestamp.
  135. """
  136. for key, datapoints in list(self.data.items()):
  137. if len(datapoints) == 0 or datapoints[-1].timestamp < start_timestamp_s:
  138. del self.data[key]
  139. else:
  140. self.data[key] = self._get_datapoints(key, start_timestamp_s)
  141. def _get_datapoints(
  142. self, key: Hashable, window_start_timestamp_s: float
  143. ) -> TimeSeries:
  144. """Get all data points given key after window_start_timestamp_s"""
  145. datapoints = self.data[key]
  146. idx = bisect.bisect(
  147. a=datapoints,
  148. x=TimeStampedValue(
  149. timestamp=window_start_timestamp_s, value=0 # dummy value
  150. ),
  151. )
  152. return datapoints[idx:]
  153. def _aggregate_reduce(
  154. self,
  155. keys: Iterable[Hashable],
  156. aggregate_fn: Callable[[Iterable[float]], float],
  157. ) -> Tuple[Optional[float], int]:
  158. """Reduce the entire set of timeseries values across the specified keys.
  159. Args:
  160. keys: Iterable of keys to aggregate across.
  161. aggregate_fn: Function to apply across all float values, e.g., sum, max.
  162. Returns:
  163. A tuple of (float, int) where the first element is the aggregated value
  164. and the second element is the number of valid keys used.
  165. Returns (None, 0) if no valid keys have data.
  166. Example:
  167. Suppose the store contains:
  168. >>> store = InMemoryMetricsStore()
  169. >>> store.data.update({
  170. ... "a": [TimeStampedValue(0, 1.0), TimeStampedValue(1, 2.0)],
  171. ... "b": [],
  172. ... "c": [TimeStampedValue(0, 10.0)],
  173. ... })
  174. Using sum across keys:
  175. >>> store._aggregate_reduce(keys=["a", "b", "c"], aggregate_fn=sum)
  176. (13.0, 2)
  177. Here:
  178. - The aggregated value is 1.0 + 2.0 + 10.0 = 13.0
  179. - Only keys "a" and "c" contribute values, so report_count = 2
  180. """
  181. valid_key_count = 0
  182. def _values_generator():
  183. """Generator that yields values from valid keys without storing them all in memory."""
  184. nonlocal valid_key_count
  185. for key in keys:
  186. series = self.data.get(key, [])
  187. if not series:
  188. continue
  189. valid_key_count += 1
  190. for timestamp_value in series:
  191. yield timestamp_value.value
  192. # Create the generator and check if it has any values
  193. values_gen = _values_generator()
  194. try:
  195. first_value = next(values_gen)
  196. except StopIteration:
  197. # No valid data found
  198. return None, 0
  199. # Apply aggregation to the generator (memory efficient)
  200. aggregated_result = aggregate_fn(chain([first_value], values_gen))
  201. return aggregated_result, valid_key_count
  202. def get_latest(
  203. self,
  204. key: Hashable,
  205. ) -> Optional[float]:
  206. """Get the latest value for a given key."""
  207. if not self.data.get(key, None):
  208. return None
  209. return self.data[key][-1].value
  210. def aggregate_sum(
  211. self,
  212. keys: Iterable[Hashable],
  213. ) -> Tuple[Optional[float], int]:
  214. """Sum the entire set of timeseries values across the specified keys.
  215. Args:
  216. keys: Iterable of keys to aggregate across.
  217. Returns:
  218. A tuple of (float, int) where the first element is the sum across
  219. all values found at `keys`, and the second is the number of valid
  220. keys used to compute the sum.
  221. Returns (None, 0) if no valid keys have data.
  222. """
  223. return self._aggregate_reduce(keys, sum)
  224. def aggregate_avg(
  225. self,
  226. keys: Iterable[Hashable],
  227. ) -> Tuple[Optional[float], int]:
  228. """Average the entire set of timeseries values across the specified keys.
  229. Args:
  230. keys: Iterable of keys to aggregate across.
  231. Returns:
  232. A tuple of (float, int) where the first element is the mean across
  233. all values found at `keys`, and the second is the number of valid
  234. keys used to compute the mean.
  235. Returns (None, 0) if no valid keys have data.
  236. """
  237. return self._aggregate_reduce(keys, statistics.mean)
  238. def timeseries_count(
  239. self,
  240. key: Hashable,
  241. ) -> int:
  242. """Count the number of values across all timeseries values at the specified keys."""
  243. series = self.data.get(key, [])
  244. if not series:
  245. return 0
  246. return len(series)
  247. def time_weighted_average(
  248. step_series: TimeSeries,
  249. window_start: Optional[float] = None,
  250. window_end: Optional[float] = None,
  251. last_window_s: float = 1.0,
  252. ) -> Optional[float]:
  253. """
  254. Compute time-weighted average of a step function over a time interval.
  255. Args:
  256. step_series: Step function as list of (timestamp, value) points, sorted by time.
  257. Values are right-continuous (constant until next change).
  258. window_start: Start of averaging window (inclusive). If None, uses the start of the series.
  259. window_end: End of averaging window (exclusive). If None, uses the end of the series.
  260. last_window_s: when window_end is None, uses the last_window_s to compute the end of the window.
  261. Returns:
  262. Time-weighted average over the interval, or None if no data overlaps.
  263. """
  264. if not step_series:
  265. return None
  266. # Handle None values by using full timeseries bounds
  267. if window_start is None:
  268. window_start = step_series[0].timestamp
  269. if window_end is None:
  270. # Use timestamp after the last point to include the final segment
  271. window_end = step_series[-1].timestamp + last_window_s
  272. if window_end <= window_start:
  273. return None
  274. total_weighted_value = 0.0
  275. total_duration = 0.0
  276. current_value = 0.0 # Default if no data before window_start
  277. current_time = window_start
  278. # Process each segment that overlaps with the window
  279. for point in step_series:
  280. if point.timestamp <= window_start:
  281. # Find the value at window_start (LOCF)
  282. current_value = point.value
  283. continue
  284. if point.timestamp >= window_end:
  285. break # Beyond our window
  286. # Add contribution of current segment
  287. segment_end = min(point.timestamp, window_end)
  288. duration = segment_end - current_time
  289. if duration > 0:
  290. total_weighted_value += current_value * duration
  291. total_duration += duration
  292. current_value = point.value
  293. current_time = segment_end
  294. # Add final segment if it extends to window_end
  295. if current_time < window_end:
  296. duration = window_end - current_time
  297. total_weighted_value += current_value * duration
  298. total_duration += duration
  299. return total_weighted_value / total_duration if total_duration > 0 else None
  300. def aggregate_timeseries(
  301. timeseries: TimeSeries,
  302. aggregation_function: AggregationFunction,
  303. last_window_s: float = 1.0,
  304. ) -> Optional[float]:
  305. """Aggregate the values in a timeseries using a specified function."""
  306. if aggregation_function == AggregationFunction.MEAN:
  307. return time_weighted_average(timeseries, last_window_s=last_window_s)
  308. elif aggregation_function == AggregationFunction.MAX:
  309. return max(ts.value for ts in timeseries) if timeseries else None
  310. elif aggregation_function == AggregationFunction.MIN:
  311. return min(ts.value for ts in timeseries) if timeseries else None
  312. else:
  313. raise ValueError(f"Invalid aggregation function: {aggregation_function}")
  314. def merge_instantaneous_total(
  315. replicas_timeseries: List[TimeSeries],
  316. ) -> TimeSeries:
  317. """
  318. Merge multiple gauge time series (right-continuous, LOCF) into an
  319. instantaneous total time series as a step function.
  320. This approach treats each replica's gauge as right-continuous, last-observation-
  321. carried-forward (LOCF), which matches gauge semantics. It produces an exact
  322. instantaneous total across replicas without bias from arbitrary windowing.
  323. Uses a k-way merge algorithm for O(n log k) complexity where k is the number
  324. of timeseries and n is the total number of events.
  325. Timestamps are rounded to 10ms precision (2 decimal places) and datapoints
  326. with the same rounded timestamp are combined, keeping the most recent value.
  327. Args:
  328. replicas_timeseries: List of time series, one per replica. Each time series
  329. is a list of TimeStampedValue objects sorted by timestamp.
  330. Returns:
  331. A list of TimeStampedValue representing the instantaneous total at event times.
  332. Between events, the total remains constant (step function). Timestamps are
  333. rounded to 10ms precision and duplicate timestamps are combined.
  334. """
  335. # Filter out empty timeseries
  336. active_series = [series for series in replicas_timeseries if series]
  337. if not active_series:
  338. return []
  339. if len(active_series) == 1:
  340. return active_series[0]
  341. # True k-way merge: heap maintains exactly k elements (one per series)
  342. # Each element is (timestamp, replica_id, iterator)
  343. merge_heap = []
  344. current_values = [0.0] * len(active_series) # Current value for each replica (LOCF)
  345. # Initialize heap with first element from each series
  346. for replica_idx, series in enumerate(active_series):
  347. if series: # Non-empty series
  348. iterator = iter(series)
  349. try:
  350. first_point = next(iterator)
  351. heapq.heappush(
  352. merge_heap,
  353. (first_point.timestamp, replica_idx, first_point.value, iterator),
  354. )
  355. except StopIteration:
  356. pass
  357. merged: TimeSeries = []
  358. running_total = 0.0
  359. while merge_heap:
  360. # Pop the earliest event (heap size stays ≤ k)
  361. timestamp, replica_idx, value, iterator = heapq.heappop(merge_heap)
  362. old_value = current_values[replica_idx]
  363. current_values[replica_idx] = value
  364. running_total += value - old_value
  365. # Try to get the next point from this replica's series and push it back
  366. try:
  367. next_point: TimeStampedValue = next(iterator)
  368. heapq.heappush(
  369. merge_heap,
  370. (next_point.timestamp, replica_idx, next_point.value, iterator),
  371. )
  372. except StopIteration:
  373. pass # This series is exhausted
  374. # Only add a point if the total actually changed
  375. if value != old_value: # Equivalent to new_total != old_total
  376. # Round timestamp to 10ms precision (2 decimal places)
  377. rounded_timestamp = round(timestamp, 2)
  378. # Check if we already have a point with this rounded timestamp
  379. # If so, update its value; otherwise, add a new point
  380. if merged and merged[-1].timestamp == rounded_timestamp:
  381. # Update the last point's value since timestamps match
  382. merged[-1] = TimeStampedValue(rounded_timestamp, running_total)
  383. else:
  384. # Add new point with rounded timestamp
  385. merged.append(TimeStampedValue(rounded_timestamp, running_total))
  386. return merged
  387. def merge_timeseries_dicts(
  388. *timeseries_dicts: DefaultDict[Hashable, TimeSeries],
  389. ) -> DefaultDict[Hashable, TimeSeries]:
  390. """
  391. Merge multiple time-series dictionaries using instantaneous merge approach.
  392. """
  393. merged: DefaultDict[Hashable, TimeSeries] = defaultdict(list)
  394. for ts_dict in timeseries_dicts:
  395. for key, ts in ts_dict.items():
  396. merged[key].append(ts)
  397. return {key: merge_instantaneous_total(ts_list) for key, ts_list in merged.items()}