batching.py 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009
  1. import asyncio
  2. import io
  3. import logging
  4. import time
  5. from collections import deque
  6. from dataclasses import dataclass
  7. from functools import wraps
  8. from inspect import isasyncgenfunction, iscoroutinefunction
  9. from typing import (
  10. Any,
  11. AsyncGenerator,
  12. Callable,
  13. Coroutine,
  14. Dict,
  15. Generic,
  16. Iterable,
  17. List,
  18. Literal,
  19. Optional,
  20. Protocol,
  21. Set,
  22. Tuple,
  23. TypeVar,
  24. overload,
  25. )
  26. from ray import serve
  27. from ray._common.signature import extract_signature, flatten_args, recover_args
  28. from ray._common.utils import get_or_create_event_loop
  29. from ray.serve._private.constants import (
  30. BATCH_EXECUTION_TIME_BUCKETS_MS,
  31. BATCH_SIZE_BUCKETS,
  32. BATCH_UTILIZATION_BUCKETS_PERCENT,
  33. BATCH_WAIT_TIME_BUCKETS_MS,
  34. SERVE_LOGGER_NAME,
  35. )
  36. from ray.serve._private.utils import extract_self_if_method_call
  37. from ray.serve.exceptions import RayServeException
  38. from ray.serve.metrics import Counter, Gauge, Histogram
  39. from ray.util.annotations import PublicAPI
  40. logger = logging.getLogger(SERVE_LOGGER_NAME)
  41. # The user can return these values in their streaming batch handler function to
  42. # indicate that a request is finished, so Serve can terminate the request.
  43. USER_CODE_STREAMING_SENTINELS = [StopIteration, StopAsyncIteration]
  44. @dataclass
  45. class _SingleRequest:
  46. self_arg: Any
  47. flattened_args: List[Any]
  48. future: asyncio.Future
  49. request_context: serve.context._RequestContext
  50. @dataclass
  51. class _GeneratorResult:
  52. result: Any
  53. next_future: asyncio.Future
  54. @dataclass
  55. class _RuntimeSummaryStatistics:
  56. start_times: List[float]
  57. @property
  58. def min_start_time(self) -> Optional[float]:
  59. return min(self.start_times) if self.start_times else None
  60. @property
  61. def mean_start_time(self) -> Optional[float]:
  62. return (
  63. sum(self.start_times) / len(self.start_times) if self.start_times else None
  64. )
  65. @property
  66. def max_start_time(self) -> Optional[float]:
  67. return max(self.start_times) if self.start_times else None
  68. @property
  69. def num_requests(self) -> int:
  70. return len(self.start_times)
  71. def _batch_args_kwargs(
  72. list_of_flattened_args: List[List[Any]],
  73. ) -> Tuple[Tuple[Any], Dict[Any, Any]]:
  74. """Batch a list of flatten args and returns regular args and kwargs"""
  75. # Ray's flatten arg format is a list with alternating key and values
  76. # e.g. args=(1, 2), kwargs={"key": "val"} got turned into
  77. # [None, 1, None, 2, "key", "val"]
  78. arg_lengths = {len(args) for args in list_of_flattened_args}
  79. assert (
  80. len(arg_lengths) == 1
  81. ), "All batch requests should have the same number of parameters."
  82. arg_length = arg_lengths.pop()
  83. batched_flattened_args = []
  84. for idx in range(arg_length):
  85. if idx % 2 == 0:
  86. batched_flattened_args.append(list_of_flattened_args[0][idx])
  87. else:
  88. batched_flattened_args.append(
  89. [item[idx] for item in list_of_flattened_args]
  90. )
  91. return recover_args(batched_flattened_args)
  92. class _BatchQueue:
  93. def __init__(
  94. self,
  95. max_batch_size: int,
  96. batch_wait_timeout_s: float,
  97. max_concurrent_batches: int,
  98. handle_batch_func: Optional[Callable] = None,
  99. batch_size_fn: Optional[Callable[[List], int]] = None,
  100. ) -> None:
  101. """Async queue that accepts individual items and returns batches.
  102. Respects max_batch_size and batch_wait_timeout_s; a batch will be returned when
  103. max_batch_size elements are available or the timeout has passed since
  104. the previous get.
  105. If handle_batch_func is passed in, a background coroutine will run to
  106. poll from the queue and call handle_batch_func on the results.
  107. Cannot be pickled.
  108. Arguments:
  109. max_batch_size: max number of elements to return in a batch.
  110. batch_wait_timeout_s: time to wait before returning an incomplete
  111. batch.
  112. max_concurrent_batches: max number of batches to run concurrently.
  113. handle_batch_func(Optional[Callable]): callback to run in the
  114. background to handle batches if provided.
  115. batch_size_fn(Optional[Callable[[List], int]]): optional function to
  116. compute the effective batch size. If None, uses len(batch).
  117. The function takes a list of requests and returns an integer
  118. representing the batch size. This is useful for batching based
  119. on custom metrics such as total nodes in graphs, total tokens
  120. in sequences, etc.
  121. """
  122. self.queue: asyncio.Queue[_SingleRequest] = asyncio.Queue()
  123. self.max_batch_size = max_batch_size
  124. self.batch_wait_timeout_s = batch_wait_timeout_s
  125. self.max_concurrent_batches = max_concurrent_batches
  126. self.batch_size_fn = batch_size_fn
  127. self.semaphore = asyncio.Semaphore(max_concurrent_batches)
  128. self.requests_available_event = asyncio.Event()
  129. self.tasks: Set[asyncio.Task] = set()
  130. # Used for observability.
  131. self.curr_iteration_start_times: Dict[asyncio.Task, float] = {}
  132. # Initialize batching metrics.
  133. self._batch_wait_time_histogram = Histogram(
  134. "serve_batch_wait_time_ms",
  135. description="Time requests waited for batch to fill (in milliseconds).",
  136. boundaries=BATCH_WAIT_TIME_BUCKETS_MS,
  137. tag_keys=("function_name",),
  138. )
  139. self._batch_execution_time_histogram = Histogram(
  140. "serve_batch_execution_time_ms",
  141. description="Time to execute the batch function (in milliseconds).",
  142. boundaries=BATCH_EXECUTION_TIME_BUCKETS_MS,
  143. tag_keys=("function_name",),
  144. )
  145. self._batch_queue_length_gauge = Gauge(
  146. "serve_batch_queue_length",
  147. description="Number of requests waiting in the batch queue.",
  148. tag_keys=("function_name",),
  149. )
  150. self._batch_utilization_histogram = Histogram(
  151. "serve_batch_utilization_percent",
  152. description="Batch utilization as percentage (actual_batch_size / max_batch_size * 100).",
  153. boundaries=BATCH_UTILIZATION_BUCKETS_PERCENT,
  154. tag_keys=("function_name",),
  155. )
  156. self._batch_size_histogram = Histogram(
  157. "serve_actual_batch_size",
  158. description="The actual number of requests in each batch.",
  159. boundaries=BATCH_SIZE_BUCKETS,
  160. tag_keys=("function_name",),
  161. )
  162. self._batches_processed_counter = Counter(
  163. "serve_batches_processed",
  164. description="Counter of batches executed.",
  165. tag_keys=("function_name",),
  166. )
  167. self._function_name = (
  168. handle_batch_func.__name__ if handle_batch_func is not None else "unknown"
  169. )
  170. self._handle_batch_task = None
  171. self._loop = get_or_create_event_loop()
  172. if handle_batch_func is not None:
  173. self._handle_batch_task = self._loop.create_task(
  174. self._process_batches(handle_batch_func)
  175. )
  176. self._warn_if_max_batch_size_exceeds_max_ongoing_requests()
  177. def _warn_if_max_batch_size_exceeds_max_ongoing_requests(self):
  178. """Helper to check whether the max_batch_size is bounded.
  179. Log a warning to configure `max_ongoing_requests` if it's bounded.
  180. """
  181. max_ongoing_requests = (
  182. serve.get_replica_context()._deployment_config.max_ongoing_requests
  183. )
  184. if max_ongoing_requests < self.max_batch_size * self.max_concurrent_batches:
  185. logger.warning(
  186. f"`max_batch_size` ({self.max_batch_size}) * `max_concurrent_batches` "
  187. f"({self.max_concurrent_batches}) is larger than `max_ongoing_requests` "
  188. f"({max_ongoing_requests}). This means the replica will never achieve "
  189. "the configured `max_batch_size` concurrently. Please update "
  190. "`max_ongoing_requests` to be >= `max_batch_size` * `max_concurrent_batches`."
  191. )
  192. def set_max_batch_size(self, new_max_batch_size: int) -> None:
  193. """Updates queue's max_batch_size."""
  194. self.max_batch_size = new_max_batch_size
  195. self._warn_if_max_batch_size_exceeds_max_ongoing_requests()
  196. def put(self, request: Tuple[_SingleRequest, asyncio.Future]) -> None:
  197. self.queue.put_nowait(request)
  198. self.requests_available_event.set()
  199. def _compute_batch_size(self, batch: List[_SingleRequest]) -> int:
  200. """Compute the effective batch size using batch_size_fn or len()."""
  201. if self.batch_size_fn is None:
  202. return len(batch)
  203. # Extract the actual data items from requests to pass to batch_size_fn.
  204. # We need to reconstruct the original arguments from flattened_args.
  205. items = []
  206. for request in batch:
  207. # Recover the original arguments from flattened format
  208. args, kwargs = recover_args(request.flattened_args)
  209. # The batch function expects a single positional argument (the item)
  210. # after 'self' has been extracted (if it was a method)
  211. items.append(args[0])
  212. return self.batch_size_fn(items)
  213. async def wait_for_batch(self) -> Tuple[List[_SingleRequest], int]:
  214. """Wait for batch respecting self.max_batch_size and self.timeout_s.
  215. Returns a tuple of (batch, computed_batch_size) where batch contains
  216. up to self.max_batch_size items. Waits for up to self.timeout_s after
  217. receiving the first request that will be in the next batch. After the
  218. timeout, returns as many items as are ready.
  219. Always returns a batch with at least one item - will block
  220. indefinitely until an item comes in.
  221. """
  222. batch = []
  223. first_item = await self.queue.get() # Block until first item arrives
  224. # Cache current max_batch_size and batch_wait_timeout_s for this batch.
  225. max_batch_size = self.max_batch_size
  226. batch_wait_timeout_s = self.batch_wait_timeout_s
  227. # Check if first item alone exceeds max_batch_size (only with batch_size_fn)
  228. if self.batch_size_fn is not None:
  229. first_item_size = self._compute_batch_size([first_item])
  230. if first_item_size > max_batch_size:
  231. exc = RuntimeError(
  232. "Size of item is greater than max_batch_size. "
  233. "Please increase the max_batch_size or check the "
  234. "implementation of the batch_size_fn."
  235. )
  236. # Set exception on the future so the caller receives it
  237. first_item.future.set_exception(exc)
  238. return [], 0
  239. batch.append(first_item)
  240. # Wait self.timeout_s seconds for new queue arrivals.
  241. batch_start_time = time.time()
  242. while True:
  243. # Record queue length metric.
  244. self._batch_queue_length_gauge.set(
  245. self.queue.qsize(), tags={"function_name": self._function_name}
  246. )
  247. remaining_batch_time_s = max(
  248. batch_wait_timeout_s - (time.time() - batch_start_time), 0
  249. )
  250. try:
  251. # Wait for new arrivals.
  252. await asyncio.wait_for(
  253. self.requests_available_event.wait(), remaining_batch_time_s
  254. )
  255. except asyncio.TimeoutError:
  256. pass
  257. # Custom batch size function logic
  258. if self.batch_size_fn is not None:
  259. # Add all new arrivals to the batch.
  260. # Track items we need to put back if they don't fit
  261. deferred_item = None
  262. while not self.queue.empty():
  263. next_item = self.queue.get_nowait()
  264. # Temporarily add to check size
  265. batch.append(next_item)
  266. new_size = self._compute_batch_size(batch)
  267. if new_size > max_batch_size:
  268. # Would exceed limit, remove it and save for later
  269. batch.pop()
  270. deferred_item = next_item
  271. break
  272. # Size is OK, keep it in the batch (already added above)
  273. # Put deferred item back in queue for next batch
  274. if deferred_item is not None:
  275. # NOTE: The deferred item goes to the back of the queue (FIFO),
  276. # so newer requests may be processed before it. Consider using
  277. # asyncio.PriorityQueue if strict ordering is required.
  278. self.queue.put_nowait(deferred_item)
  279. # Compute final batch size before breaking (batch is now valid
  280. # after popping the deferred item).
  281. current_batch_size = self._compute_batch_size(batch)
  282. # break the loop early because the deferred item is too large to fit in the batch
  283. break
  284. else:
  285. # Default behavior: use original len() check logic
  286. while len(batch) < max_batch_size and not self.queue.empty():
  287. batch.append(self.queue.get_nowait())
  288. # Only clear the put event if the queue is empty. If it's not empty
  289. # we can start constructing a new batch immediately in the next loop.
  290. # The code that puts items into the queue runs on the same event loop
  291. # as this code, so there's no race condition between the time we
  292. # get objects in the queue (and clear the event) and when objects
  293. # get added to the queue.
  294. if self.queue.empty():
  295. self.requests_available_event.clear()
  296. current_batch_size = self._compute_batch_size(batch)
  297. if (
  298. time.time() - batch_start_time >= batch_wait_timeout_s
  299. or current_batch_size >= max_batch_size
  300. ):
  301. break
  302. # Record batch wait time metric (time spent waiting for batch to fill).
  303. batch_wait_time_ms = (time.time() - batch_start_time) * 1000
  304. self._batch_wait_time_histogram.observe(
  305. batch_wait_time_ms, tags={"function_name": self._function_name}
  306. )
  307. return batch, current_batch_size
  308. def _validate_results(
  309. self, results: Iterable[Any], input_batch_length: int
  310. ) -> None:
  311. if len(results) != input_batch_length:
  312. raise RayServeException(
  313. "Batched function doesn't preserve batch size. "
  314. f"The input list has length {input_batch_length} but the "
  315. f"returned list has length {len(results)}."
  316. )
  317. async def _consume_func_generator(
  318. self,
  319. func_generator: AsyncGenerator,
  320. initial_futures: List[asyncio.Future],
  321. input_batch_length: int,
  322. ) -> None:
  323. """Consumes batch function generator.
  324. This function only runs if the function decorated with @serve.batch
  325. is a generator.
  326. """
  327. FINISHED_TOKEN = None
  328. try:
  329. futures = deque(initial_futures)
  330. assert len(futures) == input_batch_length
  331. async for results in func_generator:
  332. self._validate_results(results, input_batch_length)
  333. for idx in range(input_batch_length):
  334. result, future = results[idx], futures[0]
  335. if future is FINISHED_TOKEN:
  336. # This caller has already terminated.
  337. futures.append(FINISHED_TOKEN)
  338. elif result in USER_CODE_STREAMING_SENTINELS:
  339. # User's code returned sentinel. No values left
  340. # for caller. Terminate iteration for caller.
  341. _set_exception_if_not_done(future, StopAsyncIteration)
  342. futures.append(FINISHED_TOKEN)
  343. else:
  344. next_future = get_or_create_event_loop().create_future()
  345. _set_result_if_not_done(
  346. future, _GeneratorResult(result, next_future)
  347. )
  348. futures.append(next_future)
  349. # Remove processed future. We remove the future at the very
  350. # end of the loop to ensure that if an exception occurs,
  351. # all pending futures will get set in the `except` block.
  352. futures.popleft()
  353. for future in futures:
  354. if future is not FINISHED_TOKEN:
  355. _set_exception_if_not_done(future, StopAsyncIteration)
  356. except Exception as e:
  357. for future in futures:
  358. if future is not FINISHED_TOKEN:
  359. _set_exception_if_not_done(future, e)
  360. async def _assign_func_results(
  361. self,
  362. func_future: asyncio.Future,
  363. futures: List[asyncio.Future],
  364. input_batch_length: int,
  365. ):
  366. """Assigns func's results to the list of futures."""
  367. try:
  368. results = await func_future
  369. self._validate_results(results, input_batch_length)
  370. for result, future in zip(results, futures):
  371. _set_result_if_not_done(future, result)
  372. except Exception as e:
  373. for future in futures:
  374. _set_exception_if_not_done(future, e)
  375. def _split_batch_by_model_id(
  376. self, batch: List[_SingleRequest]
  377. ) -> List[List[_SingleRequest]]:
  378. """Split a batch into sub-batches based on multiplexed_model_id.
  379. When using model multiplexing with batching, requests for different models
  380. may end up in the same batch. This method ensures that each sub-batch only
  381. contains requests for the same model, preventing issues where a single batch
  382. contains requests for different models.
  383. If no requests have a multiplexed_model_id set, returns the original batch
  384. as a single sub-batch.
  385. Args:
  386. batch: The batch of requests to split.
  387. Returns:
  388. A list of sub-batches, where each sub-batch contains requests for the
  389. same multiplexed_model_id.
  390. """
  391. # Group requests by their multiplexed_model_id
  392. model_id_to_requests: Dict[str, List[_SingleRequest]] = {}
  393. for request in batch:
  394. model_id = request.request_context.multiplexed_model_id
  395. if model_id not in model_id_to_requests:
  396. model_id_to_requests[model_id] = []
  397. model_id_to_requests[model_id].append(request)
  398. # Return sub-batches for each model_id
  399. return list(model_id_to_requests.values())
  400. async def _process_batches(self, func: Callable) -> None:
  401. """Loops infinitely and processes queued request batches."""
  402. # When asyncio task is created, the task will inherit the request context from the current context.
  403. # So we unset the request context so the current context is not inherited by the task, _process_batch.
  404. serve.context._unset_request_context()
  405. while not self._loop.is_closed():
  406. batch, _ = await self.wait_for_batch()
  407. # Split batch by multiplexed_model_id to ensure requests for different
  408. # models are processed in separate batches. This is necessary when using
  409. # model multiplexing with batching, as a single batch containing requests
  410. # for different models would not work correctly.
  411. sub_batches = self._split_batch_by_model_id(batch)
  412. # Process all sub-batches together under a single semaphore permit.
  413. # This ensures sub-batches from the same original batch run concurrently
  414. # rather than being serialized by the semaphore.
  415. promise = self._process_sub_batches(func, sub_batches)
  416. task = asyncio.create_task(promise)
  417. self.tasks.add(task)
  418. self.curr_iteration_start_times[task] = time.time()
  419. task.add_done_callback(self._handle_completed_task)
  420. async def _process_sub_batches(
  421. self, func: Callable, sub_batches: List[List[_SingleRequest]]
  422. ) -> None:
  423. """Processes multiple sub-batches concurrently under a single semaphore permit.
  424. This method acquires the semaphore once and then processes all sub-batches
  425. in parallel, ensuring that sub-batches from the same original batch don't
  426. compete for semaphore permits.
  427. """
  428. # NOTE: this semaphore caps the number of concurrent batches specified by `max_concurrent_batches`
  429. async with self.semaphore:
  430. # Create tasks for each sub-batch. We use asyncio.create_task() instead
  431. # of passing coroutines directly to asyncio.gather() because create_task
  432. # copies the current context, giving each sub-batch its own isolated
  433. # contextvars. This prevents concurrent sub-batches from overwriting
  434. # each other's _serve_batch_request_context, which would cause
  435. # get_multiplexed_model_id() to return wrong values.
  436. tasks = [
  437. asyncio.create_task(self._process_batch_inner(func, sub_batch))
  438. for sub_batch in sub_batches
  439. ]
  440. await asyncio.gather(*tasks)
  441. async def _process_batch_inner(
  442. self, func: Callable, batch: List[_SingleRequest]
  443. ) -> None:
  444. """Processes a single batch without acquiring the semaphore.
  445. This is the inner implementation called by _process_sub_batches after
  446. the semaphore has already been acquired.
  447. """
  448. # Remove requests that have been cancelled from the batch. If
  449. # all requests have been cancelled, simply return and wait for
  450. # the next batch.
  451. batch = [req for req in batch if not req.future.cancelled()]
  452. if len(batch) == 0:
  453. return
  454. # Compute batch size for this sub-batch. Each sub-batch may have a different
  455. # size, especially when splitting by model_id, so we compute it here.
  456. computed_batch_size = self._compute_batch_size(batch)
  457. # Calculate and record batch utilization percentage.
  458. batch_utilization_percent = (computed_batch_size / self.max_batch_size) * 100
  459. self._batch_utilization_histogram.observe(
  460. batch_utilization_percent, tags={"function_name": self._function_name}
  461. )
  462. # Record actual batch size (number of requests in the batch computed by the batch_size_fn).
  463. self._batch_size_histogram.observe(
  464. computed_batch_size, tags={"function_name": self._function_name}
  465. )
  466. # Increment batches processed counter.
  467. self._batches_processed_counter.inc(tags={"function_name": self._function_name})
  468. futures = [item.future for item in batch]
  469. # Most of the logic in the function should be wrapped in this try-
  470. # except block, so the futures' exceptions can be set if an exception
  471. # occurs. Otherwise, the futures' requests may hang indefinitely.
  472. batch_execution_start_time = time.time()
  473. try:
  474. self_arg = batch[0].self_arg
  475. args, kwargs = _batch_args_kwargs([item.flattened_args for item in batch])
  476. # Method call.
  477. if self_arg is not None:
  478. func_future_or_generator = func(self_arg, *args, **kwargs)
  479. # Normal function call.
  480. else:
  481. func_future_or_generator = func(*args, **kwargs)
  482. # Add individual request context to the batch request context
  483. serve.context._set_batch_request_context(
  484. [req.request_context for req in batch]
  485. )
  486. if isasyncgenfunction(func):
  487. func_generator = func_future_or_generator
  488. await self._consume_func_generator(func_generator, futures, len(batch))
  489. else:
  490. func_future = func_future_or_generator
  491. await self._assign_func_results(func_future, futures, len(batch))
  492. # Reset the batch request context after the batch is processed
  493. serve.context._set_batch_request_context([])
  494. except Exception as e:
  495. logger.exception("_process_batch ran into an unexpected exception.")
  496. for future in futures:
  497. _set_exception_if_not_done(future, e)
  498. finally:
  499. # Record batch execution time.
  500. batch_execution_time_ms = (time.time() - batch_execution_start_time) * 1000
  501. self._batch_execution_time_histogram.observe(
  502. batch_execution_time_ms, tags={"function_name": self._function_name}
  503. )
  504. def _handle_completed_task(self, task: asyncio.Task) -> None:
  505. self.tasks.remove(task)
  506. del self.curr_iteration_start_times[task]
  507. self._log_if_exception(task.exception())
  508. @staticmethod
  509. def _log_if_exception(exception_maybe: Optional[BaseException]) -> None:
  510. if exception_maybe is not None:
  511. if isinstance(exception_maybe, asyncio.CancelledError):
  512. logger.debug("Task was cancelled")
  513. else:
  514. logger.exception("Task failed unexpectedly")
  515. def __del__(self):
  516. if (
  517. self._handle_batch_task is None
  518. or not get_or_create_event_loop().is_running()
  519. ):
  520. return
  521. # TODO(edoakes): although we try to gracefully shutdown here, it still
  522. # causes some errors when the process exits due to the asyncio loop
  523. # already being destroyed.
  524. self._handle_batch_task.cancel()
  525. class _LazyBatchQueueWrapper:
  526. """Stores a _BatchQueue and updates its settings.
  527. _BatchQueue cannot be pickled, you must construct it lazily
  528. at runtime inside a replica. This class initializes a queue only upon
  529. first access.
  530. """
  531. def __init__(
  532. self,
  533. max_batch_size: int = 10,
  534. batch_wait_timeout_s: float = 0.0,
  535. max_concurrent_batches: int = 1,
  536. handle_batch_func: Optional[Callable] = None,
  537. batch_size_fn: Optional[Callable[[List], int]] = None,
  538. ):
  539. self._queue: Optional[_BatchQueue] = None
  540. self.max_batch_size = max_batch_size
  541. self.batch_wait_timeout_s = batch_wait_timeout_s
  542. self.max_concurrent_batches = max_concurrent_batches
  543. self.handle_batch_func = handle_batch_func
  544. self.batch_size_fn = batch_size_fn
  545. @property
  546. def queue(self) -> _BatchQueue:
  547. """Returns _BatchQueue.
  548. Initializes queue when called for the first time.
  549. """
  550. if self._queue is None:
  551. self._queue = _BatchQueue(
  552. self.max_batch_size,
  553. self.batch_wait_timeout_s,
  554. self.max_concurrent_batches,
  555. self.handle_batch_func,
  556. self.batch_size_fn,
  557. )
  558. return self._queue
  559. def set_max_batch_size(self, new_max_batch_size: int) -> None:
  560. """Updates queue's max_batch_size."""
  561. self.max_batch_size = new_max_batch_size
  562. if self._queue is not None:
  563. self._queue.set_max_batch_size(new_max_batch_size)
  564. def set_batch_wait_timeout_s(self, new_batch_wait_timeout_s: float) -> None:
  565. self.batch_wait_timeout_s = new_batch_wait_timeout_s
  566. if self._queue is not None:
  567. self._queue.batch_wait_timeout_s = new_batch_wait_timeout_s
  568. def get_max_batch_size(self) -> int:
  569. return self.max_batch_size
  570. def get_batch_wait_timeout_s(self) -> float:
  571. return self.batch_wait_timeout_s
  572. def _get_curr_iteration_start_times(self) -> _RuntimeSummaryStatistics:
  573. """Gets summary statistics of current iteration's start times."""
  574. return _RuntimeSummaryStatistics(
  575. list(self.queue.curr_iteration_start_times.values())
  576. )
  577. async def _is_batching_task_alive(self) -> bool:
  578. """Gets whether default _BatchQueue's background task is alive.
  579. Returns False if the batch handler doesn't use a default _BatchQueue.
  580. """
  581. if hasattr(self.queue, "_handle_batch_task"):
  582. return not self.queue._handle_batch_task.done()
  583. else:
  584. return False
  585. async def _get_handling_task_stack(self) -> Optional[str]:
  586. """Gets the stack for the default _BatchQueue's background task.
  587. Returns empty string if the batch handler doesn't use a default _BatchQueue.
  588. """
  589. if hasattr(self.queue, "_handle_batch_task"):
  590. str_buffer = io.StringIO()
  591. self.queue._handle_batch_task.print_stack(file=str_buffer)
  592. return str_buffer.getvalue()
  593. else:
  594. return None
  595. def _validate_max_batch_size(max_batch_size):
  596. if not isinstance(max_batch_size, int):
  597. if isinstance(max_batch_size, float) and max_batch_size.is_integer():
  598. max_batch_size = int(max_batch_size)
  599. else:
  600. raise TypeError(
  601. f"max_batch_size must be integer >= 1, got {max_batch_size}"
  602. )
  603. if max_batch_size < 1:
  604. raise ValueError(
  605. f"max_batch_size must be an integer >= 1, got {max_batch_size}"
  606. )
  607. def _validate_batch_wait_timeout_s(batch_wait_timeout_s):
  608. if not isinstance(batch_wait_timeout_s, (float, int)):
  609. raise TypeError(
  610. f"batch_wait_timeout_s must be a float >= 0, got {batch_wait_timeout_s}"
  611. )
  612. if batch_wait_timeout_s < 0:
  613. raise ValueError(
  614. f"batch_wait_timeout_s must be a float >= 0, got {batch_wait_timeout_s}"
  615. )
  616. def _validate_max_concurrent_batches(max_concurrent_batches: int) -> None:
  617. if not isinstance(max_concurrent_batches, int) or max_concurrent_batches < 1:
  618. raise TypeError(
  619. f"max_concurrent_batches must be an integer >= 1, got {max_concurrent_batches}"
  620. )
  621. def _validate_batch_size_fn(batch_size_fn: Optional[Callable[[List], int]]) -> None:
  622. if batch_size_fn is not None and not callable(batch_size_fn):
  623. raise TypeError(
  624. f"batch_size_fn must be a callable or None, got {type(batch_size_fn)}"
  625. )
  626. SelfType = TypeVar("SelfType", contravariant=True)
  627. T = TypeVar("T")
  628. R = TypeVar("R")
  629. class _SyncBatchingMethod(Protocol, Generic[SelfType, T, R]):
  630. def __call__(self, self_: SelfType, __batch: List[T], /) -> List[R]:
  631. ...
  632. class _AsyncBatchingMethod(Protocol, Generic[SelfType, T, R]):
  633. async def __call__(self, self_: SelfType, __batch: List[T], /) -> List[R]:
  634. ...
  635. @overload # Sync function for `batch` called WITHOUT arguments
  636. def batch(_sync_func: Callable[[List[T]], List[R]], /) -> Callable[[T], R]:
  637. ...
  638. @overload # Async function for `batch` called WITHOUT arguments
  639. def batch(
  640. _async_func: Callable[[List[T]], Coroutine[Any, Any, List[R]]], /
  641. ) -> Callable[[T], Coroutine[Any, Any, R]]:
  642. ...
  643. @overload # Sync method for `batch` called WITHOUT arguments
  644. def batch(
  645. _sync_meth: _SyncBatchingMethod[SelfType, T, R], /
  646. ) -> Callable[[SelfType, T], R]:
  647. ...
  648. @overload # Async method for `batch` called WITHOUT arguments
  649. def batch(
  650. _async_meth: _AsyncBatchingMethod[SelfType, T, R], /
  651. ) -> Callable[[SelfType, T], Coroutine[Any, Any, R]]:
  652. ...
  653. @overload # `batch` called WITH arguments
  654. def batch(
  655. _: Literal[None] = None,
  656. /,
  657. max_batch_size: int = 10,
  658. batch_wait_timeout_s: float = 0.01,
  659. max_concurrent_batches: int = 1,
  660. batch_size_fn: Optional[Callable[[List], int]] = None,
  661. ) -> "_BatchDecorator":
  662. ...
  663. class _BatchDecorator(Protocol):
  664. """Descibes behaviour of decorator produced by calling `batch` with arguments"""
  665. @overload # Sync function
  666. def __call__(self, _sync_func: Callable[[List[T]], List[R]], /) -> Callable[[T], R]:
  667. ...
  668. @overload # Async function
  669. def __call__(
  670. self, _async_func: Callable[[List[T]], Coroutine[Any, Any, List[R]]], /
  671. ) -> Callable[[T], Coroutine[Any, Any, R]]:
  672. ...
  673. @overload # Sync method
  674. def __call__(
  675. self, _sync_meth: _SyncBatchingMethod[SelfType, T, R], /
  676. ) -> Callable[[SelfType, T], R]:
  677. ...
  678. @overload # Async method
  679. def __call__(
  680. self, _async_meth: _AsyncBatchingMethod[SelfType, T, R], /
  681. ) -> Callable[[SelfType, T], Coroutine[Any, Any, R]]:
  682. ...
  683. @PublicAPI(stability="stable")
  684. def batch(
  685. _func: Optional[Callable] = None,
  686. /,
  687. max_batch_size: int = 10,
  688. batch_wait_timeout_s: float = 0.01,
  689. max_concurrent_batches: int = 1,
  690. batch_size_fn: Optional[Callable[[List], int]] = None,
  691. ) -> Callable:
  692. """Converts a function to asynchronously handle batches.
  693. The function can be a standalone function or a class method. In both
  694. cases, the function must be `async def` and take a list of objects as
  695. its sole argument and return a list of the same length as a result.
  696. When invoked, the caller passes a single object. These will be batched
  697. and executed asynchronously once there is a batch of `max_batch_size`
  698. or `batch_wait_timeout_s` has elapsed, whichever occurs first.
  699. `max_batch_size` and `batch_wait_timeout_s` can be updated using setter
  700. methods from the batch_handler (`set_max_batch_size` and
  701. `set_batch_wait_timeout_s`).
  702. Example:
  703. .. code-block:: python
  704. from ray import serve
  705. from starlette.requests import Request
  706. @serve.deployment
  707. class BatchedDeployment:
  708. @serve.batch(max_batch_size=10, batch_wait_timeout_s=0.1)
  709. async def batch_handler(self, requests: List[Request]) -> List[str]:
  710. response_batch = []
  711. for r in requests:
  712. name = (await requests.json())["name"]
  713. response_batch.append(f"Hello {name}!")
  714. return response_batch
  715. def update_batch_params(self, max_batch_size, batch_wait_timeout_s):
  716. self.batch_handler.set_max_batch_size(max_batch_size)
  717. self.batch_handler.set_batch_wait_timeout_s(batch_wait_timeout_s)
  718. async def __call__(self, request: Request):
  719. return await self.batch_handler(request)
  720. app = BatchedDeployment.bind()
  721. Arguments:
  722. max_batch_size: the maximum batch size that will be executed in
  723. one call to the underlying function.
  724. batch_wait_timeout_s: the maximum duration to wait for
  725. `max_batch_size` elements before running the current batch.
  726. max_concurrent_batches: the maximum number of batches that can be
  727. executed concurrently. If the number of concurrent batches exceeds
  728. this limit, the batch handler will wait for a batch to complete
  729. before sending the next batch to the underlying function.
  730. batch_size_fn: optional function to compute the effective batch size.
  731. If provided, this function takes a list of items and returns an
  732. integer representing the batch size. This is useful for batching
  733. based on custom metrics such as total nodes in graphs, total tokens
  734. in sequences, or other domain-specific measures. If None, the batch
  735. size is computed as len(batch).
  736. """
  737. # `_func` will be None in the case when the decorator is parametrized.
  738. # See the comment at the end of this function for a detailed explanation.
  739. if _func is not None:
  740. if not callable(_func):
  741. raise TypeError(
  742. "@serve.batch can only be used to decorate functions or methods."
  743. )
  744. if not iscoroutinefunction(_func):
  745. raise TypeError("Functions decorated with @serve.batch must be 'async def'")
  746. _validate_max_batch_size(max_batch_size)
  747. _validate_batch_wait_timeout_s(batch_wait_timeout_s)
  748. _validate_max_concurrent_batches(max_concurrent_batches)
  749. _validate_batch_size_fn(batch_size_fn)
  750. def _batch_decorator(_func):
  751. lazy_batch_queue_wrapper = _LazyBatchQueueWrapper(
  752. max_batch_size,
  753. batch_wait_timeout_s,
  754. max_concurrent_batches,
  755. _func,
  756. batch_size_fn,
  757. )
  758. async def batch_handler_generator(
  759. first_future: asyncio.Future,
  760. ) -> AsyncGenerator:
  761. """Generator that handles generator batch functions."""
  762. future = first_future
  763. while True:
  764. try:
  765. async_response: _GeneratorResult = await future
  766. future = async_response.next_future
  767. yield async_response.result
  768. except StopAsyncIteration:
  769. break
  770. def enqueue_request(args, kwargs) -> asyncio.Future:
  771. flattened_args: List = flatten_args(extract_signature(_func), args, kwargs)
  772. # If the function is a method, remove self as an argument.
  773. self = extract_self_if_method_call(args, _func)
  774. if self is not None:
  775. flattened_args = flattened_args[2:]
  776. batch_queue = lazy_batch_queue_wrapper.queue
  777. future = get_or_create_event_loop().create_future()
  778. request_context = serve.context._get_serve_request_context()
  779. batch_queue.put(
  780. _SingleRequest(self, flattened_args, future, request_context)
  781. )
  782. return future
  783. @wraps(_func)
  784. def generator_batch_wrapper(*args, **kwargs):
  785. first_future = enqueue_request(args, kwargs)
  786. return batch_handler_generator(first_future)
  787. @wraps(_func)
  788. async def batch_wrapper(*args, **kwargs):
  789. # This will raise if the underlying call raised an exception.
  790. return await enqueue_request(args, kwargs)
  791. if isasyncgenfunction(_func):
  792. wrapper = generator_batch_wrapper
  793. else:
  794. wrapper = batch_wrapper
  795. # We store the lazy_batch_queue_wrapper's getters and setters as
  796. # batch_wrapper attributes, so they can be accessed in user code.
  797. wrapper._get_max_batch_size = lazy_batch_queue_wrapper.get_max_batch_size
  798. wrapper._get_batch_wait_timeout_s = (
  799. lazy_batch_queue_wrapper.get_batch_wait_timeout_s
  800. )
  801. wrapper.set_max_batch_size = lazy_batch_queue_wrapper.set_max_batch_size
  802. wrapper.set_batch_wait_timeout_s = (
  803. lazy_batch_queue_wrapper.set_batch_wait_timeout_s
  804. )
  805. # Store debugging methods in the lazy_batch_queue wrapper
  806. wrapper._get_curr_iteration_start_times = (
  807. lazy_batch_queue_wrapper._get_curr_iteration_start_times
  808. )
  809. wrapper._is_batching_task_alive = (
  810. lazy_batch_queue_wrapper._is_batching_task_alive
  811. )
  812. wrapper._get_handling_task_stack = (
  813. lazy_batch_queue_wrapper._get_handling_task_stack
  814. )
  815. return wrapper
  816. # Unfortunately, this is required to handle both non-parametrized
  817. # (@serve.batch) and parametrized (@serve.batch(**kwargs)) usage.
  818. # In the former case, `serve.batch` will be called with the underlying
  819. # function as the sole argument. In the latter case, it will first be
  820. # called with **kwargs, then the result of that call will be called
  821. # with the underlying function as the sole argument (i.e., it must be a
  822. # "decorator factory.").
  823. return _batch_decorator(_func) if callable(_func) else _batch_decorator
  824. def _set_result_if_not_done(future: asyncio.Future, result: Any):
  825. """Sets the future's result if the future is not done."""
  826. if not future.done():
  827. future.set_result(result)
  828. def _set_exception_if_not_done(future: asyncio.Future, exception: Any):
  829. """Sets the future's exception if the future is not done."""
  830. if not future.done():
  831. future.set_exception(exception)