| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397 |
- # Copyright 2025 The HuggingFace Inc. team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import threading
- from abc import ABC, abstractmethod
- from collections import deque
- from ...utils.metrics import attach_tracer, traced
- from .cache import PagedAttentionCache
- from .requests import FutureRequestState, RequestState, RequestStatus, logger
- class Scheduler(ABC):
- """
- Abstract base class for scheduling requests in the continuous batch processor. Schedulers manage the lifecycle of
- requests from when they are added to the waiting queue to when they are scheduled for processing. Different
- schedulers implement different strategies for prioritizing and batching requests.
- """
- def __init__(self, cache: PagedAttentionCache):
- self.cache = cache
- self._cancellation_lock = threading.Lock()
- # This is to compute the read cache used by a new request being scheduled
- self.read_cache_limit = None if self.cache.num_full_attention_groups else self.cache.config.sliding_window
- self.max_decode_fast_path_length = self.cache.max_blocks_per_request * self.cache.block_size
- # Initialize mutable states via reset()
- self.reset()
- def reset(self) -> None:
- """Reset scheduler state for a new generation loop."""
- self.active_requests: dict[str, RequestState] = {}
- self.waiting_requests: dict[str, RequestState] = {}
- self.waiting_requests_order: deque[str] = deque()
- self._requests_to_cancel: set[str] = set()
- self._requests_to_fork: list[RequestState] = []
- self.block_new_requests = False
- @traced
- def add_waiting_request(self, state: RequestState):
- """Adds a request to the waiting list."""
- self.waiting_requests[state.request_id] = state
- self.waiting_requests_order.append(state.request_id)
- @abstractmethod
- def schedule_batch(
- self, token_budget: int, cache_budget: int
- ) -> tuple[list[FutureRequestState] | None, bool, int, int]:
- """Schedules requests for the next batch based on available token and cache budgets. This method selects which
- requests should be processed in the current batch, considering the budgets and the scheduler's prioritization
- rules. The token_budget is the maximum number of tokens that can be processed in a batch, and the cache_budget
- is the maximum number of KV cache entries that can be read in a batch.
- Returns the list of scheduled requests in their "FutureRequestState" form, a boolean indicating if the decode
- fast path can be used, the total number of query tokens and the maximum number of kv tokens read."""
- @traced
- def has_pending_requests(self) -> bool:
- """Checks if there are requests ready to be processed."""
- return bool(len(self.active_requests) or len(self.waiting_requests))
- @traced
- def finish_request(self, request_id: str) -> None:
- """Completes processing of a request and frees its allocated cache blocks. This method is called
- when a request has finished generation or encountered an error.
- """
- self.cache.free_blocks(request_id)
- self.active_requests.pop(request_id, None)
- @traced
- def get_active_request_static_outputs(self, request_id: str) -> list[int]:
- """Gets generated tokens for an active request."""
- if request_id in self.active_requests:
- return self.active_requests[request_id].generated_tokens
- return []
- @traced
- def set_request_cancellation(self, request_id: str):
- """Marks a request for cancellation."""
- with self._cancellation_lock:
- self._requests_to_cancel.add(request_id)
- @traced
- def clear_cancelled_requests(self):
- """Remove all cancelled requests from active and waiting queues."""
- with self._cancellation_lock:
- for request_id in self._requests_to_cancel:
- self.active_requests.pop(request_id, None)
- self.waiting_requests.pop(request_id, None)
- if request_id in self.waiting_requests_order:
- self.waiting_requests_order.remove(request_id)
- self.cache.free_blocks(request_id)
- self._requests_to_cancel = set()
- @traced
- def request_is_cancelled(self, request_id: str) -> bool:
- """Checks if a request has been cancelled or removed."""
- return request_id in self._requests_to_cancel or (
- request_id not in self.active_requests and request_id not in self.waiting_requests
- )
- @traced
- def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool:
- """Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to
- accommodate the next tokens. It calculates how many blocks are needed based on the request's current
- cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator
- objects. Returns a boolean indicating if the allocation was successful or not.
- """
- # 1. we check that the occupancy is less than the requested length
- # 2. we allocate enough blocks to cover the requested length
- current_len = state.current_len()
- occupancy = state.allocated_blocks * self.cache.block_size - current_len
- if occupancy < len_next_tokens or state.allocated_blocks == 0:
- blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
- allocated = self.cache.allocate_blocks(blocks_needed, state.request_id, state.allocated_blocks)
- if allocated is None:
- return False
- state.allocated_blocks += allocated
- return True
- def _infer_request_tokens(self, state: RequestState, request_ids_to_remove_from_waiting: set[str]) -> list[int]:
- """Prepares a request for processing in the current batch. If prefix sharing is enabled, and the request was
- pending, this is where we look for a prefix match and split the request if found."""
- # If prefix sharing is enabled, we look for a prefix match and split the request if found
- if self.cache.use_prefix_sharing and state.status == RequestStatus.PENDING:
- prefill_length = self.cache.search_prefix_match(state.request_id, state.remaining_prefill_tokens)
- if prefill_length > 0:
- self.active_requests[state.request_id] = state
- request_ids_to_remove_from_waiting.add(state.request_id)
- state.status = RequestStatus.PREFILLING
- # We keep track of the number of allocated blocks to avoid double allocation
- state.allocated_blocks += prefill_length // self.cache.block_size
- # Even if we match the whole request, we keep at least 1 token to start decoding
- prefill_length = min(prefill_length, len(state.remaining_prefill_tokens) - 1)
- state.remaining_prefill_tokens = state.remaining_prefill_tokens[prefill_length:]
- state.position_offset += prefill_length
- # If the request is decoding, the tokens to process are already set
- if state.status == RequestStatus.DECODING:
- request_tokens = state.tokens_to_process
- # Otherwise, the tokens to process are the remaining prefill tokens
- else:
- request_tokens = state.remaining_prefill_tokens
- return request_tokens
- def _schedule_request(
- self,
- state: RequestState,
- request_tokens: list[int],
- token_budget: int,
- request_ids_to_remove_from_waiting: set[str],
- ) -> None:
- """Schedules a request for the current batch, updating the request's status according to the token budget left.
- After a request is scheduled, it is part of the next batch unless there is an error.
- If the request has children (for parallel decoding), it ensures at least one token remains before the request is
- forked."""
- # If the request has one or more children we make sure not to prefill it entirely
- # This does not check the request state, but DECODING request already have children set to 0.
- if state.num_children > 0 and token_budget >= len(request_tokens) - 1:
- token_budget = len(request_tokens) - 1
- self._requests_to_fork.append(state)
- # Case: we can process the entire prompt/remainder
- if len(request_tokens) <= token_budget:
- if state.status == RequestStatus.PENDING:
- self.active_requests[state.request_id] = state
- request_ids_to_remove_from_waiting.add(state.request_id)
- if state.status <= RequestStatus.PREFILLING:
- state.tokens_to_process = state.remaining_prefill_tokens
- state.remaining_prefill_tokens = []
- # Although prefill will only be done after the batch being scheduled now, we set the status to DECODING
- # to stay coherent when using asynchronous batching
- state.status = RequestStatus.DECODING
- # Otherwise: we need to split the request
- else:
- if state.status == RequestStatus.PENDING:
- self.active_requests[state.request_id] = state
- state.status = RequestStatus.PREFILLING
- request_ids_to_remove_from_waiting.add(state.request_id)
- state.remaining_prefill_tokens = request_tokens[token_budget:]
- state.tokens_to_process = request_tokens[:token_budget]
- def _process_candidates(
- self,
- candidates: list[RequestState],
- token_budget: int,
- cache_budget: int,
- request_ids_to_remove_from_waiting: set[str],
- safety_margin: float = 0.0,
- ) -> tuple[list[FutureRequestState], bool, bool, int, int]:
- """Schedules candidate requests for the current batch.
- This method contains the common logic shared by all schedulers: it checks token and cache budgets, allocates
- cache blocks if needed, updates request states, and tracks which waiting requests should be removed from the
- waiting queue.
- """
- scheduled_requests = []
- one_allocation_failed = False
- decode_fast_path = True
- safety_margins = safety_margin * self.cache.num_blocks
- original_token_budget, original_cache_budget = token_budget, cache_budget
- for state in candidates:
- num_free_blocks = self.cache.get_num_free_blocks()
- # If we are out the safety margin, we only accept decoding requests or the first prefill request
- outside_safety_margin = num_free_blocks < safety_margins
- if outside_safety_margin and scheduled_requests and state.status != RequestStatus.DECODING:
- logger.info(
- f"Outside safety margin, breaking out of scheduling loop. {num_free_blocks = } {safety_margins = }"
- )
- break
- # Check cache budget
- read_cache_needed = state.current_len()
- if self.read_cache_limit is not None:
- read_cache_needed = min(read_cache_needed, self.read_cache_limit)
- if cache_budget < read_cache_needed:
- continue
- # Infer the tokens that will be present in the batch if token budget is enough
- request_tokens = self._infer_request_tokens(state, request_ids_to_remove_from_waiting)
- # Account for token budget
- request_len = min(len(request_tokens), token_budget)
- # Check there will be enough cache for the new tokens
- allocation_successful = self._allocate_blocks_if_needed(state, request_len)
- # If the allocation would not be successful, we move on to the next request
- if not allocation_successful:
- one_allocation_failed = True
- # If we reached a waiting request and the cache is full, all subsequent waiting requests will need
- # allocation as well, so we can safely break out of the scheduling loop.
- if num_free_blocks == 0 and state.request_id in self.waiting_requests:
- logger.info(f"Breaking mid-loop for request {state.request_id} because the cache is full")
- break
- continue
- # If this point is reached, it means we can safely schedule the request
- self._schedule_request(state, request_tokens, token_budget, request_ids_to_remove_from_waiting)
- request_len = len(state.tokens_to_process) # it may change after scheduling
- # The decode fast path is only used if the request is a single token and its length is less than the max blocks per request
- decode_fast_path &= request_len == 1 and state.position_offset < self.max_decode_fast_path_length
- # Update the token and cache budgets
- token_budget -= request_len
- cache_budget -= read_cache_needed
- # If using prefix sharing, we make note of the blocks that will be computed in the forward pass
- if self.cache.allow_block_sharing:
- tokens_in_current_block = state.current_len() % self.cache.block_size
- tokens_after_forward = tokens_in_current_block + request_len
- complete_blocks = tokens_after_forward // self.cache.block_size
- else:
- complete_blocks = 0
- # Store the future request state
- has_new_token = not state.remaining_prefill_tokens
- scheduled_requests.append(FutureRequestState(state, has_new_token, complete_blocks))
- # Remove the request from the waiting queue and mark it as removed
- req_id = state.request_id
- was_waiting = self.waiting_requests.pop(req_id, None) is not None
- if was_waiting:
- request_ids_to_remove_from_waiting.add(req_id)
- # Early exit of the loop if we have no budget left
- if token_budget == 0 or cache_budget == 0:
- break
- num_q_tokens = original_token_budget - token_budget
- max_kv_read = original_cache_budget - cache_budget
- return scheduled_requests, one_allocation_failed, decode_fast_path, num_q_tokens, max_kv_read
- def _cleanup_waiting_queue(self, request_ids_to_remove_from_waiting: set[str]) -> None:
- """Removes processed requests from the waiting queue order."""
- self.waiting_requests_order = deque(
- [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
- )
- # TODO: further common-ize the two classes
- @attach_tracer()
- class FIFOScheduler(Scheduler):
- """This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
- prefilling requests. Additionally, it includes a safety margin mechanism to prevent cache exhaustion. By default,
- when 80% of the cache is full, new requests will not be scheduled to prioritize decoding active requests."""
- def __init__(self, cache: PagedAttentionCache, safety_margin: float = 0.2):
- """Initializes the FIFO scheduler. The safety margin is the percentage of free blocks under which we stop
- scheduling new prefill requests, so safety_margin = 0.1 means that when there is less than 10% of free blocks,
- or equivalently when more than 90% of blocks are already allocated, we stop scheduling new prefill requests.
- """
- super().__init__(cache)
- self.safety_margin = safety_margin
- @traced
- def schedule_batch(
- self, token_budget: int, cache_budget: int
- ) -> tuple[list[FutureRequestState] | None, bool, int, int]:
- priority_states: list[RequestState] = []
- second_priority_states: list[RequestState] = []
- for state in self.active_requests.values():
- if state.status == RequestStatus.DECODING:
- priority_states.append(state)
- elif state.status == RequestStatus.PREFILLING:
- second_priority_states.append(state)
- # Add waiting requests to second priority
- if not self.block_new_requests:
- for req_id in self.waiting_requests_order:
- second_priority_states.append(self.waiting_requests[req_id])
- candidates = priority_states + second_priority_states
- request_ids_to_remove_from_waiting = set()
- scheduled_requests, one_allocation_failed, decode_fast_path, num_q_tokens, max_kv_read = (
- self._process_candidates(
- candidates,
- token_budget,
- cache_budget,
- request_ids_to_remove_from_waiting,
- safety_margin=self.safety_margin,
- )
- )
- # We remove waiting requests before checking requests were scheduled, because there might have been prefill matches
- self._cleanup_waiting_queue(request_ids_to_remove_from_waiting)
- # If no requests were scheduled and the cache is full, we signal it by returning None
- if not scheduled_requests and one_allocation_failed:
- return None, decode_fast_path, 0, 0
- return scheduled_requests, decode_fast_path, num_q_tokens, max_kv_read
- # FIXME: prioritize adding from waiting reqs before scheduling `RequestStatus.DECODING` when cache space allows it
- # TODO: further consolidate the code by making more of it common. The reference Scheduler is FIFO, not this one.
- @attach_tracer()
- class PrefillFirstScheduler(Scheduler):
- """Scheduler that prioritizes split prefill requests over decoding requests. This scheduler ensures that split
- prefill requests (which are continuations of partially processed prompts) are completed before processing new
- decoding requests."""
- @traced
- def schedule_batch(
- self, token_budget: int, cache_budget: int
- ) -> tuple[list[FutureRequestState] | None, bool, int, int]:
- priority_states: list[RequestState] = []
- second_priority_states: list[RequestState] = []
- for state in self.active_requests.values():
- # XXX: when cache is full, state can stay on `PREFILLING_SPLIT` so we need to take those into account
- if state.status == RequestStatus.PREFILLING:
- priority_states.append(state)
- elif state.status == RequestStatus.DECODING:
- second_priority_states.append(state)
- # Add waiting requests to second priority
- if not self.block_new_requests:
- for req_id in self.waiting_requests_order:
- second_priority_states.append(self.waiting_requests[req_id])
- candidates = priority_states + second_priority_states
- request_ids_to_remove_from_waiting = set()
- scheduled_requests, one_allocation_failed, decode_fast_path, num_q_tokens, max_kv_read = (
- self._process_candidates(
- candidates,
- token_budget,
- cache_budget,
- request_ids_to_remove_from_waiting,
- safety_margin=0.0,
- )
- )
- # We remove waiting requests before checking requests were scheduled, because there might have been prefill matches
- self._cleanup_waiting_queue(request_ids_to_remove_from_waiting)
- # If no requests were scheduled and the cache is full, we signal it by returning None
- if not scheduled_requests and one_allocation_failed:
- return None, decode_fast_path, 0, 0
- return scheduled_requests, decode_fast_path, num_q_tokens, max_kv_read
- SCHEDULER_MAPPING = {
- "fifo": FIFOScheduler,
- "prefill_first": PrefillFirstScheduler,
- }
|