| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761 |
- # Copyright 2025 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import inspect
- from math import floor, gcd, sqrt
- from typing import Any
- import torch
- from ...configuration_utils import PreTrainedConfig
- from ...generation.configuration_utils import ContinuousBatchingConfig
- from ...utils.generic import is_flash_attention_requested
- from ...utils.metrics import attach_tracer, traced
- from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
- from .requests import RequestState, RequestStatus, get_device_and_memory_breakdown, logger
- def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]], list[str]]:
- """
- Group layers depending on the attention mix, according to VLLM's hybrid allocator rules:
- - Layers in each group need to have the same type of attention
- - All groups have the same number of layers
- For a model with the following layer types: ["sliding", "full", "full", "sliding", "full", "full", "full", "full"]
- We would get four groups: [0, 3], [1, 2], [4,5] and [6,7].
- """
- # If the config has no layer_type attribute, it means all layers are the same attention type
- layer_types = getattr(config, "layer_types", None)
- if layer_types is None:
- attn_type = "sliding_attention" if getattr(config, "sliding_window", None) is not None else "full_attention"
- layer_types = [attn_type for _ in range(config.num_hidden_layers)]
- # We then count the number of layers of each type
- layer_counts = {}
- for i, layer_type in enumerate(layer_types):
- layer_counts[layer_type] = layer_counts.get(layer_type, []) + [i]
- # The size of all groups is the greatest common divisor of the number of layers of each type
- group_size = gcd(*[len(indices) for indices in layer_counts.values()])
- # We then group the layers by type
- layer_groups = []
- for layer_type, indices in layer_counts.items():
- for i in range(0, len(indices), group_size):
- layer_groups.append(indices[i : i + group_size])
- # And note the layer types
- group_types = [layer_types[lg[0]] for lg in layer_groups]
- return layer_groups, group_types
- @attach_tracer()
- class PagedAttentionCache:
- """
- Manages the cache for a paged attention mechanism, inspired by VLLM's hybrid allocator. The cache relies on making
- groups of layers to reduce the complexity of cache management and fragmentation.
- The cache uses a three-level hierarchy:
- - Pages: The smallest unit of cache, a page has a size of [num_heads, head_size], which is the space needed to
- store the key or value states for one token and one layer. For a model with only full-attention layers, to store
- the KV cache of one token, we need `2 * num_layers` pages: key and values each take `num_layers` pages.
- Pages are grouped into blocks:
- - Blocks: A block is a collection of `block_size` pages, serving as the allocation unit to reduce management
- complexity and fragmentation. Cache is allocated and freed block by block, not page by page. One block is
- allocated to one layer group, which only has one attention type, like full-attention or sliding-attention.
- If all layers in the model have the same attention type, then all layers will be in the same group. There is
- more than one group if and only if the model has a mixed attention types, like layers with full-attention and
- layers with sliding-attention.
- - Cache tensors: The physical supports for the cache. There are as many cache tensors as there are layer in a
- layer group, and the shape of the cache tensor is `[num_blocks * block_size, num_heads, head_size]`.
- Grouping layers into groups is useful because when we allocate one block to a group N, the block allocated is the
- same for all layers in group N, equivalently it is allocated across all cache tensors. This allows us to
- efficiently allocate and free blocks, and to efficiently read and write key and value states.
- For instance, imagine we have 8 blocks of cache and a model with two layer groups: a full-attention group with 3
- layers and a sliding-attention group with 3 layers. At creation time, the physical cache tensors look like this:
- cache_tensor_0: □ □ □ □ □ □ □ □
- cache_tensor_1: □ □ □ □ □ □ □ □
- cache_tensor_2: □ □ □ □ □ □ □ □
- where □ means the blocks is not allocated to any layer group yet. We have 3 cache tensors because there are
- 3 layers per group.
- We allocate 1 block to each group, after allocation, the cache tensors look like this:
- cache_tensor_0: ✖ ◉ □ □ □ □ □ □
- cache_tensor_1: ✖ ◉ □ □ □ □ □ □
- cache_tensor_2: ✖ ◉ □ □ □ □ □ □
- where ✖ means the block is allocated to the full-attention group, and ◉ means the block is allocated to the
- sliding-attention group.
- Now, if we continue to generate, and the sliding window has been reached, we only need to allocate a new block
- for the full-attention group, and the cache tensors look like this:
- cache_tensor_0: ✖ ◉ ✖ □ □ □ □ □
- cache_tensor_1: ✖ ◉ ✖ □ □ □ □ □
- cache_tensor_2: ✖ ◉ ✖ □ □ □ □ □
- And after further generation, when we need a new block allocated:
- cache_tensor_0: ✖ ◉ ✖ ✖ □ □ □ □
- cache_tensor_1: ✖ ◉ ✖ ✖ □ □ □ □
- cache_tensor_2: ✖ ◉ ✖ ✖ □ □ □ □
- This would not have been possible if all layers were in the same group: we would have had to allocate a new block
- for the sliding-attention group, although it is not needed.
- """
- def __init__(
- self,
- config: PreTrainedConfig,
- continuous_batching_config: ContinuousBatchingConfig,
- device: torch.device | str,
- dtype: torch.dtype = torch.float16,
- tp_size: int | None = None,
- ) -> None:
- """Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has
- only full attention layers.
- Args:
- config: Model configuration
- continuous_batching_config: Continuous batching configuration containing cache parameters
- device: Device for the cache tensors
- dtype: Data type of the cache
- tp_size: Tensor parallelism size
- """
- self.config = config
- self.dtype = dtype
- self.device = device
- # Extract model dimensions
- kv_heads = getattr(config, "num_key_value_heads", None)
- self.num_key_value_heads: int = kv_heads if kv_heads is not None else config.num_attention_heads
- head_dim = getattr(config, "head_dim", None)
- self.head_dim: int = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads
- # Extract cache dimensions. Default used to be 32, now it's 256 to be compatible with flash_with_kvcache.
- self.block_size = continuous_batching_config.block_size
- if self.block_size <= 0:
- raise ValueError(f"Block size must be positive, but got {self.block_size}")
- # Group layers depending on the attention mix
- layer_groups, group_types = group_layers_by_attn_type(config)
- group_size = len(layer_groups[0])
- self.num_groups = len(layer_groups)
- self.sliding_windows = {}
- self.layer_index_to_group_indices = {}
- for i, group in enumerate(layer_groups):
- sliding_window = config.sliding_window if group_types[i] == "sliding_attention" else 1
- for j, layer in enumerate(group):
- self.layer_index_to_group_indices[layer] = (i, j)
- self.sliding_windows[layer] = sliding_window
- # Handle TP (or dont)
- if tp_size is not None and tp_size > 1:
- if self.num_key_value_heads % tp_size != 0:
- raise ValueError(
- f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
- )
- # If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
- # self.num_key_value_heads //= tp_size # TODO: why is this commented out?
- # Infer number of blocks and max batch tokens
- page_size = self.head_dim * self.num_key_value_heads
- if is_flash_attention_requested(self.config):
- num_attention_masks = 0 # only used to compute the default memory footprint args
- elif "sliding_attention" in group_types:
- # TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
- num_attention_masks = 2
- else:
- num_attention_masks = 1
- memory_handler = PagedAttentionMemoryHandler(
- block_size=self.block_size,
- page_size=page_size,
- num_groups=self.num_groups,
- group_size=group_size,
- peak_activation_per_token=(config.hidden_size + config.vocab_size),
- num_attention_masks=num_attention_masks,
- )
- num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens(
- num_blocks=continuous_batching_config.num_blocks,
- max_batch_tokens=continuous_batching_config.max_batch_tokens,
- max_memory_percent=continuous_batching_config.max_memory_percent,
- cache_dtype=self.dtype,
- )
- # Add the inferred attributes to the class
- self.num_blocks = num_blocks
- self.max_batch_tokens = max_batch_tokens
- self.num_pages = self.num_blocks * self.block_size
- logger.info(
- f"PagedAttentionCache initialized with {self.num_blocks = }, {self.block_size = }, {page_size = }, "
- f"{self.max_batch_tokens = } {num_attention_masks = }"
- )
- # If max_blocks_per_request is not set, the default value is 16 max blocks. With default block size of 256, this
- # means a max sequence length of 4096 tokens for the fast decode path.
- max_blocks_per_request = continuous_batching_config.max_blocks_per_request
- if max_blocks_per_request is None:
- max_blocks_per_request = 0
- # logger.info( TODO: uncomment when we have good defaults
- # f"max_blocks_per_request was not set, using {max_blocks_per_request}. This means max sequence "
- # f"length for the decode fast path is {max_blocks_per_request * self.block_size}."
- # )
- self.max_blocks_per_request = max_blocks_per_request
- # Initialize the cache
- self.key_cache: list[torch.Tensor] = []
- self.value_cache: list[torch.Tensor] = []
- # We add two extra tokens to the cache to handle padding and generally discard unwanted tokens
- self.cache_shape = ((num_blocks + 2) * self.block_size, self.num_key_value_heads, self.head_dim)
- for _ in range(group_size):
- new_layer_key_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
- new_layer_value_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
- torch._dynamo.mark_static_address(new_layer_key_cache)
- torch._dynamo.mark_static_address(new_layer_value_cache)
- self.key_cache.append(new_layer_key_cache)
- self.value_cache.append(new_layer_value_cache)
- logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")
- # Block management data structures
- self.allow_block_sharing = continuous_batching_config.allow_block_sharing
- self.group_cache_managers: list[CacheAllocator] = []
- self.num_full_attention_groups = 0
- self.num_sliding_attention_groups = 0
- self.max_sliding_window_blocks_per_request = 0
- for i, group_type in enumerate(group_types):
- if group_type == "full_attention":
- cm = FullAttentionCacheAllocator(i, self.block_size, allow_block_sharing=self.allow_block_sharing)
- self.num_full_attention_groups += 1
- elif group_type == "sliding_attention":
- cm = SlidingAttentionCacheAllocator(i, self.block_size, config.sliding_window)
- self.num_sliding_attention_groups += 1
- self.max_sliding_window_blocks_per_request = cm._max_blocks_per_request
- else:
- raise ValueError(f"Invalid group type: {group_type}")
- self.group_cache_managers.append(cm)
- # We only use prefix sharing if the whole model has only full attention layers and block sharing is allowed
- self.use_prefix_sharing = self.allow_block_sharing and group_types == ["full_attention"]
- self._block_manager = BlockManager(num_blocks, self.block_size)
- self._total_prefix_length: int = 0 # a counter to measure the impact of prefix sharing, also used in tests
- # For block table support, we lazy init the name of the block table key
- self._block_table_key = None
- def will_allocation_be_successful(self, num_requested_blocks: int, allocated_blocks: int) -> bool:
- """Returns a boolean indicating if the allocation of (num_requested_blocks) blocks will be successful. The
- number of newly allocated blocks needed is predicted by the following rules:
- - for full attention groups: since there is no sliding window for full attention layers, one requested block is
- always equivalent to one newly allocated block for EACH full attention group
- - for sliding window groups: because of the sliding window, the number of blocks allocated to a request is
- capped. Using the number of already (allocated_blocks) we can compute the number of new blocks to actually
- allocate to the request, which can be lower than the number of requested blocks. That number is the same for
- all sliding window groups, as only one sliding window size is supported.
- """
- # This is not in a branch, because it is very rare to have zero full attention layer
- needed_blocks = num_requested_blocks * self.num_full_attention_groups
- # Only take this branch if the model has sliding window attention layers
- if self.num_sliding_attention_groups:
- blocks_left = max(self.max_sliding_window_blocks_per_request - allocated_blocks, 0)
- needed_blocks += min(blocks_left, num_requested_blocks) * self.num_sliding_attention_groups
- return needed_blocks <= self.get_num_free_blocks()
- @traced
- def allocate_blocks(self, n_blocks: int, request_id: str, allocated_blocks: int) -> int | None:
- """Allocate cache blocks across all layer groups for a given request. Actual allocation is done by the cache
- managers, and this method only returns the maximum number of blocks actually allocated across all managers."""
- # First check allocation will be successful before starting, to avoid partial allocations
- if not self.will_allocation_be_successful(n_blocks, allocated_blocks):
- return None
- # Allocate blocks across all cache managers
- max_allocated = 0
- for cm in self.group_cache_managers:
- num_allocated_blocks = cm.allocate_blocks(n_blocks, request_id, self._block_manager)
- if num_allocated_blocks is None:
- raise ValueError(f"Failed to allocate {n_blocks} blocks for request {request_id}")
- max_allocated = max(max_allocated, num_allocated_blocks)
- return max_allocated
- @traced
- def free_blocks(self, request_id: str) -> None:
- """Free all allocated cache blocks for a given request across all layer groups. Actual deallocation is done
- by the cache managers."""
- for cm in self.group_cache_managers:
- cm.free_blocks(request_id, self._block_manager)
- def get_num_free_blocks(self) -> int:
- """Get the current number of unallocated blocks available for new requests."""
- return self._block_manager.num_free_blocks
- @traced
- def extend_read_and_write_indices(
- self,
- request_id: str,
- past_length: int,
- query_length: int,
- read_index: list[list[int]],
- write_index: list[list[int]],
- ) -> None:
- """Retrieve physical cache indices for reading KV states in the cache across all layer groups. This method
- coordinates with all cache managers to build the complete set of read indices needed for attention computation.
- """
- for cm, read_indices, write_indices in zip(self.group_cache_managers, read_index, write_index):
- indices = cm.get_read_indices(request_id, past_length, query_length)
- read_indices.extend(indices)
- indices = cm.get_write_indices(request_id, past_length, query_length)
- write_indices.extend(indices)
- def fill_block_table(
- self, request_id: str, past_length: int, query_length: int, block_table: torch.Tensor
- ) -> None:
- for i, cm in enumerate(self.group_cache_managers):
- cm.fill_block_table(request_id, past_length, query_length, block_table[i])
- @traced
- def get_seqlens_k(self, past_length: int, query_length: int) -> dict[str, int]:
- """Retrieve the key sequence length for the given request_id across all layer types. Returns a dictionary of
- layer types to their corresponding key sequence lengths."""
- seqlens_k = {}
- if self.num_full_attention_groups > 0:
- seqlens_k["full_attention"] = past_length + query_length
- if self.num_sliding_attention_groups > 0:
- seqlens_k["sliding_attention"] = query_length + min(past_length, self.config.sliding_window - 1)
- # NOTE: when we add more attention types / different sliding windows, we can go back to looping over CMs
- return seqlens_k
- @traced
- def update(
- self,
- key_states: torch.Tensor, # shape [1, num_kv_heads, seqlen_kv, head_dim]
- value_states: torch.Tensor, # shape [1, num_kv_heads, seqlen_kv, head_dim]
- layer_idx: int,
- read_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_kv + past_length]
- write_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_q]
- ) -> tuple[torch.Tensor, torch.Tensor]: # shape [seqlen_kv + past_length, num_kv_heads, head_dim]
- """Update the cache with new key-value states for a specific layer. This method writes new KV states to the
- appropriate cache locations. The behavior differs based on the layer's attention type:
- - Full attention: New KV states are written to cache, then complete sequence is read from cache
- - Sliding window: Old KV is read from cache along with extra spaces for the new KV, then new KV is written to
- cache. This is because new KV might overwrite the old KV, so we need to read the old KV first.
- Returns the complete KV states (cached + new) for attention computation.
- """
- # Retrieve the layer read and write indices, and if there is a sliding window
- group_idx, layer_idx_in_group = self.layer_index_to_group_indices[layer_idx]
- layer_read_index = read_index[group_idx]
- layer_write_index = write_index[group_idx]
- # Select the correct cache
- k_cache = self.key_cache[layer_idx_in_group]
- v_cache = self.value_cache[layer_idx_in_group]
- # Transpose the key and value states to match the cache shape, after which shape is [seqlen_kv, num_kv_heads, head_dim]
- key_states = key_states.transpose(1, 2).squeeze(0)
- value_states = value_states.transpose(1, 2).squeeze(0)
- # Case: full attention
- sliding_window = self.sliding_windows[layer_idx]
- if sliding_window == 1:
- k_cache[layer_write_index, :, :] = key_states
- v_cache[layer_write_index, :, :] = value_states
- key_states_with_cache = k_cache[layer_read_index, :, :]
- value_states_with_cache = v_cache[layer_read_index, :, :]
- # Case: sliding window -- we need to be careful of read/write order because of chunked prefill, because it's
- # the only case where you may write over cache you need to use
- else:
- # Add the cache to the key and value states
- mask = (layer_read_index == -1).unsqueeze(-1).unsqueeze(-1) # TODO: should this be precomputed?
- key_states_with_cache = k_cache[layer_read_index, :, :]
- key_states_with_cache.masked_scatter_(mask, key_states)
- value_states_with_cache = v_cache[layer_read_index, :, :]
- value_states_with_cache.masked_scatter_(mask, value_states)
- # Write new KV values to the cache
- k_cache[layer_write_index, :, :] = key_states
- v_cache[layer_write_index, :, :] = value_states
- # Return the new KV values
- return key_states_with_cache, value_states_with_cache
- def get_block_table_key(self, flash_attn_with_kvcache_fn: Any) -> str:
- """A function to get the name of the block table key for the given flash_attn_with_kvcache_fn. The function's
- signature is only inspected once. This is necessary because different version of flash have different names for
- the block table key."""
- if self._block_table_key is None:
- kwarg_names = inspect.signature(flash_attn_with_kvcache_fn).parameters.keys()
- if "block_table" in kwarg_names:
- self._block_table_key = "block_table"
- elif "page_table" in kwarg_names:
- self._block_table_key = "page_table"
- else:
- raise ValueError(
- f"flash_attn_with_kvcache_fn does not have a block_table or page_table argument: {inspect.signature(flash_attn_with_kvcache_fn)}"
- )
- return self._block_table_key
- def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int:
- """Searches for a prefix match in the cache for the given (prompts_ids). If one is found, we reference the
- matching blocks in the (request_id), increase the reference count of the blocks and return the number of blocks
- that match. If no prefix match is found, we return 0."""
- current_hash = None
- allocated_blocks = []
- for b in range(len(prompt_ids) // self.block_size):
- tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size]
- # Prefix sharing is only supported when there is only one full attention layer group, so group_id=0.
- current_hash = self._block_manager.compute_hash(current_hash, tokens, group_id=0)
- block_id = self._block_manager._hash_to_id.get(current_hash)
- if block_id is not None:
- allocated_blocks.append(block_id)
- self._block_manager.increase_ref_count(block_id)
- else:
- break
- # If we found a matching prefix, we reference the blocks in the request
- if allocated_blocks:
- logger.debug(f"Found prefix match for request {request_id} with {len(allocated_blocks)} blocks")
- cm = self.group_cache_managers[0]
- cm.block_table[request_id] = allocated_blocks
- prefix_length = len(allocated_blocks) * self.block_size
- self._total_prefix_length += prefix_length
- return prefix_length
- def mark_shareable_blocks_as_complete(self, state: RequestState, num_complete_blocks: int) -> None:
- """Marks the blocks allocated to a request (state) as complete if they are shareable and they have been computed
- in the forward pass. A complete block is a block where the KV cache has been fully computed: if the block has
- enough space to hold the cache for N tokens, the block is marked as complete when the cache data is present for
- the N tokens. If block sharing is off, this is a no-op."""
- # The status can be FINISHED in async mode, because batch N+1 offloaded the request before batch N was over. So
- # we need to check for this case to avoid looking in the block table for blocks that no longer exist.
- if num_complete_blocks == 0 or state.status == RequestStatus.FINISHED:
- return None
- for cm in self.group_cache_managers:
- if cm.uses_block_sharing:
- self._block_manager.mark_shareable_blocks_as_complete(
- num_complete_blocks=num_complete_blocks,
- allocated_blocks=cm.block_table[state.request_id],
- prompt_ids=(state.initial_tokens + state.generated_tokens),
- )
- def copy_cache(self, list_source_blocks: list[int], list_forked_blocks: list[int]) -> None:
- """Copy the cache from the source blocks to the forked blocks."""
- source_blocks = torch.tensor(list_source_blocks, device=self.device, dtype=torch.int32)
- forked_blocks = torch.tensor(list_forked_blocks, device=self.device, dtype=torch.int32)
- for key_cache, value_cache in zip(self.key_cache, self.value_cache):
- key_cache = key_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim)
- value_cache = value_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim)
- key_cache[forked_blocks] = key_cache[source_blocks]
- value_cache[forked_blocks] = value_cache[source_blocks]
- # FIXME: consolidate the cache into a single tensor of shape (group_size, 2, *self.k_or_v_cache_shape)
- # This will allow for better .update and a single copy instead of one per cache tensor
- def fork_request(self, source_request_id: str, destination_request_ids: list[str]) -> tuple[list[int], list[int]]:
- """Fork the cache of a request (state) into the one of a list of requests with the given (dst_request_ids)."""
- # These lists will be the accumulators for the source and destination blocks for the cache copy
- source_blocks, destination_blocks = [], []
- # Main fork loop
- for cm in self.group_cache_managers:
- src_blocks, dst_blocks = cm.fork_blocks(source_request_id, destination_request_ids, self._block_manager)
- source_blocks.extend(src_blocks)
- destination_blocks.extend(dst_blocks)
- return source_blocks, destination_blocks
- def free_all_requests(self) -> None:
- """Free all blocks allocated to requests across all cache managers. This preserves prefix hashes in the block
- manager (blocks become initialized rather than uninitialized if they were complete), allowing prefix sharing
- to work across generation sessions."""
- all_request_ids = set()
- for cm in self.group_cache_managers:
- all_request_ids.update(cm.block_table.keys())
- for request_id in all_request_ids:
- self.free_blocks(request_id)
- # TODO: rework computation with the groups and their sizes
- class PagedAttentionMemoryHandler:
- """A helper class to determine the best number of pages and maximum number of tokens per batch for the paged
- attention cache, providing automatic sizing based on available GPU memory.
- The helper works using the number of pages, which is tied to the number of blocks by:
- num_blocks = num_pages // block_size
- The memory footprint consists of three main components:
- - Cache memory: the space needed to store the cache tensors:
- 2 * layer_group_size * [num_pages, page_size] * cache_dtype
- - Activation memory: the space temporarily taken by the largest activation during the model forward pass:
- peak_activation_per_token * max_tokens_per_batch * activation_dtype_size
- - Static tensors: the space taken by the input/output buffers and metadata tensors for batch processing, sum of:
- - inputs_ids + outputs_ids + position_ids + logits_indices: 4 * max_tokens_per_batch * int32_size
- - attention_mask: num_attention_masks * num_pages * max_tokens_per_batch * activation_dtype_size
- - cumulative_seqlens_q + cumulative_seqlens_k: (1 + 2) * max_tokens_per_batch * int32_size
- - write_index_tensor: num_groups * max_tokens_per_batch * int32_size
- - read_index_tensor: num_groups * (num_pages + max_tokens_per_batch) * int32_size
- The handler can operate in three modes:
- 1. Auto-sizing: Determines both number of pages and maximum number of tokens per batch using quadratic optimization
- 2. Fixed cache: Calculates max batch tokens given a fixed number of pages
- 3. Fixed batch: Calculates number of pages given a fixed maximum batch size
- """
- _activation_dtype = torch.bfloat16
- _input_dtype = torch.int32
- _upper_bound_max_batch_tokens = 256
- _upper_bound_num_blocks = 4096
- def __init__(
- self,
- block_size: int,
- page_size: int,
- num_groups: int,
- group_size: int,
- peak_activation_per_token: int,
- num_attention_masks: int,
- ) -> None:
- """Initialize the memory handler with the parameters that cannot be automatically inferred.
- Args:
- block_size: Size of the cache blocks
- page_size: Size of the cache pages
- num_groups: Number of layer groups
- group_size: Number of layers per layer group
- peak_activation_per_token: Maximum size of activation tensor per token, = hidden_size + vocab_size
- num_attention_masks: Number of attention masks, 0 if no attention mask is used, 2 if hybrid model, else 1
- """
- self.block_size = block_size
- self.page_size = page_size
- self.num_groups = num_groups
- self.group_size = group_size
- self.peak_activation_per_token = peak_activation_per_token
- self.num_attention_masks = num_attention_masks
- @staticmethod
- def get_available_memory(max_memory_percent: float = 1.0) -> int:
- """Calculate available GPU memory for cache allocation, accounting for already allocated tensors.
- This method queries the current memory state and applies the specified percentage limit to determine
- how much memory can be safely used for the paged attention cache.
- Args:
- max_memory_percent: Fraction of available memory to use (0.0-1.0). 1.0 means use all available memory.
- Returns:
- int: Available memory in bytes for cache allocation
- """
- _, total, reserved, allocated = get_device_and_memory_breakdown()
- available_memory = total - max(allocated, reserved)
- available_memory = int(available_memory * max_memory_percent)
- return available_memory
- def infer_num_blocks_and_max_batch_tokens(
- self,
- num_blocks: int | None = None,
- max_batch_tokens: int | None = None,
- max_memory_percent: float = 0.8, # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI
- cache_dtype: torch.dtype = torch.float16,
- ) -> tuple[int, int]:
- """Determine optimal number of blocks and maximum number of tokens per batch based on available memory and
- constraints. Check the class docstring for more details. Naming the number of pages as N and the maximum number
- of tokens per batch as M, the equation solved is:
- available_memory = sum([
- MN * num_attention_masks * activation_dtype_size,
- 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
- M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
- ])
- where we already simplified int32_size = 4.
- """
- if num_blocks is None:
- if max_batch_tokens is None:
- # If neither num_blocks nor max_batch_tokens are provided, we use a second-order polynomial
- num_blocks, max_batch_tokens = self.compute_num_blocks_and_max_batch_tokens(
- max_memory_percent, cache_dtype
- )
- else:
- # If only max_batch_tokens is provided, we infer the num_blocks
- num_blocks = self.compute_num_blocks(max_batch_tokens, max_memory_percent, cache_dtype)
- elif max_batch_tokens is None:
- # If only num_blocks is provided, we infer the max_batch_tokens
- max_batch_tokens = self.compute_max_batch_tokens(num_blocks, max_memory_percent, cache_dtype)
- else:
- # If both num_blocks and max_batch_tokens are provided, we use them (useless, but helps with typing)
- max_batch_tokens = max_batch_tokens
- # We check if the memory footprint is too large in all cases
- available_memory = self.get_available_memory(max_memory_percent)
- memory_footprint = self.compute_memory_footprint(
- max_batch_tokens=max_batch_tokens, num_blocks=num_blocks, cache_dtype=cache_dtype
- )
- if memory_footprint > available_memory:
- raise MemoryError(f"Memory footprint {memory_footprint} is more than available memory {available_memory}")
- return num_blocks, max_batch_tokens
- def compute_num_blocks_and_max_batch_tokens(
- self,
- max_memory_percent: float,
- cache_dtype: torch.dtype = torch.float16,
- m: float = 0.01,
- ) -> tuple[int, int]:
- """Calculate optimal number of blocks and maximum number of tokens per batch using quadratic optimization when
- neither is fixed. This method assumes a relationship M = m * N where m is a small ratio below 1 and solves the
- resulting quadratic equation to find the optimal N that maximizes utilization within memory constraints. m is
- the amount of cache we can fill with one batch: m=0.01 means a batch fills at most 1% of the cache. The equation
- to solve is:
- available_memory = sum([
- m * N^2 * num_attention_masks * activation_dtype_size,
- 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
- m * N * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
- ])
- If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial.
- """
- cache_memory = self.get_available_memory(max_memory_percent)
- logger.info(f"Cache memory: {cache_memory}")
- # Compute second-degree polynomial coefficients
- a = m * self.num_attention_masks * self._activation_dtype.itemsize
- b = 2 * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
- b += m * (self.peak_activation_per_token * self._activation_dtype.itemsize + 28 + 4 * self.num_groups)
- c = -cache_memory
- logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")
- # If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial
- if self.num_attention_masks == 0:
- greatest_solution = -c / b
- # Otherwise, we solve the quadratic equation
- else:
- discriminant = b**2 - 4 * a * c
- if discriminant < 0:
- raise ValueError(f"Discriminant is negative: {discriminant = }")
- greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
- if greatest_solution < 0:
- raise ValueError(f"Greatest solution is negative: {greatest_solution = }")
- # Infer number of blocks and max batch tokens
- num_pages = floor(greatest_solution)
- num_blocks = num_pages // self.block_size
- if num_blocks > self._upper_bound_num_blocks:
- logger.info(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }")
- num_blocks = self._upper_bound_num_blocks
- max_batch_tokens = int(greatest_solution * m)
- if max_batch_tokens > self._upper_bound_max_batch_tokens:
- logger.info(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }")
- max_batch_tokens = self._upper_bound_max_batch_tokens
- return num_blocks, max_batch_tokens
- def compute_max_batch_tokens(
- self,
- num_blocks: int,
- max_memory_percent: float,
- cache_dtype: torch.dtype = torch.float16,
- ) -> int:
- """Calculate maximum batch tokens M given a fixed number of cache blocks. The formula for M is given by:
- M = (available_memory - 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group))
- / (activation_dtype_size * (N * num_attention_masks + peak_activation_per_token) + 28 + 4 * num_group)
- """
- cache_memory = self.get_available_memory(max_memory_percent)
- num_pages = num_blocks * self.block_size
- # Compute numerator
- num = cache_memory
- num -= 2 * num_pages * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
- # Compute denominator
- denum = self._activation_dtype.itemsize * (
- num_pages * self.num_attention_masks + self.peak_activation_per_token
- )
- denum += 28 + 4 * self.num_groups
- # Compute max batch tokens and return
- max_batch_tokens = floor(num / denum)
- if max_batch_tokens > self._upper_bound_max_batch_tokens:
- logger.info(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }")
- max_batch_tokens = self._upper_bound_max_batch_tokens
- return max_batch_tokens
- def compute_num_blocks(
- self,
- max_batch_tokens: int,
- max_memory_percent: float,
- cache_dtype: torch.dtype = torch.float16,
- ) -> int:
- """Calculate number of cache blocks N given a fixed maximum token per token M. The formula for N is given by:
- N = (available_memory - M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group))
- / (2 * (layer_group_size * page_size * cache_dtype + 2 * num_group) + M * (num_attention_masks * activation_dtype_size))
- """
- cache_memory = self.get_available_memory(max_memory_percent)
- # Compute numerator
- num = cache_memory
- num -= max_batch_tokens * self.peak_activation_per_token * self._activation_dtype.itemsize
- num -= max_batch_tokens * (28 + 4 * self.num_groups)
- # Compute denominator
- denum = 2 * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
- denum += max_batch_tokens * (self.num_attention_masks * self._activation_dtype.itemsize)
- denum += max_batch_tokens * self._activation_dtype.itemsize
- # Compute cache size and return number of blocks
- num_pages = floor(num / denum)
- num_blocks = num_pages // self.block_size
- if num_blocks > self._upper_bound_num_blocks:
- logger.info(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }")
- num_blocks = self._upper_bound_num_blocks
- return num_blocks
- def compute_memory_footprint(
- self,
- num_blocks: int,
- max_batch_tokens: int,
- cache_dtype: torch.dtype,
- ) -> int:
- """Calculate the memory footprint breakdown for a given number of blocks and maximum batch tokens. The memory
- footprint is given by:
- available_memory = sum([
- MN * num_attention_masks * activation_dtype_size,
- 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
- M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
- ])
- but is broken down below.
- """
- num_pages = num_blocks * self.block_size
- cache_memory_footprint = 2 * self.group_size * num_pages * self.page_size * cache_dtype.itemsize
- activation_memory_footprint = self.peak_activation_per_token * self._activation_dtype.itemsize
- activation_memory_footprint *= max_batch_tokens
- inputs_outputs_positions_and_logits_memory_footprint = 4 * max_batch_tokens * 4 # second 4 is for int32 size
- attention_memory_footprint = self.num_attention_masks * self._activation_dtype.itemsize
- attention_memory_footprint *= num_pages * max_batch_tokens
- cumulative_seqlens_memory_footprint = 3 * max_batch_tokens * 4 # 4 is for int32 size
- write_index_memory_footprint = self.num_groups * max_batch_tokens * 4 # 4 is for int32 size
- read_index_memory_footprint = self.num_groups * (num_pages + max_batch_tokens) * 4 # 4 is for int32 size
- total_memory_footprint = sum(
- [
- cache_memory_footprint,
- activation_memory_footprint,
- inputs_outputs_positions_and_logits_memory_footprint,
- attention_memory_footprint,
- cumulative_seqlens_memory_footprint,
- write_index_memory_footprint,
- read_index_memory_footprint,
- ]
- )
- return total_memory_footprint
|