scheduler.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # Copyright 2025 The HuggingFace Inc. team
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import threading
  15. from abc import ABC, abstractmethod
  16. from collections import deque
  17. from ...utils.metrics import attach_tracer, traced
  18. from .cache import PagedAttentionCache
  19. from .requests import FutureRequestState, RequestState, RequestStatus, logger
  20. class Scheduler(ABC):
  21. """
  22. Abstract base class for scheduling requests in the continuous batch processor. Schedulers manage the lifecycle of
  23. requests from when they are added to the waiting queue to when they are scheduled for processing. Different
  24. schedulers implement different strategies for prioritizing and batching requests.
  25. """
  26. def __init__(self, cache: PagedAttentionCache):
  27. self.cache = cache
  28. self._cancellation_lock = threading.Lock()
  29. # This is to compute the read cache used by a new request being scheduled
  30. self.read_cache_limit = None if self.cache.num_full_attention_groups else self.cache.config.sliding_window
  31. self.max_decode_fast_path_length = self.cache.max_blocks_per_request * self.cache.block_size
  32. # Initialize mutable states via reset()
  33. self.reset()
  34. def reset(self) -> None:
  35. """Reset scheduler state for a new generation loop."""
  36. self.active_requests: dict[str, RequestState] = {}
  37. self.waiting_requests: dict[str, RequestState] = {}
  38. self.waiting_requests_order: deque[str] = deque()
  39. self._requests_to_cancel: set[str] = set()
  40. self._requests_to_fork: list[RequestState] = []
  41. self.block_new_requests = False
  42. @traced
  43. def add_waiting_request(self, state: RequestState):
  44. """Adds a request to the waiting list."""
  45. self.waiting_requests[state.request_id] = state
  46. self.waiting_requests_order.append(state.request_id)
  47. @abstractmethod
  48. def schedule_batch(
  49. self, token_budget: int, cache_budget: int
  50. ) -> tuple[list[FutureRequestState] | None, bool, int, int]:
  51. """Schedules requests for the next batch based on available token and cache budgets. This method selects which
  52. requests should be processed in the current batch, considering the budgets and the scheduler's prioritization
  53. rules. The token_budget is the maximum number of tokens that can be processed in a batch, and the cache_budget
  54. is the maximum number of KV cache entries that can be read in a batch.
  55. Returns the list of scheduled requests in their "FutureRequestState" form, a boolean indicating if the decode
  56. fast path can be used, the total number of query tokens and the maximum number of kv tokens read."""
  57. @traced
  58. def has_pending_requests(self) -> bool:
  59. """Checks if there are requests ready to be processed."""
  60. return bool(len(self.active_requests) or len(self.waiting_requests))
  61. @traced
  62. def finish_request(self, request_id: str) -> None:
  63. """Completes processing of a request and frees its allocated cache blocks. This method is called
  64. when a request has finished generation or encountered an error.
  65. """
  66. self.cache.free_blocks(request_id)
  67. self.active_requests.pop(request_id, None)
  68. @traced
  69. def get_active_request_static_outputs(self, request_id: str) -> list[int]:
  70. """Gets generated tokens for an active request."""
  71. if request_id in self.active_requests:
  72. return self.active_requests[request_id].generated_tokens
  73. return []
  74. @traced
  75. def set_request_cancellation(self, request_id: str):
  76. """Marks a request for cancellation."""
  77. with self._cancellation_lock:
  78. self._requests_to_cancel.add(request_id)
  79. @traced
  80. def clear_cancelled_requests(self):
  81. """Remove all cancelled requests from active and waiting queues."""
  82. with self._cancellation_lock:
  83. for request_id in self._requests_to_cancel:
  84. self.active_requests.pop(request_id, None)
  85. self.waiting_requests.pop(request_id, None)
  86. if request_id in self.waiting_requests_order:
  87. self.waiting_requests_order.remove(request_id)
  88. self.cache.free_blocks(request_id)
  89. self._requests_to_cancel = set()
  90. @traced
  91. def request_is_cancelled(self, request_id: str) -> bool:
  92. """Checks if a request has been cancelled or removed."""
  93. return request_id in self._requests_to_cancel or (
  94. request_id not in self.active_requests and request_id not in self.waiting_requests
  95. )
  96. @traced
  97. def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool:
  98. """Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to
  99. accommodate the next tokens. It calculates how many blocks are needed based on the request's current
  100. cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator
  101. objects. Returns a boolean indicating if the allocation was successful or not.
  102. """
  103. # 1. we check that the occupancy is less than the requested length
  104. # 2. we allocate enough blocks to cover the requested length
  105. current_len = state.current_len()
  106. occupancy = state.allocated_blocks * self.cache.block_size - current_len
  107. if occupancy < len_next_tokens or state.allocated_blocks == 0:
  108. blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
  109. allocated = self.cache.allocate_blocks(blocks_needed, state.request_id, state.allocated_blocks)
  110. if allocated is None:
  111. return False
  112. state.allocated_blocks += allocated
  113. return True
  114. def _infer_request_tokens(self, state: RequestState, request_ids_to_remove_from_waiting: set[str]) -> list[int]:
  115. """Prepares a request for processing in the current batch. If prefix sharing is enabled, and the request was
  116. pending, this is where we look for a prefix match and split the request if found."""
  117. # If prefix sharing is enabled, we look for a prefix match and split the request if found
  118. if self.cache.use_prefix_sharing and state.status == RequestStatus.PENDING:
  119. prefill_length = self.cache.search_prefix_match(state.request_id, state.remaining_prefill_tokens)
  120. if prefill_length > 0:
  121. self.active_requests[state.request_id] = state
  122. request_ids_to_remove_from_waiting.add(state.request_id)
  123. state.status = RequestStatus.PREFILLING
  124. # We keep track of the number of allocated blocks to avoid double allocation
  125. state.allocated_blocks += prefill_length // self.cache.block_size
  126. # Even if we match the whole request, we keep at least 1 token to start decoding
  127. prefill_length = min(prefill_length, len(state.remaining_prefill_tokens) - 1)
  128. state.remaining_prefill_tokens = state.remaining_prefill_tokens[prefill_length:]
  129. state.position_offset += prefill_length
  130. # If the request is decoding, the tokens to process are already set
  131. if state.status == RequestStatus.DECODING:
  132. request_tokens = state.tokens_to_process
  133. # Otherwise, the tokens to process are the remaining prefill tokens
  134. else:
  135. request_tokens = state.remaining_prefill_tokens
  136. return request_tokens
  137. def _schedule_request(
  138. self,
  139. state: RequestState,
  140. request_tokens: list[int],
  141. token_budget: int,
  142. request_ids_to_remove_from_waiting: set[str],
  143. ) -> None:
  144. """Schedules a request for the current batch, updating the request's status according to the token budget left.
  145. After a request is scheduled, it is part of the next batch unless there is an error.
  146. If the request has children (for parallel decoding), it ensures at least one token remains before the request is
  147. forked."""
  148. # If the request has one or more children we make sure not to prefill it entirely
  149. # This does not check the request state, but DECODING request already have children set to 0.
  150. if state.num_children > 0 and token_budget >= len(request_tokens) - 1:
  151. token_budget = len(request_tokens) - 1
  152. self._requests_to_fork.append(state)
  153. # Case: we can process the entire prompt/remainder
  154. if len(request_tokens) <= token_budget:
  155. if state.status == RequestStatus.PENDING:
  156. self.active_requests[state.request_id] = state
  157. request_ids_to_remove_from_waiting.add(state.request_id)
  158. if state.status <= RequestStatus.PREFILLING:
  159. state.tokens_to_process = state.remaining_prefill_tokens
  160. state.remaining_prefill_tokens = []
  161. # Although prefill will only be done after the batch being scheduled now, we set the status to DECODING
  162. # to stay coherent when using asynchronous batching
  163. state.status = RequestStatus.DECODING
  164. # Otherwise: we need to split the request
  165. else:
  166. if state.status == RequestStatus.PENDING:
  167. self.active_requests[state.request_id] = state
  168. state.status = RequestStatus.PREFILLING
  169. request_ids_to_remove_from_waiting.add(state.request_id)
  170. state.remaining_prefill_tokens = request_tokens[token_budget:]
  171. state.tokens_to_process = request_tokens[:token_budget]
  172. def _process_candidates(
  173. self,
  174. candidates: list[RequestState],
  175. token_budget: int,
  176. cache_budget: int,
  177. request_ids_to_remove_from_waiting: set[str],
  178. safety_margin: float = 0.0,
  179. ) -> tuple[list[FutureRequestState], bool, bool, int, int]:
  180. """Schedules candidate requests for the current batch.
  181. This method contains the common logic shared by all schedulers: it checks token and cache budgets, allocates
  182. cache blocks if needed, updates request states, and tracks which waiting requests should be removed from the
  183. waiting queue.
  184. """
  185. scheduled_requests = []
  186. one_allocation_failed = False
  187. decode_fast_path = True
  188. safety_margins = safety_margin * self.cache.num_blocks
  189. original_token_budget, original_cache_budget = token_budget, cache_budget
  190. for state in candidates:
  191. num_free_blocks = self.cache.get_num_free_blocks()
  192. # If we are out the safety margin, we only accept decoding requests or the first prefill request
  193. outside_safety_margin = num_free_blocks < safety_margins
  194. if outside_safety_margin and scheduled_requests and state.status != RequestStatus.DECODING:
  195. logger.info(
  196. f"Outside safety margin, breaking out of scheduling loop. {num_free_blocks = } {safety_margins = }"
  197. )
  198. break
  199. # Check cache budget
  200. read_cache_needed = state.current_len()
  201. if self.read_cache_limit is not None:
  202. read_cache_needed = min(read_cache_needed, self.read_cache_limit)
  203. if cache_budget < read_cache_needed:
  204. continue
  205. # Infer the tokens that will be present in the batch if token budget is enough
  206. request_tokens = self._infer_request_tokens(state, request_ids_to_remove_from_waiting)
  207. # Account for token budget
  208. request_len = min(len(request_tokens), token_budget)
  209. # Check there will be enough cache for the new tokens
  210. allocation_successful = self._allocate_blocks_if_needed(state, request_len)
  211. # If the allocation would not be successful, we move on to the next request
  212. if not allocation_successful:
  213. one_allocation_failed = True
  214. # If we reached a waiting request and the cache is full, all subsequent waiting requests will need
  215. # allocation as well, so we can safely break out of the scheduling loop.
  216. if num_free_blocks == 0 and state.request_id in self.waiting_requests:
  217. logger.info(f"Breaking mid-loop for request {state.request_id} because the cache is full")
  218. break
  219. continue
  220. # If this point is reached, it means we can safely schedule the request
  221. self._schedule_request(state, request_tokens, token_budget, request_ids_to_remove_from_waiting)
  222. request_len = len(state.tokens_to_process) # it may change after scheduling
  223. # The decode fast path is only used if the request is a single token and its length is less than the max blocks per request
  224. decode_fast_path &= request_len == 1 and state.position_offset < self.max_decode_fast_path_length
  225. # Update the token and cache budgets
  226. token_budget -= request_len
  227. cache_budget -= read_cache_needed
  228. # If using prefix sharing, we make note of the blocks that will be computed in the forward pass
  229. if self.cache.allow_block_sharing:
  230. tokens_in_current_block = state.current_len() % self.cache.block_size
  231. tokens_after_forward = tokens_in_current_block + request_len
  232. complete_blocks = tokens_after_forward // self.cache.block_size
  233. else:
  234. complete_blocks = 0
  235. # Store the future request state
  236. has_new_token = not state.remaining_prefill_tokens
  237. scheduled_requests.append(FutureRequestState(state, has_new_token, complete_blocks))
  238. # Remove the request from the waiting queue and mark it as removed
  239. req_id = state.request_id
  240. was_waiting = self.waiting_requests.pop(req_id, None) is not None
  241. if was_waiting:
  242. request_ids_to_remove_from_waiting.add(req_id)
  243. # Early exit of the loop if we have no budget left
  244. if token_budget == 0 or cache_budget == 0:
  245. break
  246. num_q_tokens = original_token_budget - token_budget
  247. max_kv_read = original_cache_budget - cache_budget
  248. return scheduled_requests, one_allocation_failed, decode_fast_path, num_q_tokens, max_kv_read
  249. def _cleanup_waiting_queue(self, request_ids_to_remove_from_waiting: set[str]) -> None:
  250. """Removes processed requests from the waiting queue order."""
  251. self.waiting_requests_order = deque(
  252. [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
  253. )
  254. # TODO: further common-ize the two classes
  255. @attach_tracer()
  256. class FIFOScheduler(Scheduler):
  257. """This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
  258. prefilling requests. Additionally, it includes a safety margin mechanism to prevent cache exhaustion. By default,
  259. when 80% of the cache is full, new requests will not be scheduled to prioritize decoding active requests."""
  260. def __init__(self, cache: PagedAttentionCache, safety_margin: float = 0.2):
  261. """Initializes the FIFO scheduler. The safety margin is the percentage of free blocks under which we stop
  262. scheduling new prefill requests, so safety_margin = 0.1 means that when there is less than 10% of free blocks,
  263. or equivalently when more than 90% of blocks are already allocated, we stop scheduling new prefill requests.
  264. """
  265. super().__init__(cache)
  266. self.safety_margin = safety_margin
  267. @traced
  268. def schedule_batch(
  269. self, token_budget: int, cache_budget: int
  270. ) -> tuple[list[FutureRequestState] | None, bool, int, int]:
  271. priority_states: list[RequestState] = []
  272. second_priority_states: list[RequestState] = []
  273. for state in self.active_requests.values():
  274. if state.status == RequestStatus.DECODING:
  275. priority_states.append(state)
  276. elif state.status == RequestStatus.PREFILLING:
  277. second_priority_states.append(state)
  278. # Add waiting requests to second priority
  279. if not self.block_new_requests:
  280. for req_id in self.waiting_requests_order:
  281. second_priority_states.append(self.waiting_requests[req_id])
  282. candidates = priority_states + second_priority_states
  283. request_ids_to_remove_from_waiting = set()
  284. scheduled_requests, one_allocation_failed, decode_fast_path, num_q_tokens, max_kv_read = (
  285. self._process_candidates(
  286. candidates,
  287. token_budget,
  288. cache_budget,
  289. request_ids_to_remove_from_waiting,
  290. safety_margin=self.safety_margin,
  291. )
  292. )
  293. # We remove waiting requests before checking requests were scheduled, because there might have been prefill matches
  294. self._cleanup_waiting_queue(request_ids_to_remove_from_waiting)
  295. # If no requests were scheduled and the cache is full, we signal it by returning None
  296. if not scheduled_requests and one_allocation_failed:
  297. return None, decode_fast_path, 0, 0
  298. return scheduled_requests, decode_fast_path, num_q_tokens, max_kv_read
  299. # FIXME: prioritize adding from waiting reqs before scheduling `RequestStatus.DECODING` when cache space allows it
  300. # TODO: further consolidate the code by making more of it common. The reference Scheduler is FIFO, not this one.
  301. @attach_tracer()
  302. class PrefillFirstScheduler(Scheduler):
  303. """Scheduler that prioritizes split prefill requests over decoding requests. This scheduler ensures that split
  304. prefill requests (which are continuations of partially processed prompts) are completed before processing new
  305. decoding requests."""
  306. @traced
  307. def schedule_batch(
  308. self, token_budget: int, cache_budget: int
  309. ) -> tuple[list[FutureRequestState] | None, bool, int, int]:
  310. priority_states: list[RequestState] = []
  311. second_priority_states: list[RequestState] = []
  312. for state in self.active_requests.values():
  313. # XXX: when cache is full, state can stay on `PREFILLING_SPLIT` so we need to take those into account
  314. if state.status == RequestStatus.PREFILLING:
  315. priority_states.append(state)
  316. elif state.status == RequestStatus.DECODING:
  317. second_priority_states.append(state)
  318. # Add waiting requests to second priority
  319. if not self.block_new_requests:
  320. for req_id in self.waiting_requests_order:
  321. second_priority_states.append(self.waiting_requests[req_id])
  322. candidates = priority_states + second_priority_states
  323. request_ids_to_remove_from_waiting = set()
  324. scheduled_requests, one_allocation_failed, decode_fast_path, num_q_tokens, max_kv_read = (
  325. self._process_candidates(
  326. candidates,
  327. token_budget,
  328. cache_budget,
  329. request_ids_to_remove_from_waiting,
  330. safety_margin=0.0,
  331. )
  332. )
  333. # We remove waiting requests before checking requests were scheduled, because there might have been prefill matches
  334. self._cleanup_waiting_queue(request_ids_to_remove_from_waiting)
  335. # If no requests were scheduled and the cache is full, we signal it by returning None
  336. if not scheduled_requests and one_allocation_failed:
  337. return None, decode_fast_path, 0, 0
  338. return scheduled_requests, decode_fast_path, num_q_tokens, max_kv_read
  339. SCHEDULER_MAPPING = {
  340. "fifo": FIFOScheduler,
  341. "prefill_first": PrefillFirstScheduler,
  342. }