modeling_flash_attention_utils.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. # Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. import inspect
  16. import os
  17. from collections.abc import Callable
  18. from functools import partial
  19. from typing import TypedDict
  20. import torch
  21. import torch.nn.functional as F
  22. from .utils import (
  23. is_flash_attn_2_available,
  24. is_flash_attn_3_available,
  25. is_flash_attn_4_available,
  26. is_torch_cuda_available,
  27. is_torch_mlu_available,
  28. is_torch_npu_available,
  29. is_torch_xpu_available,
  30. logging,
  31. )
  32. from .utils.import_utils import PACKAGE_DISTRIBUTION_MAPPING, is_tracing
  33. logger = logging.get_logger(__name__)
  34. # TODO Deprecate when all models have the attention interface
  35. def flash_attn_supports_top_left_mask():
  36. if is_flash_attn_2_available() or is_flash_attn_3_available() or is_flash_attn_4_available():
  37. return False
  38. from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
  39. return is_npu_fa2_top_left_aligned_causal_mask()
  40. # TODO Deprecate when all models have the attention interface
  41. def is_flash_attn_available():
  42. return (
  43. is_flash_attn_4_available()
  44. or is_flash_attn_3_available()
  45. or is_flash_attn_2_available()
  46. or is_torch_npu_available()
  47. or is_torch_xpu_available()
  48. )
  49. # Mapping from flash attention implementations to their kernel fallback repositories
  50. FLASH_ATTN_KERNEL_FALLBACK = {
  51. "flash_attention_2": "kernels-community/flash-attn2",
  52. "flash_attention_3": "kernels-community/vllm-flash-attn3",
  53. "flash_attention_4": "kernels-community/flash-attn4",
  54. }
  55. # Meta information on each mainline FA compatibility:
  56. # 1. The import structure and availability
  57. # 2. Device support (with custom ones that use other workarounds, e.g. kernels)
  58. # 3. Supported major cuda devices, e.g. Hopper, Blackwell. Mostly found in the newest FA versions
  59. FLASH_ATTENTION_COMPATIBILITY_MATRIX = {
  60. 2: {
  61. "flash_attn_version": 2,
  62. "general_availability_check": is_flash_attn_2_available,
  63. "pkg_availability_check": lambda *args, **kwargs: importlib.util.find_spec("flash_attn") is not None
  64. and "flash-attn" in [pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"]],
  65. "supported_devices": (
  66. (is_torch_cuda_available, "cuda"),
  67. (is_torch_mlu_available, "mlu"),
  68. (is_torch_npu_available, "npu"),
  69. (is_torch_xpu_available, "xpu"),
  70. ),
  71. "custom_supported_devices": (
  72. (is_torch_npu_available, "Detect using FlashAttention2 on Ascend NPU."),
  73. (
  74. is_torch_xpu_available,
  75. f"Detect using FlashAttention2 (via kernel `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}`) on XPU.",
  76. ),
  77. ),
  78. },
  79. 3: {
  80. "flash_attn_version": 3,
  81. "general_availability_check": is_flash_attn_3_available,
  82. "pkg_availability_check": lambda *args, **kwargs: importlib.util.find_spec("flash_attn_interface") is not None
  83. and "flash-attn-3" in [pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn_interface"]],
  84. "supported_devices": ((is_torch_cuda_available, "cuda"),),
  85. "cuda_min_major_version": 8, # Ampere
  86. },
  87. 4: {
  88. "flash_attn_version": 4,
  89. "general_availability_check": is_flash_attn_4_available,
  90. "pkg_availability_check": lambda *args, **kwargs: importlib.util.find_spec("flash_attn") is not None
  91. and "flash-attn-4" in [pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"]],
  92. "supported_devices": ((is_torch_cuda_available, "cuda"),),
  93. "cuda_min_major_version": 9, # Hopper
  94. },
  95. }
  96. # `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves
  97. _loaded_implementation = None
  98. _flash_fn = None
  99. _flash_varlen_fn = None
  100. _flash_with_kvcache_fn = None
  101. _pad_fn = None
  102. _unpad_fn = None
  103. # function that processes kwargs, generalized to handle any supported kwarg within the function
  104. _process_flash_kwargs_fn = None
  105. # exceptions where hf API doesn't match the original flash attention API
  106. _hf_api_to_flash_mapping = {
  107. "dropout": "dropout_p",
  108. "sliding_window": "window_size",
  109. }
  110. # alternative names within the different flash attention APIs, e.g. for attention sinks
  111. _flash_api_alternative_names = {"s_aux": "learnable_sink"}
  112. def _lazy_imports(
  113. implementation: str | None, attention_wrapper: Callable | None = None, allow_all_kernels: bool = False
  114. ):
  115. """
  116. Lazy loads the respective flash attention implementations.
  117. Return:
  118. flash_attn_func: The base flash attention function.
  119. flash_attn_varlen_func: The flash attention function supporting variable sequence lengths,
  120. e.g. for padding-free training.
  121. pad_input: The function to pad inputs into one sequence and returning the respective kwargs.
  122. unpad_input: The function to unpad outputs based on the kwargs (from pad_input).
  123. """
  124. is_fa2 = is_flash_attn_2_available()
  125. is_fa3 = is_flash_attn_3_available()
  126. is_fa4 = is_flash_attn_4_available()
  127. pad_input, unpad_input = _pad_input, _unpad_input
  128. is_paged = implementation.startswith("paged|")
  129. implementation = implementation.split("|")[1] if is_paged else implementation
  130. if (implementation == "flash_attention_2" and is_fa2) or (
  131. implementation is None and is_fa2 and not is_fa3 and not is_fa4
  132. ):
  133. from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache
  134. from flash_attn.bert_padding import pad_input, unpad_input
  135. elif is_torch_npu_available():
  136. # Package `flash-attn` is unavailable on Ascend NPU, which will cause ImportError
  137. # Flash-Attention2 related apis for Ascend NPU must be imported from `.integrations.npu_flash_attention` module
  138. from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
  139. from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
  140. from .integrations.npu_flash_attention import npu_flash_attn_with_kvcache as flash_attn_with_kvcache
  141. else:
  142. if implementation == "flash_attention_3" or (implementation is None and is_fa3 and not is_fa4):
  143. from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache
  144. elif implementation == "flash_attention_4" or (implementation is None and is_fa4):
  145. from flash_attn.cute import flash_attn_func, flash_attn_varlen_func
  146. flash_attn_with_kvcache = None # not supported yet
  147. # Kernels fallback
  148. else:
  149. from .integrations.hub_kernels import load_and_register_attn_kernel
  150. # Map standard attention names to hub kernel repos
  151. kernel_repo = FLASH_ATTN_KERNEL_FALLBACK.get(implementation, implementation)
  152. # We want to explicitly register the name with `paged|` if found
  153. kernel_implementation = f"paged|{implementation}" if is_paged else kernel_repo
  154. kernel = load_and_register_attn_kernel(
  155. kernel_implementation, attention_wrapper, allow_all_kernels=allow_all_kernels
  156. )
  157. flash_attn_func = getattr(kernel, "flash_attn_func", None)
  158. flash_attn_varlen_func = getattr(kernel, "flash_attn_varlen_func", None)
  159. flash_attn_with_kvcache = getattr(kernel, "flash_attn_with_kvcache", None)
  160. if flash_attn_varlen_func is None:
  161. raise ValueError(
  162. f"Could not find the currently requested flash attention implementation at `{implementation}`."
  163. "Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn2`."
  164. )
  165. if flash_attn_func is None:
  166. logger.warning(
  167. f"The loaded flash attention implementation at `{implementation}` only supports varlen, i.e. "
  168. "it can only be used with continuous batching and does not support the full functionality for "
  169. "the base transformers generation methods."
  170. )
  171. if flash_attn_with_kvcache is None:
  172. logger.warning(
  173. f"The loaded flash attention implementation at `{implementation}` does not support block tables, so"
  174. " the full performances of continuous batching will not be achieved, only the varlen path will be "
  175. "used."
  176. )
  177. return flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache, pad_input, unpad_input
  178. def _lazy_define_process_function(flash_function):
  179. """
  180. Depending on the version and kernel some features are not supported. Due to limitations in
  181. `torch.compile`, we opt to statically type which (optional) kwarg parameters are supported
  182. within `_process_flash_attention_kwargs`.
  183. NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`.
  184. This might be confusing for kwargs that we use in any case, e.g. `is_causal`.
  185. """
  186. flash_parameters = inspect.signature(flash_function).parameters
  187. process_parameters = inspect.signature(_process_flash_attention_kwargs).parameters
  188. supports_mapping = {}
  189. for param in process_parameters:
  190. fa_param = _hf_api_to_flash_mapping.get(param, param)
  191. supports_mapping[fa_param] = fa_param in flash_parameters
  192. if (fa_alternative_name := _flash_api_alternative_names.get(param, param)) != fa_param:
  193. supports_mapping[fa_alternative_name] = fa_alternative_name in flash_parameters
  194. return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping)
  195. def lazy_import_flash_attention(
  196. implementation: str | None, attention_wrapper: Callable | None = None, allow_all_kernels: bool = False
  197. ):
  198. """
  199. Lazily import flash attention and return the respective functions + flags.
  200. NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can
  201. work without preloading. See `load_and_register_attn_kernel` in `integrations.hub_kernels`.
  202. """
  203. global _loaded_implementation
  204. if implementation is None and _loaded_implementation is None:
  205. raise ValueError("Could not find any flash attn implementation based on your environment.")
  206. global _flash_fn, _flash_varlen_fn, _flash_with_kvcache_fn, _pad_fn, _unpad_fn, _process_flash_kwargs_fn
  207. if implementation is not None and _loaded_implementation != implementation:
  208. _loaded_implementation = implementation
  209. _flash_fn, _flash_varlen_fn, _flash_with_kvcache_fn, _pad_fn, _unpad_fn = _lazy_imports(
  210. implementation, attention_wrapper, allow_all_kernels=allow_all_kernels
  211. )
  212. _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn)
  213. return (_flash_fn, _flash_varlen_fn, _flash_with_kvcache_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn
  214. def lazy_import_paged_flash_attention(implementation: str | None, allow_all_kernels: bool = False):
  215. """
  216. Same as `lazy_import_flash_attention` but explicitly wrapping it with the paged implementation.
  217. """
  218. from .integrations.flash_paged import paged_attention_forward
  219. (_, flash_attn_varlen_func, flash_attn_with_kvcache_fn, _, _), _ = lazy_import_flash_attention(
  220. implementation, attention_wrapper=paged_attention_forward, allow_all_kernels=allow_all_kernels
  221. )
  222. return flash_attn_varlen_func, flash_attn_with_kvcache_fn
  223. def _index_first_axis(tensor, indices):
  224. """
  225. A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
  226. after flattening the first two dimensions of the tensor. This is functionally equivalent to
  227. FA2's `index_first_axis` and replaces the need to import it.
  228. """
  229. # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
  230. # two dimensions to get (total_tokens, ...) before indexing.
  231. reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
  232. return reshaped_tensor[indices]
  233. def _unpad_input(hidden_states, attention_mask, unused_mask=None):
  234. """
  235. unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
  236. Arguments:
  237. hidden_states: (batch, seqlen, ...)
  238. attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
  239. unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
  240. Return:
  241. hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
  242. indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
  243. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
  244. max_seqlen_in_batch: int
  245. seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
  246. """
  247. all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
  248. seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
  249. used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  250. indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
  251. max_seqlen_in_batch = seqlens_in_batch.max()
  252. cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
  253. return (
  254. _index_first_axis(hidden_states, indices),
  255. indices,
  256. cu_seqlens,
  257. max_seqlen_in_batch,
  258. used_seqlens_in_batch,
  259. )
  260. def _pad_input(hidden_states, indices, batch, seqlen):
  261. """
  262. pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
  263. Arguments:
  264. hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
  265. indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
  266. batch: int, batch size for the padded sequence.
  267. seqlen: int, maximum sequence length for the padded sequence.
  268. Return:
  269. hidden_states: (batch, seqlen, ...)
  270. """
  271. dim = hidden_states.shape[1:]
  272. output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
  273. output[indices] = hidden_states
  274. return output.view(batch, seqlen, *dim)
  275. def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
  276. """
  277. Retrieves indexing data required to repad unpadded (ragged) tensors.
  278. Arguments:
  279. attention_mask (`torch.Tensor`):
  280. Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
  281. Return:
  282. indices (`torch.Tensor`):
  283. The indices of non-masked tokens from the flattened input sequence.
  284. cu_seqlens (`torch.Tensor`):
  285. The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
  286. max_seqlen_in_batch (`int`):
  287. Maximum sequence length in batch.
  288. """
  289. seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  290. indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  291. max_seqlen_in_batch = seqlens_in_batch.max()
  292. cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
  293. return (
  294. indices,
  295. cu_seqlens,
  296. max_seqlen_in_batch,
  297. )
  298. def _upad_input(
  299. query_layer: torch.Tensor,
  300. key_layer: torch.Tensor,
  301. value_layer: torch.Tensor,
  302. attention_mask: torch.Tensor,
  303. query_length: int,
  304. unpad_input_func,
  305. ):
  306. """
  307. Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
  308. This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
  309. tensors for query, key, value tensors.
  310. Arguments:
  311. query_layer (`torch.Tensor`):
  312. Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
  313. key_layer (`torch.Tensor`):
  314. Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
  315. value_layer (`torch.Tensor`):
  316. Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
  317. attention_mask (`torch.Tensor`):
  318. Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
  319. query_length (`int`):
  320. Target length.
  321. unpad_input_func:
  322. The function to use for unpadding the input tensors.
  323. Return:
  324. query_layer (`torch.Tensor`):
  325. Query state without padding. Shape: (total_target_length, num_heads, head_dim).
  326. key_layer (`torch.Tensor`):
  327. Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
  328. value_layer (`torch.Tensor`):
  329. Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
  330. indices_q (`torch.Tensor`):
  331. The indices of non-masked tokens from the flattened input target sequence.
  332. (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
  333. The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
  334. (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
  335. Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
  336. """
  337. indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
  338. # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage
  339. # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores
  340. if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]):
  341. key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :]
  342. batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
  343. key_layer = _index_first_axis(key_layer, indices_k)
  344. value_layer = _index_first_axis(value_layer, indices_k)
  345. if query_length == kv_seq_len:
  346. query_layer = _index_first_axis(query_layer, indices_k)
  347. cu_seqlens_q = cu_seqlens_k
  348. max_seqlen_in_batch_q = max_seqlen_in_batch_k
  349. indices_q = indices_k
  350. elif query_length == 1:
  351. max_seqlen_in_batch_q = 1
  352. cu_seqlens_q = torch.arange(
  353. batch_size + 1, dtype=torch.int32, device=query_layer.device
  354. ) # There is a memcpy here, that is very bad.
  355. indices_q = cu_seqlens_q[:-1]
  356. query_layer = query_layer.squeeze(1)
  357. else:
  358. # The -q_len: slice assumes left padding.
  359. attention_mask = attention_mask[:, -query_length:]
  360. query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask)
  361. return (
  362. query_layer,
  363. key_layer,
  364. value_layer,
  365. indices_q,
  366. (cu_seqlens_q, cu_seqlens_k),
  367. (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
  368. )
  369. def prepare_fa_kwargs_from_position_ids(position_ids):
  370. """
  371. This function returns all the necessary kwargs to call `flash_attn_varlen_func` extracted from position_ids.
  372. Arguments:
  373. position_ids (`torch.Tensor`):
  374. Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
  375. Return:
  376. (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
  377. The cumulative sequence lengths for the target (query) and source (key, value), used to index into
  378. ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
  379. (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
  380. Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
  381. `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
  382. """
  383. tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
  384. position_ids = position_ids.reshape(-1)
  385. indices_q = (position_ids == 0).nonzero().view(-1)
  386. cu_seq_lens_q = torch.cat(
  387. (
  388. indices_q.to(**tensor_kwargs),
  389. torch.tensor(position_ids.size(), **tensor_kwargs),
  390. )
  391. )
  392. cu_seq_lens_k = cu_seq_lens_q
  393. # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
  394. # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
  395. # for some models (e.g. qwen2-vl).
  396. max_length_q = cu_seq_lens_q.diff().max()
  397. max_length_k = max_length_q
  398. return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)
  399. def _prepare_from_posids(query, key, value, position_ids):
  400. """
  401. This function returns necessary arguments to call `flash_attn_varlen_func`.
  402. All three query, key, value states will be flattened.
  403. Cumulative lengths of each examples in the batch will be extracted from position_ids.
  404. NOTE: ideally cumulative lengths should be prepared at the data collator stage
  405. Arguments:
  406. query (`torch.Tensor`):
  407. Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
  408. key (`torch.Tensor`):
  409. Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
  410. value (`torch.Tensor`):
  411. Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
  412. position_ids (`torch.Tensor`):
  413. Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
  414. Return:
  415. query (`torch.Tensor`):
  416. Query state without padding. Shape: (total_target_length, num_heads, head_dim).
  417. key (`torch.Tensor`):
  418. Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
  419. value (`torch.Tensor`):
  420. Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
  421. (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
  422. The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
  423. (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
  424. Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
  425. """
  426. query = query.contiguous().view(-1, query.size(-2), query.size(-1))
  427. key = key.contiguous().view(-1, key.size(-2), key.size(-1))
  428. value = value.contiguous().view(-1, value.size(-2), value.size(-1))
  429. (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(position_ids)
  430. return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k))
  431. def _is_packed_sequence(position_ids, batch_size):
  432. """
  433. Check the position ids whether packed sequences are indicated or not
  434. 1. Position ids exist
  435. 2. Flattened sequences only are supported
  436. 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences
  437. """
  438. if position_ids is None:
  439. return False
  440. increasing_position_sequences = (
  441. torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min()
  442. )
  443. return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool()
  444. def fa_peft_integration_check(
  445. q: torch.Tensor,
  446. k: torch.Tensor,
  447. v: torch.Tensor,
  448. target_dtype: torch.dtype | None = None,
  449. ):
  450. """
  451. PEFT usually casts the layer norms in float32 for training stability reasons
  452. therefore the input hidden states gets silently casted in float32. Hence, we need
  453. cast them back in float16 / bfloat16 just to be sure everything works as expected.
  454. This might slowdown training & inference so it is recommended to not cast the LayerNorms!
  455. """
  456. if target_dtype and q.dtype == torch.float32:
  457. logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.")
  458. q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype)
  459. return q, k, v
  460. class FlashAttentionKwargs(TypedDict, total=False):
  461. """
  462. Keyword arguments for Flash Attention with Compile.
  463. Attributes:
  464. cu_seq_lens_q (`torch.LongTensor`, *optional*)
  465. Gets cumulative sequence length for query state.
  466. cu_seq_lens_k (`torch.LongTensor`, *optional*)
  467. Gets cumulative sequence length for key state.
  468. max_length_q (`int`, *optional*):
  469. Maximum sequence length for query state.
  470. max_length_k (`int`, *optional*):
  471. Maximum sequence length for key state.
  472. """
  473. cu_seq_lens_q: torch.LongTensor | None
  474. cu_seq_lens_k: torch.LongTensor | None
  475. max_length_q: int | None
  476. max_length_k: int | None
  477. def _process_flash_attention_kwargs(
  478. query_length: int,
  479. key_length: int,
  480. is_causal: bool,
  481. dropout: float = 0.0,
  482. softmax_scale: float | None = None,
  483. sliding_window: int | None = None,
  484. use_top_left_mask: bool = False,
  485. softcap: float | None = None,
  486. deterministic: bool | None = None,
  487. s_aux: torch.Tensor | None = None,
  488. max_seqlen_q: int | torch.IntTensor | None = None,
  489. max_seqlen_k: int | torch.IntTensor | None = None,
  490. supports_mapping: dict[str, bool] | None = None,
  491. **kwargs,
  492. ):
  493. """
  494. Returns a set of kwargs that are passed down to the according flash attention function based on
  495. requested features and whether it is supported - depends on the version and kernel implementation
  496. which is dynamically configured at `lazy_import_flash_attention`. The (un)supported features can be
  497. inspected in `supports_mapping`, see `_lazy_define_process_function` for more details.
  498. Args:
  499. query_length (`int`):
  500. Length of the query states
  501. key_length (`int`):
  502. Length of the key states
  503. is_causal (`bool`):
  504. Whether we perform causal (decoder) attention or full attention.
  505. dropout (`float`):
  506. Attention dropout.
  507. softmax_scale (`float`, *optional*):
  508. The scaling of QK^T before applying softmax. Default to `1 / sqrt(head_dim)`.
  509. sliding_window (`int`, *optional*):
  510. The size of the sliding window, i.e. we look at a max of `sliding_window` tokens back.
  511. use_top_left_mask (`bool`):
  512. Deprecated behavior of older versions of flash attention requiring different masking.
  513. softcap (`float`, *optional*):
  514. Softcap for the attention logits, used e.g. in gemma2.
  515. deterministic (`bool`, *optional*):
  516. Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
  517. s_aux (`torch.Tensor`, *optional*):
  518. Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head.
  519. max_seqlen_q (`Union[int, torch.IntTensor]`, *optional*):
  520. The maximum sequence length in the query tensor during a varlen forward.
  521. max_seqlen_k (`Union[int, torch.IntTensor]`, *optional*):
  522. The maximum sequence length in the key/value tensor during a varlen forward.
  523. Return:
  524. flash_kwargs (`dict`):
  525. A dict of kwargs that are requested and supported.
  526. """
  527. flash_kwargs = {
  528. "causal": is_causal and not (use_top_left_mask and query_length == 1),
  529. "softmax_scale": softmax_scale,
  530. }
  531. if supports_mapping["dropout_p"]:
  532. flash_kwargs["dropout_p"] = dropout
  533. if supports_mapping["window_size"] and sliding_window is not None and key_length > sliding_window:
  534. # The flash attention API sets inclusive boundaries, i.e. (4, 0) would take 4 tokens to the left
  535. # and the current token for a total size of 5. However, we usually define our window sizes by
  536. # their total window size (when causal). Encoder models as of now seldom use SWA and when they
  537. # do, they must align with this symmetric logic, i.e. for a total of `2*sliding_window + 1`.
  538. flash_kwargs["window_size"] = (sliding_window - 1, sliding_window - 1)
  539. if supports_mapping["deterministic"]:
  540. flash_kwargs["deterministic"] = (
  541. deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
  542. )
  543. if supports_mapping["softcap"] and softcap is not None:
  544. flash_kwargs["softcap"] = softcap
  545. if ((legacy_sink_param := supports_mapping["s_aux"]) or supports_mapping["learnable_sink"]) and s_aux is not None:
  546. if legacy_sink_param:
  547. flash_kwargs["s_aux"] = s_aux # e.g. FA3 (vllm)
  548. else:
  549. flash_kwargs["learnable_sink"] = s_aux # FA4
  550. # There is a limitation of the flash attention API, as the function `flash_attn_varlen_func`
  551. # may require `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
  552. #
  553. # You can either set
  554. # - Env: `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`
  555. # - Before compiling: `torch._dynamo.config.capture_scalar_outputs = True`
  556. # to allow torch compile to handle scalar outputs in those cases.
  557. same_max_seqlen = max_seqlen_q is max_seqlen_k # to avoid 2x device syncs
  558. if supports_mapping["max_seqlen_q"] and max_seqlen_q is not None:
  559. if not isinstance(max_seqlen_q, int) and is_tracing(max_seqlen_q):
  560. max_seqlen_q = max_seqlen_q.item()
  561. flash_kwargs["max_seqlen_q"] = max_seqlen_q
  562. if supports_mapping["max_seqlen_k"] and max_seqlen_k is not None:
  563. if same_max_seqlen and flash_kwargs["max_seqlen_q"] is not None:
  564. max_seqlen_k = flash_kwargs["max_seqlen_q"]
  565. elif not isinstance(max_seqlen_k, int) and is_tracing(max_seqlen_k):
  566. max_seqlen_k = max_seqlen_k.item()
  567. flash_kwargs["max_seqlen_k"] = max_seqlen_k
  568. return flash_kwargs
  569. def _flash_attention_forward(
  570. query_states: torch.Tensor,
  571. key_states: torch.Tensor,
  572. value_states: torch.Tensor,
  573. attention_mask: torch.Tensor | None,
  574. query_length: int,
  575. is_causal: bool,
  576. dropout: float = 0.0,
  577. position_ids: torch.Tensor | None = None,
  578. softmax_scale: float | None = None,
  579. sliding_window: int | None = None,
  580. use_top_left_mask: bool = False,
  581. softcap: float | None = None,
  582. deterministic: bool | None = None,
  583. cu_seq_lens_q: torch.LongTensor | None = None,
  584. cu_seq_lens_k: torch.LongTensor | None = None,
  585. max_length_q: int | None = None,
  586. max_length_k: int | None = None,
  587. target_dtype: torch.dtype | None = None,
  588. attn_implementation: str | None = None,
  589. **kwargs,
  590. ):
  591. """
  592. Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
  593. first unpad the input, then computes the attention scores and pad the final attention scores.
  594. (Optional) kwargs are described further in `_process_flash_attention_kwargs` and `FlashAttentionKwargs`.
  595. Args:
  596. query_states (`torch.Tensor`):
  597. Input query states to be passed to Flash Attention API
  598. key_states (`torch.Tensor`):
  599. Input key states to be passed to Flash Attention API
  600. value_states (`torch.Tensor`):
  601. Input value states to be passed to Flash Attention API
  602. attention_mask (`torch.Tensor`, *optional*):
  603. The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
  604. position of padding tokens and 1 for the position of non-padding tokens.
  605. attn_implementation (`str`, *optional*):
  606. The attention implementation to use. If None, will default to the one based on the environment.
  607. """
  608. (flash_fn, flash_varlen_fn, _, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_attention(
  609. attn_implementation
  610. )
  611. # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
  612. query_states, key_states, value_states = fa_peft_integration_check(
  613. query_states, key_states, value_states, target_dtype
  614. )
  615. # Extract the flash attention kwargs that have been requested (and are supported by the implementation)
  616. flash_kwargs = partial(
  617. process_flash_kwargs_fn,
  618. query_length=query_length,
  619. key_length=key_states.size(1),
  620. is_causal=is_causal,
  621. dropout=dropout,
  622. softmax_scale=softmax_scale,
  623. sliding_window=sliding_window,
  624. use_top_left_mask=use_top_left_mask,
  625. softcap=softcap,
  626. deterministic=deterministic,
  627. **kwargs,
  628. )
  629. # We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
  630. # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`.
  631. # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
  632. # use `flash_varlen_fn` knowing we already have all necessary the kwargs.
  633. #
  634. # NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model.
  635. # See #39121 for more information.
  636. is_fa_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0))
  637. is_fa_with_varlen_kwargs = all(
  638. kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
  639. )
  640. # Contains at least one padding token in the sequence
  641. if attention_mask is not None:
  642. q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input(
  643. query_states, key_states, value_states, attention_mask, query_length, unpad_fn
  644. )
  645. # TODO for now this is required to work with
  646. # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
  647. if "mps" in str(q.device):
  648. cu_seq_lens_k = cu_seq_lens_k.clone()
  649. out_unpad = flash_varlen_fn(
  650. q,
  651. k,
  652. v,
  653. cu_seqlens_q=cu_seq_lens_q,
  654. cu_seqlens_k=cu_seq_lens_k,
  655. **flash_kwargs(max_seqlen_q=max_length_q, max_seqlen_k=max_length_k),
  656. )
  657. if isinstance(out_unpad, tuple):
  658. out_unpad = out_unpad[0]
  659. out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length)
  660. # Padding free, i.e. sequences flattened into one total sequence
  661. elif is_fa_with_varlen_kwargs or is_fa_with_position_ids:
  662. if cu_seq_lens_q is None or cu_seq_lens_k is None:
  663. q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids(
  664. query_states, key_states, value_states, position_ids
  665. )
  666. else:
  667. q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
  668. k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
  669. v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
  670. # TODO for now this is required to work with
  671. # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
  672. if "mps" in str(q.device):
  673. cu_seq_lens_k = cu_seq_lens_k.clone()
  674. out = flash_varlen_fn(
  675. q,
  676. k,
  677. v,
  678. cu_seqlens_q=cu_seq_lens_q,
  679. cu_seqlens_k=cu_seq_lens_k,
  680. **flash_kwargs(max_seqlen_q=max_length_q, max_seqlen_k=max_length_k),
  681. )
  682. if isinstance(out, tuple):
  683. out = out[0]
  684. out = out.view(query_states.size(0), -1, out.size(-2), out.size(-1))
  685. # No padding
  686. else:
  687. out = flash_fn(query_states, key_states, value_states, **flash_kwargs())
  688. if isinstance(out, tuple):
  689. out = out[0]
  690. return out