| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188 |
- # Copyright 2026 The HuggingFace Inc. team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections import OrderedDict
- from math import ceil
- from typing import Any
- import torch
- from transformers.configuration_utils import PretrainedConfig
- from .requests import FutureRequestState, RequestState, RequestStatus, logger
- class CudaGraphBuffer:
- """A fixed-size dict for CUDA graphs with LRU eviction when full."""
- def __init__(self, max_size: int) -> None:
- if max_size <= 0:
- raise ValueError(f"max_size must be positive, but got {max_size}")
- self.max_size = max_size
- self._storage: OrderedDict[tuple[int, int], torch.cuda.CUDAGraph] = OrderedDict()
- def __del__(self) -> None:
- original_max_size = self.max_size
- self.max_size = 1 # 0 would cause an infinite loop, 1 is enough to clear all graphs
- self.plan_for_new_graph(silent=True)
- self.max_size = original_max_size
- def get_graph(self, q_len: int, kv_len: int) -> torch.cuda.CUDAGraph | None:
- graph = self._storage.get((q_len, kv_len))
- if graph is not None:
- self._storage.move_to_end((q_len, kv_len))
- return graph
- def plan_for_new_graph(self, silent: bool = False) -> None:
- while len(self._storage) >= self.max_size:
- evicted_key, evicted_graph = self._storage.popitem(last=False)
- if not silent:
- logger.info(f"Evicting graph for {evicted_key = }")
- evicted_graph.reset()
- def set_graph(self, q_len: int, kv_len: int, graph: torch.cuda.CUDAGraph) -> None:
- # In our use case, this should not have any effect because we plan for a new graph before it is captured
- self.plan_for_new_graph()
- logger.info(f"Setting graph for {q_len = }, {kv_len = }")
- self._storage[(q_len, kv_len)] = graph
- def attn_mask_is_needed(config: PretrainedConfig) -> bool:
- """Checks if attention mask is needed for the given (config)."""
- return config._attn_implementation in ["paged|eager", "paged|sdpa"]
- def pad_to_interval(size: int, interval_size: int, max_value: int) -> int:
- """Return the smallest multiple of (interval_size) >= (size), capped at (max_value)."""
- if interval_size <= 0:
- return max_value
- padded = ceil(size / interval_size) * interval_size if size > 0 else interval_size
- return min(padded, max_value)
- def aligned_divide(x: int, divide_by: int, align_to: int) -> int:
- x = int(ceil(x / divide_by))
- if x % align_to:
- x += align_to - (x % align_to)
- return x
- def build_attention_mask(
- attention_mask: torch.Tensor,
- cumulative_seqlens_q: list[int],
- cumulative_seqlens_k: list[int],
- sliding_window: int = 1,
- ) -> None:
- """Builds an attention mask inplace using the cumulative seqlens of the query and key. If given a sliding window, it
- will also apply a sliding window mask on top. The attention mask is not boolean, it uses zeroes and -inf (or its
- equivalent) so it's more of an attention score bias tensor.
- The attention mask is a block-diagonal matrix, with each block an attention mask for a single query-key pair.
- Each of those block is built from a causal mask and, if there is a sliding window, a sliding window mask.
- An example is represented below, with seqlen_k = 8, seqlen_q = 4 and sliding_window = 6:
- CAUSAL MASK:
- █ █ █ █ █ ░ ░ ░
- █ █ █ █ █ █ ░ ░
- █ █ █ █ █ █ █ ░
- █ █ █ █ █ █ █ █
- SLIDING WINDOW MASK:
- ┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 8 - 4 - 6 = -2 offset to the left
- <─┴─>
- ░ █ | █ █ █ █ █ █ █ █
- ░ ░ | █ █ █ █ █ █ █ █
- ░ ░ | ░ █ █ █ █ █ █ █
- ░ ░ | ░ ░ █ █ █ █ █ █
- ATTENTION MASK (sum of causal and sliding window masks):
- █ █ █ █ █ ░ ░ ░
- █ █ █ █ █ █ ░ ░
- ░ █ █ █ █ █ █ ░
- ░ ░ █ █ █ █ █ █
- Another example with seqlen_k = 5, seqlen_q = 3 and sliding_window = 2:
- CAUSAL MASK:
- █ █ █ ░ ░
- █ █ █ █ ░
- █ █ █ █ █
- SLIDING WINDOW MASK:
- ┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 5 - 3 - 2 = 0 offset to the left
- <┴>
- | ░ █ █ █ █
- | ░ ░ █ █ █
- | ░ ░ ░ █ █
- ATTENTION MASK (sum of causal and sliding window masks):
- ░ █ █ ░ ░
- ░ ░ █ █ ░
- ░ ░ ░ █ █
- """
- min_value = torch.finfo(attention_mask.dtype).min
- for i in range(len(cumulative_seqlens_q) - 1):
- seqlen_q = cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]
- seqlen_k = cumulative_seqlens_k[i + 1] - cumulative_seqlens_k[i]
- if seqlen_q < seqlen_k and seqlen_q >= 1:
- causal_diagonal = seqlen_k - seqlen_q + 1
- else:
- causal_diagonal = 1
- query_range = slice(cumulative_seqlens_q[i], cumulative_seqlens_q[i + 1])
- key_range = slice(cumulative_seqlens_k[i], cumulative_seqlens_k[i + 1])
- # Apply causal mask
- minus_inf = torch.full(
- attention_mask[..., query_range, key_range].shape,
- min_value,
- dtype=attention_mask.dtype,
- device=attention_mask.device,
- )
- masked = torch.triu(minus_inf, diagonal=causal_diagonal)
- # Apply sliding window mask if needed
- if sliding_window > 1:
- sliding_diagonal = seqlen_k - seqlen_q - sliding_window
- masked += torch.tril(minus_inf, diagonal=sliding_diagonal)
- # Replace in attention mask
- attention_mask[..., query_range, key_range] = masked
- def create_warmup_future_states(
- num: int,
- status: RequestStatus,
- num_query_tokens: int,
- num_cache_tokens: int,
- cache: Any, # not annotated to avoid circular import
- ) -> list[FutureRequestState]:
- """An utility function to create a list of FutureRequestStates for the warmup of CB."""
- # Setup
- request_ids = [f"__warmup_{status.name}_{i}__" for i in range(num)]
- total_tokens = num_query_tokens + num_cache_tokens
- blocks_needed = ceil(total_tokens / cache.block_size)
- # Main loop
- future_states = []
- for req_id in request_ids:
- state = RequestState(request_id=req_id, initial_tokens=[0] * total_tokens, max_new_tokens=1)
- state._status = status # bypass the property setter to avoid the lifecycle side effects
- state.tokens_to_process = [0] * num_query_tokens
- state.position_offset = num_cache_tokens
- # Stop if allocation fails for any request
- allocated = cache.allocate_blocks(blocks_needed, state.request_id, 0)
- if allocated is None:
- return future_states
- future_states.append(FutureRequestState(state, has_new_token=True, complete_blocks=0))
- return future_states
|