flex_attention.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. """
  2. Partially inspired by torchtune's flex attention implementation
  3. Citation:
  4. @software{torchtune,
  5. title = {torchtune: PyTorch's finetuning library},
  6. author = {torchtune maintainers and contributors},
  7. url = {https//github.com/pytorch/torchtune},
  8. license = {BSD-3-Clause},
  9. month = apr,
  10. year = {2024}
  11. }
  12. """
  13. # coding=utf-8
  14. # Copyright 2025 The HuggingFace Inc. team.
  15. #
  16. # Licensed under the Apache License, Version 2.0 (the "License");
  17. # you may not use this file except in compliance with the License.
  18. # You may obtain a copy of the License at
  19. #
  20. # http://www.apache.org/licenses/LICENSE-2.0
  21. #
  22. # Unless required by applicable law or agreed to in writing, software
  23. # distributed under the License is distributed on an "AS IS" BASIS,
  24. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  25. # See the License for the specific language governing permissions and
  26. # limitations under the License.
  27. from typing import Optional, Union
  28. import torch
  29. from packaging import version
  30. from ..utils import is_torch_flex_attn_available, logging
  31. from ..utils.import_utils import (
  32. get_torch_version,
  33. is_torch_greater_or_equal,
  34. is_torch_less_or_equal,
  35. is_torchdynamo_compiling,
  36. )
  37. _TORCH_FLEX_USE_AUX = is_torch_greater_or_equal("2.9.0")
  38. if is_torch_flex_attn_available():
  39. from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size
  40. from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
  41. if _TORCH_FLEX_USE_AUX:
  42. from torch.nn.attention.flex_attention import AuxRequest
  43. else:
  44. AuxRequest = None
  45. logger = logging.get_logger(__name__)
  46. class WrappedFlexAttention:
  47. """
  48. We are doing a singleton class so that flex attention is compiled once when it's first called.
  49. """
  50. _instance = None
  51. _is_flex_compiled = False
  52. _compiled_flex_attention = None
  53. def __new__(cls, *args, **kwargs):
  54. if cls._instance is None:
  55. # Create a new instance if one doesn't already exist
  56. cls._instance = super().__new__(cls)
  57. return cls._instance
  58. @torch.compiler.disable(recursive=False)
  59. def __init__(self, training):
  60. """
  61. Initialize or update the singleton instance.
  62. """
  63. if not self._is_flex_compiled or training != self.training:
  64. self.training = training
  65. if is_torch_less_or_equal("2.5.1"):
  66. self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
  67. # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
  68. # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
  69. # see https://github.com/pytorch/pytorch/issues/146260 for training
  70. elif version.parse(get_torch_version()).base_version == "2.6.0" and training:
  71. self._compiled_flex_attention = torch.compile(
  72. flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
  73. )
  74. # Fallback, usually the most recent torch 2.7.x+ versions
  75. else:
  76. self._compiled_flex_attention = torch.compile(flex_attention)
  77. self._is_flex_compiled = True
  78. def __call__(self):
  79. return self._compiled_flex_attention
  80. def get_flex_attention_lse_kwargs(return_lse: bool) -> dict[str, bool | Optional["AuxRequest"]]:
  81. """
  82. Requests the LSE from flex_attention in a version-agnostic fashion.
  83. Before torch 2.9, the LSE was requested via the boolean return_lse field. However, starting with
  84. torch 2.9, an AuxRequest object must be passed via the aux_request field. This method conditionally
  85. returns the correct form based on the python version.
  86. """
  87. if _TORCH_FLEX_USE_AUX:
  88. return {"return_aux": AuxRequest(lse=True) if return_lse else None}
  89. return {"return_lse": return_lse}
  90. def compile_friendly_flex_attention(
  91. query: torch.Tensor,
  92. key: torch.Tensor,
  93. value: torch.Tensor,
  94. training=False,
  95. **kwargs,
  96. ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
  97. # First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
  98. # Do not use compiled version if already compiling forward (it raises issues)
  99. flex_attention_compiled = WrappedFlexAttention(training)() if not is_torchdynamo_compiling() else flex_attention
  100. return flex_attention_compiled(
  101. query,
  102. key,
  103. value,
  104. **kwargs,
  105. )
  106. Offset = torch.Tensor | int
  107. # TODO: deprecate / rename to make_flex_block_mask for clarity as it's not only causal anymore
  108. def make_flex_block_causal_mask(
  109. attention_mask_2d: torch.Tensor,
  110. attention_chunk_size: int | None = None,
  111. query_length=None,
  112. key_length=None,
  113. offsets: tuple[Offset, Offset] | None = None,
  114. is_causal: bool | None = True,
  115. ) -> "BlockMask":
  116. """
  117. IMPORTANT NOTICE: This function is deprecated in favor of using the mask primitives in `masking_utils.py`,
  118. and will be removed in a future version without warnings. New code should not use it. It is only kept here
  119. for BC for now, while models using it are being patched accordingly.
  120. Create a block (causal) document mask for a batch of sequences, both packed and unpacked.
  121. Create Block (causal) logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
  122. The resultant BlockMask is a compressed representation of the full (causal) block
  123. mask. BlockMask is essential for performant computation of flex attention.
  124. See: https://pytorch.org/blog/flexattention/
  125. Args:
  126. attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
  127. of shape (batch_size, total_seq_len). e.g.
  128. For unpacked sequence:
  129. [[1, 1, 1, 1, 0, 0, 0],
  130. [1, 1, 1, 1, 1, 0, 0]]
  131. For packed sequence:
  132. [[1, 1, 1, 2, 2, 2, 0],
  133. [1, 1, 2, 2, 2, 3, 3]]
  134. Returns:
  135. BlockMask
  136. """
  137. batch_size, total_seq_len = attention_mask_2d.shape
  138. if not key_length:
  139. key_length = total_seq_len
  140. if not query_length:
  141. query_length = total_seq_len
  142. # older torch (2.5.x) cannot handle sequences not in multiples of 128 (default block size)
  143. pad_len = ((key_length // flex_default_block_size) + 1) * flex_default_block_size
  144. attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, pad_len - key_length))
  145. device = attention_mask_2d.device
  146. document_ids = attention_mask_2d.clone()
  147. if attention_chunk_size is not None:
  148. # we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
  149. chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (attention_chunk_size)
  150. # Instead of passing a tensor mask, flex attention requires a mask_mod function
  151. # that determines which elements of QK^T should be included in the attention
  152. # computation prior to the softmax. For sample packing, we need both the
  153. # logic for both causal mask and document mask. See PyTorch's official
  154. # blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
  155. def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
  156. """
  157. Defines the logic of a block causal mask by combining both a standard causal mask
  158. and a block diagonal document mask.
  159. See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
  160. for an illustration.
  161. """
  162. causal_mask = q_idx >= kv_idx # not valid when decoding
  163. document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
  164. padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
  165. final_mask = causal_mask & padding_mask & document_mask
  166. return final_mask
  167. def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
  168. """
  169. Combines the chunk mask with the causal mask for chunked attention.
  170. """
  171. chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
  172. causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
  173. return chunk_mask & causal_doc_mask
  174. def default_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
  175. """
  176. Utilizes default attention mask to enable encoder and encoder-decoder
  177. attention masks.
  178. """
  179. document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
  180. # kv indexing is crucial in order to work correctly
  181. padding_mask = attention_mask_2d[batch_idx, kv_idx] > 0
  182. final_mask = padding_mask & document_mask
  183. return final_mask
  184. if not is_causal:
  185. mask_mod_maybe_combined = default_mask_mod
  186. else:
  187. mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
  188. if offsets is not None:
  189. q_offset = offsets[0].to(device)
  190. kv_offset = offsets[1].to(device)
  191. def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
  192. offset_q = q_idx + q_offset
  193. offset_kv = kv_idx + kv_offset
  194. return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
  195. else:
  196. mask_mod = mask_mod_maybe_combined
  197. return create_block_mask(
  198. mask_mod=mask_mod,
  199. B=batch_size,
  200. H=None, # attention head
  201. Q_LEN=query_length,
  202. KV_LEN=key_length,
  203. device=device,
  204. # compiling the mask is not BC with older torch
  205. _compile=not is_torch_less_or_equal("2.5.1"),
  206. )
  207. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  208. """
  209. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  210. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  211. """
  212. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  213. if n_rep == 1:
  214. return hidden_states
  215. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  216. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  217. def flex_attention_forward(
  218. module: torch.nn.Module,
  219. query: torch.Tensor,
  220. key: torch.Tensor,
  221. value: torch.Tensor,
  222. attention_mask: Union[torch.Tensor, "BlockMask"],
  223. scaling: float | None = None,
  224. softcap: float | None = None,
  225. s_aux: torch.Tensor | None = None,
  226. **kwargs,
  227. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  228. if kwargs.get("dropout", 0.0) > 0:
  229. raise ValueError(
  230. "`flex_attention` does not support `dropout`. Please use it with inference"
  231. " only (`model.eval()`) or turn off the attention dropout in the respective config."
  232. )
  233. block_mask = None
  234. score_mask = None
  235. if isinstance(attention_mask, BlockMask):
  236. block_mask = attention_mask
  237. else:
  238. score_mask = attention_mask
  239. if score_mask is not None:
  240. score_mask = score_mask[:, :, :, : key.shape[-2]]
  241. def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
  242. if softcap is not None:
  243. score = softcap * torch.tanh(score / softcap)
  244. if score_mask is not None:
  245. score = score + score_mask[batch_idx][0][q_idx][kv_idx]
  246. # Note: attention sinks cannot be correctly implemented in score_mod
  247. # because it requires operating on the full attention matrix before softmax.
  248. # ==> this is done after flex attention
  249. return score
  250. enable_gqa = True
  251. num_local_query_heads = query.shape[1]
  252. # When running TP this helps:
  253. if (num_local_query_heads & (num_local_query_heads - 1)) != 0:
  254. key = repeat_kv(key, query.shape[1] // key.shape[1])
  255. value = repeat_kv(value, query.shape[1] // value.shape[1])
  256. enable_gqa = False
  257. kernel_options = kwargs.get("kernel_options")
  258. # On CPU we must skip returning LSE due to a runtime issue; elsewhere, follow PyTorch API and return it
  259. return_lse = query.device.type != "cpu"
  260. if not return_lse and s_aux is not None:
  261. raise ValueError(
  262. "Attention sinks cannot be run on CPU with flex attention. Please switch to a different device, e.g. CUDA"
  263. )
  264. flex_attention_output = compile_friendly_flex_attention(
  265. query,
  266. key,
  267. value,
  268. score_mod=score_mod,
  269. block_mask=block_mask,
  270. enable_gqa=enable_gqa,
  271. scale=scaling,
  272. kernel_options=kernel_options,
  273. # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
  274. # For simplification, we thus always return it as no additional computations are introduced.
  275. training=module.training,
  276. # inject the lse args
  277. **get_flex_attention_lse_kwargs(return_lse),
  278. )
  279. if return_lse:
  280. # before torch 2.9, return_lse returns the LSE directly as a second tuple element
  281. # in torch 2.9 and later, return_aux returns AuxOutput as a second tuple element -- the LSE must be extracted
  282. if _TORCH_FLEX_USE_AUX:
  283. attention_output, aux = flex_attention_output # type: ignore[misc]
  284. lse = aux.lse
  285. else:
  286. attention_output, lse = flex_attention_output # type: ignore[misc]
  287. # lse is returned in float32
  288. lse = lse.to(value.dtype)
  289. if s_aux is not None:
  290. # Apply attention sinks by renormalizing using LSE
  291. batch_size, num_heads, seq_len_q, _ = attention_output.shape # batch, num_heads, seq_len, head_dim
  292. sinks = s_aux.view(1, -1, 1, 1).expand(batch_size, num_heads, seq_len_q, 1)
  293. # We need to compute the normalization that includes the sinks
  294. # since log(sum(exp(scores))) = lse, exp(log(sum(exp(scores)))) = exp(lse)
  295. # NB: log(sum(exp(scores)) + exp(sink)) = log(exp(lse) + exp(sink))
  296. lse_expanded = lse.unsqueeze(-1) # [batch, num_heads, seq_len, 1]
  297. combined_lse = torch.logsumexp(torch.cat([lse_expanded, sinks], dim=-1), dim=-1, keepdim=True)
  298. # Use new_norm / old_norm = exp(combined_lse - lse) to compute renorm and apply
  299. renorm_factor = torch.exp(lse_expanded - combined_lse)
  300. attention_output = attention_output * renorm_factor
  301. else:
  302. attention_output = flex_attention_output # type: ignore[assignment]
  303. lse = None
  304. attention_output = attention_output.transpose(1, 2).contiguous()
  305. return attention_output, lse