cache.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761
  1. # Copyright 2025 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. from math import floor, gcd, sqrt
  16. from typing import Any
  17. import torch
  18. from ...configuration_utils import PreTrainedConfig
  19. from ...generation.configuration_utils import ContinuousBatchingConfig
  20. from ...utils.generic import is_flash_attention_requested
  21. from ...utils.metrics import attach_tracer, traced
  22. from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
  23. from .requests import RequestState, RequestStatus, get_device_and_memory_breakdown, logger
  24. def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]], list[str]]:
  25. """
  26. Group layers depending on the attention mix, according to VLLM's hybrid allocator rules:
  27. - Layers in each group need to have the same type of attention
  28. - All groups have the same number of layers
  29. For a model with the following layer types: ["sliding", "full", "full", "sliding", "full", "full", "full", "full"]
  30. We would get four groups: [0, 3], [1, 2], [4,5] and [6,7].
  31. """
  32. # If the config has no layer_type attribute, it means all layers are the same attention type
  33. layer_types = getattr(config, "layer_types", None)
  34. if layer_types is None:
  35. attn_type = "sliding_attention" if getattr(config, "sliding_window", None) is not None else "full_attention"
  36. layer_types = [attn_type for _ in range(config.num_hidden_layers)]
  37. # We then count the number of layers of each type
  38. layer_counts = {}
  39. for i, layer_type in enumerate(layer_types):
  40. layer_counts[layer_type] = layer_counts.get(layer_type, []) + [i]
  41. # The size of all groups is the greatest common divisor of the number of layers of each type
  42. group_size = gcd(*[len(indices) for indices in layer_counts.values()])
  43. # We then group the layers by type
  44. layer_groups = []
  45. for layer_type, indices in layer_counts.items():
  46. for i in range(0, len(indices), group_size):
  47. layer_groups.append(indices[i : i + group_size])
  48. # And note the layer types
  49. group_types = [layer_types[lg[0]] for lg in layer_groups]
  50. return layer_groups, group_types
  51. @attach_tracer()
  52. class PagedAttentionCache:
  53. """
  54. Manages the cache for a paged attention mechanism, inspired by VLLM's hybrid allocator. The cache relies on making
  55. groups of layers to reduce the complexity of cache management and fragmentation.
  56. The cache uses a three-level hierarchy:
  57. - Pages: The smallest unit of cache, a page has a size of [num_heads, head_size], which is the space needed to
  58. store the key or value states for one token and one layer. For a model with only full-attention layers, to store
  59. the KV cache of one token, we need `2 * num_layers` pages: key and values each take `num_layers` pages.
  60. Pages are grouped into blocks:
  61. - Blocks: A block is a collection of `block_size` pages, serving as the allocation unit to reduce management
  62. complexity and fragmentation. Cache is allocated and freed block by block, not page by page. One block is
  63. allocated to one layer group, which only has one attention type, like full-attention or sliding-attention.
  64. If all layers in the model have the same attention type, then all layers will be in the same group. There is
  65. more than one group if and only if the model has a mixed attention types, like layers with full-attention and
  66. layers with sliding-attention.
  67. - Cache tensors: The physical supports for the cache. There are as many cache tensors as there are layer in a
  68. layer group, and the shape of the cache tensor is `[num_blocks * block_size, num_heads, head_size]`.
  69. Grouping layers into groups is useful because when we allocate one block to a group N, the block allocated is the
  70. same for all layers in group N, equivalently it is allocated across all cache tensors. This allows us to
  71. efficiently allocate and free blocks, and to efficiently read and write key and value states.
  72. For instance, imagine we have 8 blocks of cache and a model with two layer groups: a full-attention group with 3
  73. layers and a sliding-attention group with 3 layers. At creation time, the physical cache tensors look like this:
  74. cache_tensor_0: □ □ □ □ □ □ □ □
  75. cache_tensor_1: □ □ □ □ □ □ □ □
  76. cache_tensor_2: □ □ □ □ □ □ □ □
  77. where □ means the blocks is not allocated to any layer group yet. We have 3 cache tensors because there are
  78. 3 layers per group.
  79. We allocate 1 block to each group, after allocation, the cache tensors look like this:
  80. cache_tensor_0: ✖ ◉ □ □ □ □ □ □
  81. cache_tensor_1: ✖ ◉ □ □ □ □ □ □
  82. cache_tensor_2: ✖ ◉ □ □ □ □ □ □
  83. where ✖ means the block is allocated to the full-attention group, and ◉ means the block is allocated to the
  84. sliding-attention group.
  85. Now, if we continue to generate, and the sliding window has been reached, we only need to allocate a new block
  86. for the full-attention group, and the cache tensors look like this:
  87. cache_tensor_0: ✖ ◉ ✖ □ □ □ □ □
  88. cache_tensor_1: ✖ ◉ ✖ □ □ □ □ □
  89. cache_tensor_2: ✖ ◉ ✖ □ □ □ □ □
  90. And after further generation, when we need a new block allocated:
  91. cache_tensor_0: ✖ ◉ ✖ ✖ □ □ □ □
  92. cache_tensor_1: ✖ ◉ ✖ ✖ □ □ □ □
  93. cache_tensor_2: ✖ ◉ ✖ ✖ □ □ □ □
  94. This would not have been possible if all layers were in the same group: we would have had to allocate a new block
  95. for the sliding-attention group, although it is not needed.
  96. """
  97. def __init__(
  98. self,
  99. config: PreTrainedConfig,
  100. continuous_batching_config: ContinuousBatchingConfig,
  101. device: torch.device | str,
  102. dtype: torch.dtype = torch.float16,
  103. tp_size: int | None = None,
  104. ) -> None:
  105. """Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has
  106. only full attention layers.
  107. Args:
  108. config: Model configuration
  109. continuous_batching_config: Continuous batching configuration containing cache parameters
  110. device: Device for the cache tensors
  111. dtype: Data type of the cache
  112. tp_size: Tensor parallelism size
  113. """
  114. self.config = config
  115. self.dtype = dtype
  116. self.device = device
  117. # Extract model dimensions
  118. kv_heads = getattr(config, "num_key_value_heads", None)
  119. self.num_key_value_heads: int = kv_heads if kv_heads is not None else config.num_attention_heads
  120. head_dim = getattr(config, "head_dim", None)
  121. self.head_dim: int = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads
  122. # Extract cache dimensions. Default used to be 32, now it's 256 to be compatible with flash_with_kvcache.
  123. self.block_size = continuous_batching_config.block_size
  124. if self.block_size <= 0:
  125. raise ValueError(f"Block size must be positive, but got {self.block_size}")
  126. # Group layers depending on the attention mix
  127. layer_groups, group_types = group_layers_by_attn_type(config)
  128. group_size = len(layer_groups[0])
  129. self.num_groups = len(layer_groups)
  130. self.sliding_windows = {}
  131. self.layer_index_to_group_indices = {}
  132. for i, group in enumerate(layer_groups):
  133. sliding_window = config.sliding_window if group_types[i] == "sliding_attention" else 1
  134. for j, layer in enumerate(group):
  135. self.layer_index_to_group_indices[layer] = (i, j)
  136. self.sliding_windows[layer] = sliding_window
  137. # Handle TP (or dont)
  138. if tp_size is not None and tp_size > 1:
  139. if self.num_key_value_heads % tp_size != 0:
  140. raise ValueError(
  141. f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
  142. )
  143. # If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
  144. # self.num_key_value_heads //= tp_size # TODO: why is this commented out?
  145. # Infer number of blocks and max batch tokens
  146. page_size = self.head_dim * self.num_key_value_heads
  147. if is_flash_attention_requested(self.config):
  148. num_attention_masks = 0 # only used to compute the default memory footprint args
  149. elif "sliding_attention" in group_types:
  150. # TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
  151. num_attention_masks = 2
  152. else:
  153. num_attention_masks = 1
  154. memory_handler = PagedAttentionMemoryHandler(
  155. block_size=self.block_size,
  156. page_size=page_size,
  157. num_groups=self.num_groups,
  158. group_size=group_size,
  159. peak_activation_per_token=(config.hidden_size + config.vocab_size),
  160. num_attention_masks=num_attention_masks,
  161. )
  162. num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens(
  163. num_blocks=continuous_batching_config.num_blocks,
  164. max_batch_tokens=continuous_batching_config.max_batch_tokens,
  165. max_memory_percent=continuous_batching_config.max_memory_percent,
  166. cache_dtype=self.dtype,
  167. )
  168. # Add the inferred attributes to the class
  169. self.num_blocks = num_blocks
  170. self.max_batch_tokens = max_batch_tokens
  171. self.num_pages = self.num_blocks * self.block_size
  172. logger.info(
  173. f"PagedAttentionCache initialized with {self.num_blocks = }, {self.block_size = }, {page_size = }, "
  174. f"{self.max_batch_tokens = } {num_attention_masks = }"
  175. )
  176. # If max_blocks_per_request is not set, the default value is 16 max blocks. With default block size of 256, this
  177. # means a max sequence length of 4096 tokens for the fast decode path.
  178. max_blocks_per_request = continuous_batching_config.max_blocks_per_request
  179. if max_blocks_per_request is None:
  180. max_blocks_per_request = 0
  181. # logger.info( TODO: uncomment when we have good defaults
  182. # f"max_blocks_per_request was not set, using {max_blocks_per_request}. This means max sequence "
  183. # f"length for the decode fast path is {max_blocks_per_request * self.block_size}."
  184. # )
  185. self.max_blocks_per_request = max_blocks_per_request
  186. # Initialize the cache
  187. self.key_cache: list[torch.Tensor] = []
  188. self.value_cache: list[torch.Tensor] = []
  189. # We add two extra tokens to the cache to handle padding and generally discard unwanted tokens
  190. self.cache_shape = ((num_blocks + 2) * self.block_size, self.num_key_value_heads, self.head_dim)
  191. for _ in range(group_size):
  192. new_layer_key_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
  193. new_layer_value_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
  194. torch._dynamo.mark_static_address(new_layer_key_cache)
  195. torch._dynamo.mark_static_address(new_layer_value_cache)
  196. self.key_cache.append(new_layer_key_cache)
  197. self.value_cache.append(new_layer_value_cache)
  198. logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")
  199. # Block management data structures
  200. self.allow_block_sharing = continuous_batching_config.allow_block_sharing
  201. self.group_cache_managers: list[CacheAllocator] = []
  202. self.num_full_attention_groups = 0
  203. self.num_sliding_attention_groups = 0
  204. self.max_sliding_window_blocks_per_request = 0
  205. for i, group_type in enumerate(group_types):
  206. if group_type == "full_attention":
  207. cm = FullAttentionCacheAllocator(i, self.block_size, allow_block_sharing=self.allow_block_sharing)
  208. self.num_full_attention_groups += 1
  209. elif group_type == "sliding_attention":
  210. cm = SlidingAttentionCacheAllocator(i, self.block_size, config.sliding_window)
  211. self.num_sliding_attention_groups += 1
  212. self.max_sliding_window_blocks_per_request = cm._max_blocks_per_request
  213. else:
  214. raise ValueError(f"Invalid group type: {group_type}")
  215. self.group_cache_managers.append(cm)
  216. # We only use prefix sharing if the whole model has only full attention layers and block sharing is allowed
  217. self.use_prefix_sharing = self.allow_block_sharing and group_types == ["full_attention"]
  218. self._block_manager = BlockManager(num_blocks, self.block_size)
  219. self._total_prefix_length: int = 0 # a counter to measure the impact of prefix sharing, also used in tests
  220. # For block table support, we lazy init the name of the block table key
  221. self._block_table_key = None
  222. def will_allocation_be_successful(self, num_requested_blocks: int, allocated_blocks: int) -> bool:
  223. """Returns a boolean indicating if the allocation of (num_requested_blocks) blocks will be successful. The
  224. number of newly allocated blocks needed is predicted by the following rules:
  225. - for full attention groups: since there is no sliding window for full attention layers, one requested block is
  226. always equivalent to one newly allocated block for EACH full attention group
  227. - for sliding window groups: because of the sliding window, the number of blocks allocated to a request is
  228. capped. Using the number of already (allocated_blocks) we can compute the number of new blocks to actually
  229. allocate to the request, which can be lower than the number of requested blocks. That number is the same for
  230. all sliding window groups, as only one sliding window size is supported.
  231. """
  232. # This is not in a branch, because it is very rare to have zero full attention layer
  233. needed_blocks = num_requested_blocks * self.num_full_attention_groups
  234. # Only take this branch if the model has sliding window attention layers
  235. if self.num_sliding_attention_groups:
  236. blocks_left = max(self.max_sliding_window_blocks_per_request - allocated_blocks, 0)
  237. needed_blocks += min(blocks_left, num_requested_blocks) * self.num_sliding_attention_groups
  238. return needed_blocks <= self.get_num_free_blocks()
  239. @traced
  240. def allocate_blocks(self, n_blocks: int, request_id: str, allocated_blocks: int) -> int | None:
  241. """Allocate cache blocks across all layer groups for a given request. Actual allocation is done by the cache
  242. managers, and this method only returns the maximum number of blocks actually allocated across all managers."""
  243. # First check allocation will be successful before starting, to avoid partial allocations
  244. if not self.will_allocation_be_successful(n_blocks, allocated_blocks):
  245. return None
  246. # Allocate blocks across all cache managers
  247. max_allocated = 0
  248. for cm in self.group_cache_managers:
  249. num_allocated_blocks = cm.allocate_blocks(n_blocks, request_id, self._block_manager)
  250. if num_allocated_blocks is None:
  251. raise ValueError(f"Failed to allocate {n_blocks} blocks for request {request_id}")
  252. max_allocated = max(max_allocated, num_allocated_blocks)
  253. return max_allocated
  254. @traced
  255. def free_blocks(self, request_id: str) -> None:
  256. """Free all allocated cache blocks for a given request across all layer groups. Actual deallocation is done
  257. by the cache managers."""
  258. for cm in self.group_cache_managers:
  259. cm.free_blocks(request_id, self._block_manager)
  260. def get_num_free_blocks(self) -> int:
  261. """Get the current number of unallocated blocks available for new requests."""
  262. return self._block_manager.num_free_blocks
  263. @traced
  264. def extend_read_and_write_indices(
  265. self,
  266. request_id: str,
  267. past_length: int,
  268. query_length: int,
  269. read_index: list[list[int]],
  270. write_index: list[list[int]],
  271. ) -> None:
  272. """Retrieve physical cache indices for reading KV states in the cache across all layer groups. This method
  273. coordinates with all cache managers to build the complete set of read indices needed for attention computation.
  274. """
  275. for cm, read_indices, write_indices in zip(self.group_cache_managers, read_index, write_index):
  276. indices = cm.get_read_indices(request_id, past_length, query_length)
  277. read_indices.extend(indices)
  278. indices = cm.get_write_indices(request_id, past_length, query_length)
  279. write_indices.extend(indices)
  280. def fill_block_table(
  281. self, request_id: str, past_length: int, query_length: int, block_table: torch.Tensor
  282. ) -> None:
  283. for i, cm in enumerate(self.group_cache_managers):
  284. cm.fill_block_table(request_id, past_length, query_length, block_table[i])
  285. @traced
  286. def get_seqlens_k(self, past_length: int, query_length: int) -> dict[str, int]:
  287. """Retrieve the key sequence length for the given request_id across all layer types. Returns a dictionary of
  288. layer types to their corresponding key sequence lengths."""
  289. seqlens_k = {}
  290. if self.num_full_attention_groups > 0:
  291. seqlens_k["full_attention"] = past_length + query_length
  292. if self.num_sliding_attention_groups > 0:
  293. seqlens_k["sliding_attention"] = query_length + min(past_length, self.config.sliding_window - 1)
  294. # NOTE: when we add more attention types / different sliding windows, we can go back to looping over CMs
  295. return seqlens_k
  296. @traced
  297. def update(
  298. self,
  299. key_states: torch.Tensor, # shape [1, num_kv_heads, seqlen_kv, head_dim]
  300. value_states: torch.Tensor, # shape [1, num_kv_heads, seqlen_kv, head_dim]
  301. layer_idx: int,
  302. read_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_kv + past_length]
  303. write_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_q]
  304. ) -> tuple[torch.Tensor, torch.Tensor]: # shape [seqlen_kv + past_length, num_kv_heads, head_dim]
  305. """Update the cache with new key-value states for a specific layer. This method writes new KV states to the
  306. appropriate cache locations. The behavior differs based on the layer's attention type:
  307. - Full attention: New KV states are written to cache, then complete sequence is read from cache
  308. - Sliding window: Old KV is read from cache along with extra spaces for the new KV, then new KV is written to
  309. cache. This is because new KV might overwrite the old KV, so we need to read the old KV first.
  310. Returns the complete KV states (cached + new) for attention computation.
  311. """
  312. # Retrieve the layer read and write indices, and if there is a sliding window
  313. group_idx, layer_idx_in_group = self.layer_index_to_group_indices[layer_idx]
  314. layer_read_index = read_index[group_idx]
  315. layer_write_index = write_index[group_idx]
  316. # Select the correct cache
  317. k_cache = self.key_cache[layer_idx_in_group]
  318. v_cache = self.value_cache[layer_idx_in_group]
  319. # Transpose the key and value states to match the cache shape, after which shape is [seqlen_kv, num_kv_heads, head_dim]
  320. key_states = key_states.transpose(1, 2).squeeze(0)
  321. value_states = value_states.transpose(1, 2).squeeze(0)
  322. # Case: full attention
  323. sliding_window = self.sliding_windows[layer_idx]
  324. if sliding_window == 1:
  325. k_cache[layer_write_index, :, :] = key_states
  326. v_cache[layer_write_index, :, :] = value_states
  327. key_states_with_cache = k_cache[layer_read_index, :, :]
  328. value_states_with_cache = v_cache[layer_read_index, :, :]
  329. # Case: sliding window -- we need to be careful of read/write order because of chunked prefill, because it's
  330. # the only case where you may write over cache you need to use
  331. else:
  332. # Add the cache to the key and value states
  333. mask = (layer_read_index == -1).unsqueeze(-1).unsqueeze(-1) # TODO: should this be precomputed?
  334. key_states_with_cache = k_cache[layer_read_index, :, :]
  335. key_states_with_cache.masked_scatter_(mask, key_states)
  336. value_states_with_cache = v_cache[layer_read_index, :, :]
  337. value_states_with_cache.masked_scatter_(mask, value_states)
  338. # Write new KV values to the cache
  339. k_cache[layer_write_index, :, :] = key_states
  340. v_cache[layer_write_index, :, :] = value_states
  341. # Return the new KV values
  342. return key_states_with_cache, value_states_with_cache
  343. def get_block_table_key(self, flash_attn_with_kvcache_fn: Any) -> str:
  344. """A function to get the name of the block table key for the given flash_attn_with_kvcache_fn. The function's
  345. signature is only inspected once. This is necessary because different version of flash have different names for
  346. the block table key."""
  347. if self._block_table_key is None:
  348. kwarg_names = inspect.signature(flash_attn_with_kvcache_fn).parameters.keys()
  349. if "block_table" in kwarg_names:
  350. self._block_table_key = "block_table"
  351. elif "page_table" in kwarg_names:
  352. self._block_table_key = "page_table"
  353. else:
  354. raise ValueError(
  355. f"flash_attn_with_kvcache_fn does not have a block_table or page_table argument: {inspect.signature(flash_attn_with_kvcache_fn)}"
  356. )
  357. return self._block_table_key
  358. def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int:
  359. """Searches for a prefix match in the cache for the given (prompts_ids). If one is found, we reference the
  360. matching blocks in the (request_id), increase the reference count of the blocks and return the number of blocks
  361. that match. If no prefix match is found, we return 0."""
  362. current_hash = None
  363. allocated_blocks = []
  364. for b in range(len(prompt_ids) // self.block_size):
  365. tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size]
  366. # Prefix sharing is only supported when there is only one full attention layer group, so group_id=0.
  367. current_hash = self._block_manager.compute_hash(current_hash, tokens, group_id=0)
  368. block_id = self._block_manager._hash_to_id.get(current_hash)
  369. if block_id is not None:
  370. allocated_blocks.append(block_id)
  371. self._block_manager.increase_ref_count(block_id)
  372. else:
  373. break
  374. # If we found a matching prefix, we reference the blocks in the request
  375. if allocated_blocks:
  376. logger.debug(f"Found prefix match for request {request_id} with {len(allocated_blocks)} blocks")
  377. cm = self.group_cache_managers[0]
  378. cm.block_table[request_id] = allocated_blocks
  379. prefix_length = len(allocated_blocks) * self.block_size
  380. self._total_prefix_length += prefix_length
  381. return prefix_length
  382. def mark_shareable_blocks_as_complete(self, state: RequestState, num_complete_blocks: int) -> None:
  383. """Marks the blocks allocated to a request (state) as complete if they are shareable and they have been computed
  384. in the forward pass. A complete block is a block where the KV cache has been fully computed: if the block has
  385. enough space to hold the cache for N tokens, the block is marked as complete when the cache data is present for
  386. the N tokens. If block sharing is off, this is a no-op."""
  387. # The status can be FINISHED in async mode, because batch N+1 offloaded the request before batch N was over. So
  388. # we need to check for this case to avoid looking in the block table for blocks that no longer exist.
  389. if num_complete_blocks == 0 or state.status == RequestStatus.FINISHED:
  390. return None
  391. for cm in self.group_cache_managers:
  392. if cm.uses_block_sharing:
  393. self._block_manager.mark_shareable_blocks_as_complete(
  394. num_complete_blocks=num_complete_blocks,
  395. allocated_blocks=cm.block_table[state.request_id],
  396. prompt_ids=(state.initial_tokens + state.generated_tokens),
  397. )
  398. def copy_cache(self, list_source_blocks: list[int], list_forked_blocks: list[int]) -> None:
  399. """Copy the cache from the source blocks to the forked blocks."""
  400. source_blocks = torch.tensor(list_source_blocks, device=self.device, dtype=torch.int32)
  401. forked_blocks = torch.tensor(list_forked_blocks, device=self.device, dtype=torch.int32)
  402. for key_cache, value_cache in zip(self.key_cache, self.value_cache):
  403. key_cache = key_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim)
  404. value_cache = value_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim)
  405. key_cache[forked_blocks] = key_cache[source_blocks]
  406. value_cache[forked_blocks] = value_cache[source_blocks]
  407. # FIXME: consolidate the cache into a single tensor of shape (group_size, 2, *self.k_or_v_cache_shape)
  408. # This will allow for better .update and a single copy instead of one per cache tensor
  409. def fork_request(self, source_request_id: str, destination_request_ids: list[str]) -> tuple[list[int], list[int]]:
  410. """Fork the cache of a request (state) into the one of a list of requests with the given (dst_request_ids)."""
  411. # These lists will be the accumulators for the source and destination blocks for the cache copy
  412. source_blocks, destination_blocks = [], []
  413. # Main fork loop
  414. for cm in self.group_cache_managers:
  415. src_blocks, dst_blocks = cm.fork_blocks(source_request_id, destination_request_ids, self._block_manager)
  416. source_blocks.extend(src_blocks)
  417. destination_blocks.extend(dst_blocks)
  418. return source_blocks, destination_blocks
  419. def free_all_requests(self) -> None:
  420. """Free all blocks allocated to requests across all cache managers. This preserves prefix hashes in the block
  421. manager (blocks become initialized rather than uninitialized if they were complete), allowing prefix sharing
  422. to work across generation sessions."""
  423. all_request_ids = set()
  424. for cm in self.group_cache_managers:
  425. all_request_ids.update(cm.block_table.keys())
  426. for request_id in all_request_ids:
  427. self.free_blocks(request_id)
  428. # TODO: rework computation with the groups and their sizes
  429. class PagedAttentionMemoryHandler:
  430. """A helper class to determine the best number of pages and maximum number of tokens per batch for the paged
  431. attention cache, providing automatic sizing based on available GPU memory.
  432. The helper works using the number of pages, which is tied to the number of blocks by:
  433. num_blocks = num_pages // block_size
  434. The memory footprint consists of three main components:
  435. - Cache memory: the space needed to store the cache tensors:
  436. 2 * layer_group_size * [num_pages, page_size] * cache_dtype
  437. - Activation memory: the space temporarily taken by the largest activation during the model forward pass:
  438. peak_activation_per_token * max_tokens_per_batch * activation_dtype_size
  439. - Static tensors: the space taken by the input/output buffers and metadata tensors for batch processing, sum of:
  440. - inputs_ids + outputs_ids + position_ids + logits_indices: 4 * max_tokens_per_batch * int32_size
  441. - attention_mask: num_attention_masks * num_pages * max_tokens_per_batch * activation_dtype_size
  442. - cumulative_seqlens_q + cumulative_seqlens_k: (1 + 2) * max_tokens_per_batch * int32_size
  443. - write_index_tensor: num_groups * max_tokens_per_batch * int32_size
  444. - read_index_tensor: num_groups * (num_pages + max_tokens_per_batch) * int32_size
  445. The handler can operate in three modes:
  446. 1. Auto-sizing: Determines both number of pages and maximum number of tokens per batch using quadratic optimization
  447. 2. Fixed cache: Calculates max batch tokens given a fixed number of pages
  448. 3. Fixed batch: Calculates number of pages given a fixed maximum batch size
  449. """
  450. _activation_dtype = torch.bfloat16
  451. _input_dtype = torch.int32
  452. _upper_bound_max_batch_tokens = 256
  453. _upper_bound_num_blocks = 4096
  454. def __init__(
  455. self,
  456. block_size: int,
  457. page_size: int,
  458. num_groups: int,
  459. group_size: int,
  460. peak_activation_per_token: int,
  461. num_attention_masks: int,
  462. ) -> None:
  463. """Initialize the memory handler with the parameters that cannot be automatically inferred.
  464. Args:
  465. block_size: Size of the cache blocks
  466. page_size: Size of the cache pages
  467. num_groups: Number of layer groups
  468. group_size: Number of layers per layer group
  469. peak_activation_per_token: Maximum size of activation tensor per token, = hidden_size + vocab_size
  470. num_attention_masks: Number of attention masks, 0 if no attention mask is used, 2 if hybrid model, else 1
  471. """
  472. self.block_size = block_size
  473. self.page_size = page_size
  474. self.num_groups = num_groups
  475. self.group_size = group_size
  476. self.peak_activation_per_token = peak_activation_per_token
  477. self.num_attention_masks = num_attention_masks
  478. @staticmethod
  479. def get_available_memory(max_memory_percent: float = 1.0) -> int:
  480. """Calculate available GPU memory for cache allocation, accounting for already allocated tensors.
  481. This method queries the current memory state and applies the specified percentage limit to determine
  482. how much memory can be safely used for the paged attention cache.
  483. Args:
  484. max_memory_percent: Fraction of available memory to use (0.0-1.0). 1.0 means use all available memory.
  485. Returns:
  486. int: Available memory in bytes for cache allocation
  487. """
  488. _, total, reserved, allocated = get_device_and_memory_breakdown()
  489. available_memory = total - max(allocated, reserved)
  490. available_memory = int(available_memory * max_memory_percent)
  491. return available_memory
  492. def infer_num_blocks_and_max_batch_tokens(
  493. self,
  494. num_blocks: int | None = None,
  495. max_batch_tokens: int | None = None,
  496. max_memory_percent: float = 0.8, # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI
  497. cache_dtype: torch.dtype = torch.float16,
  498. ) -> tuple[int, int]:
  499. """Determine optimal number of blocks and maximum number of tokens per batch based on available memory and
  500. constraints. Check the class docstring for more details. Naming the number of pages as N and the maximum number
  501. of tokens per batch as M, the equation solved is:
  502. available_memory = sum([
  503. MN * num_attention_masks * activation_dtype_size,
  504. 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
  505. M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
  506. ])
  507. where we already simplified int32_size = 4.
  508. """
  509. if num_blocks is None:
  510. if max_batch_tokens is None:
  511. # If neither num_blocks nor max_batch_tokens are provided, we use a second-order polynomial
  512. num_blocks, max_batch_tokens = self.compute_num_blocks_and_max_batch_tokens(
  513. max_memory_percent, cache_dtype
  514. )
  515. else:
  516. # If only max_batch_tokens is provided, we infer the num_blocks
  517. num_blocks = self.compute_num_blocks(max_batch_tokens, max_memory_percent, cache_dtype)
  518. elif max_batch_tokens is None:
  519. # If only num_blocks is provided, we infer the max_batch_tokens
  520. max_batch_tokens = self.compute_max_batch_tokens(num_blocks, max_memory_percent, cache_dtype)
  521. else:
  522. # If both num_blocks and max_batch_tokens are provided, we use them (useless, but helps with typing)
  523. max_batch_tokens = max_batch_tokens
  524. # We check if the memory footprint is too large in all cases
  525. available_memory = self.get_available_memory(max_memory_percent)
  526. memory_footprint = self.compute_memory_footprint(
  527. max_batch_tokens=max_batch_tokens, num_blocks=num_blocks, cache_dtype=cache_dtype
  528. )
  529. if memory_footprint > available_memory:
  530. raise MemoryError(f"Memory footprint {memory_footprint} is more than available memory {available_memory}")
  531. return num_blocks, max_batch_tokens
  532. def compute_num_blocks_and_max_batch_tokens(
  533. self,
  534. max_memory_percent: float,
  535. cache_dtype: torch.dtype = torch.float16,
  536. m: float = 0.01,
  537. ) -> tuple[int, int]:
  538. """Calculate optimal number of blocks and maximum number of tokens per batch using quadratic optimization when
  539. neither is fixed. This method assumes a relationship M = m * N where m is a small ratio below 1 and solves the
  540. resulting quadratic equation to find the optimal N that maximizes utilization within memory constraints. m is
  541. the amount of cache we can fill with one batch: m=0.01 means a batch fills at most 1% of the cache. The equation
  542. to solve is:
  543. available_memory = sum([
  544. m * N^2 * num_attention_masks * activation_dtype_size,
  545. 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
  546. m * N * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
  547. ])
  548. If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial.
  549. """
  550. cache_memory = self.get_available_memory(max_memory_percent)
  551. logger.info(f"Cache memory: {cache_memory}")
  552. # Compute second-degree polynomial coefficients
  553. a = m * self.num_attention_masks * self._activation_dtype.itemsize
  554. b = 2 * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
  555. b += m * (self.peak_activation_per_token * self._activation_dtype.itemsize + 28 + 4 * self.num_groups)
  556. c = -cache_memory
  557. logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")
  558. # If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial
  559. if self.num_attention_masks == 0:
  560. greatest_solution = -c / b
  561. # Otherwise, we solve the quadratic equation
  562. else:
  563. discriminant = b**2 - 4 * a * c
  564. if discriminant < 0:
  565. raise ValueError(f"Discriminant is negative: {discriminant = }")
  566. greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
  567. if greatest_solution < 0:
  568. raise ValueError(f"Greatest solution is negative: {greatest_solution = }")
  569. # Infer number of blocks and max batch tokens
  570. num_pages = floor(greatest_solution)
  571. num_blocks = num_pages // self.block_size
  572. if num_blocks > self._upper_bound_num_blocks:
  573. logger.info(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }")
  574. num_blocks = self._upper_bound_num_blocks
  575. max_batch_tokens = int(greatest_solution * m)
  576. if max_batch_tokens > self._upper_bound_max_batch_tokens:
  577. logger.info(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }")
  578. max_batch_tokens = self._upper_bound_max_batch_tokens
  579. return num_blocks, max_batch_tokens
  580. def compute_max_batch_tokens(
  581. self,
  582. num_blocks: int,
  583. max_memory_percent: float,
  584. cache_dtype: torch.dtype = torch.float16,
  585. ) -> int:
  586. """Calculate maximum batch tokens M given a fixed number of cache blocks. The formula for M is given by:
  587. M = (available_memory - 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group))
  588. / (activation_dtype_size * (N * num_attention_masks + peak_activation_per_token) + 28 + 4 * num_group)
  589. """
  590. cache_memory = self.get_available_memory(max_memory_percent)
  591. num_pages = num_blocks * self.block_size
  592. # Compute numerator
  593. num = cache_memory
  594. num -= 2 * num_pages * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
  595. # Compute denominator
  596. denum = self._activation_dtype.itemsize * (
  597. num_pages * self.num_attention_masks + self.peak_activation_per_token
  598. )
  599. denum += 28 + 4 * self.num_groups
  600. # Compute max batch tokens and return
  601. max_batch_tokens = floor(num / denum)
  602. if max_batch_tokens > self._upper_bound_max_batch_tokens:
  603. logger.info(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }")
  604. max_batch_tokens = self._upper_bound_max_batch_tokens
  605. return max_batch_tokens
  606. def compute_num_blocks(
  607. self,
  608. max_batch_tokens: int,
  609. max_memory_percent: float,
  610. cache_dtype: torch.dtype = torch.float16,
  611. ) -> int:
  612. """Calculate number of cache blocks N given a fixed maximum token per token M. The formula for N is given by:
  613. N = (available_memory - M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group))
  614. / (2 * (layer_group_size * page_size * cache_dtype + 2 * num_group) + M * (num_attention_masks * activation_dtype_size))
  615. """
  616. cache_memory = self.get_available_memory(max_memory_percent)
  617. # Compute numerator
  618. num = cache_memory
  619. num -= max_batch_tokens * self.peak_activation_per_token * self._activation_dtype.itemsize
  620. num -= max_batch_tokens * (28 + 4 * self.num_groups)
  621. # Compute denominator
  622. denum = 2 * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
  623. denum += max_batch_tokens * (self.num_attention_masks * self._activation_dtype.itemsize)
  624. denum += max_batch_tokens * self._activation_dtype.itemsize
  625. # Compute cache size and return number of blocks
  626. num_pages = floor(num / denum)
  627. num_blocks = num_pages // self.block_size
  628. if num_blocks > self._upper_bound_num_blocks:
  629. logger.info(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }")
  630. num_blocks = self._upper_bound_num_blocks
  631. return num_blocks
  632. def compute_memory_footprint(
  633. self,
  634. num_blocks: int,
  635. max_batch_tokens: int,
  636. cache_dtype: torch.dtype,
  637. ) -> int:
  638. """Calculate the memory footprint breakdown for a given number of blocks and maximum batch tokens. The memory
  639. footprint is given by:
  640. available_memory = sum([
  641. MN * num_attention_masks * activation_dtype_size,
  642. 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
  643. M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
  644. ])
  645. but is broken down below.
  646. """
  647. num_pages = num_blocks * self.block_size
  648. cache_memory_footprint = 2 * self.group_size * num_pages * self.page_size * cache_dtype.itemsize
  649. activation_memory_footprint = self.peak_activation_per_token * self._activation_dtype.itemsize
  650. activation_memory_footprint *= max_batch_tokens
  651. inputs_outputs_positions_and_logits_memory_footprint = 4 * max_batch_tokens * 4 # second 4 is for int32 size
  652. attention_memory_footprint = self.num_attention_masks * self._activation_dtype.itemsize
  653. attention_memory_footprint *= num_pages * max_batch_tokens
  654. cumulative_seqlens_memory_footprint = 3 * max_batch_tokens * 4 # 4 is for int32 size
  655. write_index_memory_footprint = self.num_groups * max_batch_tokens * 4 # 4 is for int32 size
  656. read_index_memory_footprint = self.num_groups * (num_pages + max_batch_tokens) * 4 # 4 is for int32 size
  657. total_memory_footprint = sum(
  658. [
  659. cache_memory_footprint,
  660. activation_memory_footprint,
  661. inputs_outputs_positions_and_logits_memory_footprint,
  662. attention_memory_footprint,
  663. cumulative_seqlens_memory_footprint,
  664. write_index_memory_footprint,
  665. read_index_memory_footprint,
  666. ]
  667. )
  668. return total_memory_footprint