utils.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # Copyright 2026 The HuggingFace Inc. team
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections import OrderedDict
  15. from math import ceil
  16. from typing import Any
  17. import torch
  18. from transformers.configuration_utils import PretrainedConfig
  19. from .requests import FutureRequestState, RequestState, RequestStatus, logger
  20. class CudaGraphBuffer:
  21. """A fixed-size dict for CUDA graphs with LRU eviction when full."""
  22. def __init__(self, max_size: int) -> None:
  23. if max_size <= 0:
  24. raise ValueError(f"max_size must be positive, but got {max_size}")
  25. self.max_size = max_size
  26. self._storage: OrderedDict[tuple[int, int], torch.cuda.CUDAGraph] = OrderedDict()
  27. def __del__(self) -> None:
  28. original_max_size = self.max_size
  29. self.max_size = 1 # 0 would cause an infinite loop, 1 is enough to clear all graphs
  30. self.plan_for_new_graph(silent=True)
  31. self.max_size = original_max_size
  32. def get_graph(self, q_len: int, kv_len: int) -> torch.cuda.CUDAGraph | None:
  33. graph = self._storage.get((q_len, kv_len))
  34. if graph is not None:
  35. self._storage.move_to_end((q_len, kv_len))
  36. return graph
  37. def plan_for_new_graph(self, silent: bool = False) -> None:
  38. while len(self._storage) >= self.max_size:
  39. evicted_key, evicted_graph = self._storage.popitem(last=False)
  40. if not silent:
  41. logger.info(f"Evicting graph for {evicted_key = }")
  42. evicted_graph.reset()
  43. def set_graph(self, q_len: int, kv_len: int, graph: torch.cuda.CUDAGraph) -> None:
  44. # In our use case, this should not have any effect because we plan for a new graph before it is captured
  45. self.plan_for_new_graph()
  46. logger.info(f"Setting graph for {q_len = }, {kv_len = }")
  47. self._storage[(q_len, kv_len)] = graph
  48. def attn_mask_is_needed(config: PretrainedConfig) -> bool:
  49. """Checks if attention mask is needed for the given (config)."""
  50. return config._attn_implementation in ["paged|eager", "paged|sdpa"]
  51. def pad_to_interval(size: int, interval_size: int, max_value: int) -> int:
  52. """Return the smallest multiple of (interval_size) >= (size), capped at (max_value)."""
  53. if interval_size <= 0:
  54. return max_value
  55. padded = ceil(size / interval_size) * interval_size if size > 0 else interval_size
  56. return min(padded, max_value)
  57. def aligned_divide(x: int, divide_by: int, align_to: int) -> int:
  58. x = int(ceil(x / divide_by))
  59. if x % align_to:
  60. x += align_to - (x % align_to)
  61. return x
  62. def build_attention_mask(
  63. attention_mask: torch.Tensor,
  64. cumulative_seqlens_q: list[int],
  65. cumulative_seqlens_k: list[int],
  66. sliding_window: int = 1,
  67. ) -> None:
  68. """Builds an attention mask inplace using the cumulative seqlens of the query and key. If given a sliding window, it
  69. will also apply a sliding window mask on top. The attention mask is not boolean, it uses zeroes and -inf (or its
  70. equivalent) so it's more of an attention score bias tensor.
  71. The attention mask is a block-diagonal matrix, with each block an attention mask for a single query-key pair.
  72. Each of those block is built from a causal mask and, if there is a sliding window, a sliding window mask.
  73. An example is represented below, with seqlen_k = 8, seqlen_q = 4 and sliding_window = 6:
  74. CAUSAL MASK:
  75. █ █ █ █ █ ░ ░ ░
  76. █ █ █ █ █ █ ░ ░
  77. █ █ █ █ █ █ █ ░
  78. █ █ █ █ █ █ █ █
  79. SLIDING WINDOW MASK:
  80. ┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 8 - 4 - 6 = -2 offset to the left
  81. <─┴─>
  82. ░ █ | █ █ █ █ █ █ █ █
  83. ░ ░ | █ █ █ █ █ █ █ █
  84. ░ ░ | ░ █ █ █ █ █ █ █
  85. ░ ░ | ░ ░ █ █ █ █ █ █
  86. ATTENTION MASK (sum of causal and sliding window masks):
  87. █ █ █ █ █ ░ ░ ░
  88. █ █ █ █ █ █ ░ ░
  89. ░ █ █ █ █ █ █ ░
  90. ░ ░ █ █ █ █ █ █
  91. Another example with seqlen_k = 5, seqlen_q = 3 and sliding_window = 2:
  92. CAUSAL MASK:
  93. █ █ █ ░ ░
  94. █ █ █ █ ░
  95. █ █ █ █ █
  96. SLIDING WINDOW MASK:
  97. ┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 5 - 3 - 2 = 0 offset to the left
  98. <┴>
  99. | ░ █ █ █ █
  100. | ░ ░ █ █ █
  101. | ░ ░ ░ █ █
  102. ATTENTION MASK (sum of causal and sliding window masks):
  103. ░ █ █ ░ ░
  104. ░ ░ █ █ ░
  105. ░ ░ ░ █ █
  106. """
  107. min_value = torch.finfo(attention_mask.dtype).min
  108. for i in range(len(cumulative_seqlens_q) - 1):
  109. seqlen_q = cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]
  110. seqlen_k = cumulative_seqlens_k[i + 1] - cumulative_seqlens_k[i]
  111. if seqlen_q < seqlen_k and seqlen_q >= 1:
  112. causal_diagonal = seqlen_k - seqlen_q + 1
  113. else:
  114. causal_diagonal = 1
  115. query_range = slice(cumulative_seqlens_q[i], cumulative_seqlens_q[i + 1])
  116. key_range = slice(cumulative_seqlens_k[i], cumulative_seqlens_k[i + 1])
  117. # Apply causal mask
  118. minus_inf = torch.full(
  119. attention_mask[..., query_range, key_range].shape,
  120. min_value,
  121. dtype=attention_mask.dtype,
  122. device=attention_mask.device,
  123. )
  124. masked = torch.triu(minus_inf, diagonal=causal_diagonal)
  125. # Apply sliding window mask if needed
  126. if sliding_window > 1:
  127. sliding_diagonal = seqlen_k - seqlen_q - sliding_window
  128. masked += torch.tril(minus_inf, diagonal=sliding_diagonal)
  129. # Replace in attention mask
  130. attention_mask[..., query_range, key_range] = masked
  131. def create_warmup_future_states(
  132. num: int,
  133. status: RequestStatus,
  134. num_query_tokens: int,
  135. num_cache_tokens: int,
  136. cache: Any, # not annotated to avoid circular import
  137. ) -> list[FutureRequestState]:
  138. """An utility function to create a list of FutureRequestStates for the warmup of CB."""
  139. # Setup
  140. request_ids = [f"__warmup_{status.name}_{i}__" for i in range(num)]
  141. total_tokens = num_query_tokens + num_cache_tokens
  142. blocks_needed = ceil(total_tokens / cache.block_size)
  143. # Main loop
  144. future_states = []
  145. for req_id in request_ids:
  146. state = RequestState(request_id=req_id, initial_tokens=[0] * total_tokens, max_new_tokens=1)
  147. state._status = status # bypass the property setter to avoid the lifecycle side effects
  148. state.tokens_to_process = [0] * num_query_tokens
  149. state.position_offset = num_cache_tokens
  150. # Stop if allocation fails for any request
  151. allocated = cache.allocate_blocks(blocks_needed, state.request_id, 0)
  152. if allocated is None:
  153. return future_states
  154. future_states.append(FutureRequestState(state, has_new_token=True, complete_blocks=0))
  155. return future_states