| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364 |
- """
- Partially inspired by torchtune's flex attention implementation
- Citation:
- @software{torchtune,
- title = {torchtune: PyTorch's finetuning library},
- author = {torchtune maintainers and contributors},
- url = {https//github.com/pytorch/torchtune},
- license = {BSD-3-Clause},
- month = apr,
- year = {2024}
- }
- """
- # coding=utf-8
- # Copyright 2025 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Optional, Union
- import torch
- from packaging import version
- from ..utils import is_torch_flex_attn_available, logging
- from ..utils.import_utils import (
- get_torch_version,
- is_torch_greater_or_equal,
- is_torch_less_or_equal,
- is_torchdynamo_compiling,
- )
- _TORCH_FLEX_USE_AUX = is_torch_greater_or_equal("2.9.0")
- if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size
- from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
- if _TORCH_FLEX_USE_AUX:
- from torch.nn.attention.flex_attention import AuxRequest
- else:
- AuxRequest = None
- logger = logging.get_logger(__name__)
- class WrappedFlexAttention:
- """
- We are doing a singleton class so that flex attention is compiled once when it's first called.
- """
- _instance = None
- _is_flex_compiled = False
- _compiled_flex_attention = None
- def __new__(cls, *args, **kwargs):
- if cls._instance is None:
- # Create a new instance if one doesn't already exist
- cls._instance = super().__new__(cls)
- return cls._instance
- @torch.compiler.disable(recursive=False)
- def __init__(self, training):
- """
- Initialize or update the singleton instance.
- """
- if not self._is_flex_compiled or training != self.training:
- self.training = training
- if is_torch_less_or_equal("2.5.1"):
- self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
- # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
- # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
- # see https://github.com/pytorch/pytorch/issues/146260 for training
- elif version.parse(get_torch_version()).base_version == "2.6.0" and training:
- self._compiled_flex_attention = torch.compile(
- flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
- )
- # Fallback, usually the most recent torch 2.7.x+ versions
- else:
- self._compiled_flex_attention = torch.compile(flex_attention)
- self._is_flex_compiled = True
- def __call__(self):
- return self._compiled_flex_attention
- def get_flex_attention_lse_kwargs(return_lse: bool) -> dict[str, bool | Optional["AuxRequest"]]:
- """
- Requests the LSE from flex_attention in a version-agnostic fashion.
- Before torch 2.9, the LSE was requested via the boolean return_lse field. However, starting with
- torch 2.9, an AuxRequest object must be passed via the aux_request field. This method conditionally
- returns the correct form based on the python version.
- """
- if _TORCH_FLEX_USE_AUX:
- return {"return_aux": AuxRequest(lse=True) if return_lse else None}
- return {"return_lse": return_lse}
- def compile_friendly_flex_attention(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- training=False,
- **kwargs,
- ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- # First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
- # Do not use compiled version if already compiling forward (it raises issues)
- flex_attention_compiled = WrappedFlexAttention(training)() if not is_torchdynamo_compiling() else flex_attention
- return flex_attention_compiled(
- query,
- key,
- value,
- **kwargs,
- )
- Offset = torch.Tensor | int
- # TODO: deprecate / rename to make_flex_block_mask for clarity as it's not only causal anymore
- def make_flex_block_causal_mask(
- attention_mask_2d: torch.Tensor,
- attention_chunk_size: int | None = None,
- query_length=None,
- key_length=None,
- offsets: tuple[Offset, Offset] | None = None,
- is_causal: bool | None = True,
- ) -> "BlockMask":
- """
- IMPORTANT NOTICE: This function is deprecated in favor of using the mask primitives in `masking_utils.py`,
- and will be removed in a future version without warnings. New code should not use it. It is only kept here
- for BC for now, while models using it are being patched accordingly.
- Create a block (causal) document mask for a batch of sequences, both packed and unpacked.
- Create Block (causal) logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
- The resultant BlockMask is a compressed representation of the full (causal) block
- mask. BlockMask is essential for performant computation of flex attention.
- See: https://pytorch.org/blog/flexattention/
- Args:
- attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
- of shape (batch_size, total_seq_len). e.g.
- For unpacked sequence:
- [[1, 1, 1, 1, 0, 0, 0],
- [1, 1, 1, 1, 1, 0, 0]]
- For packed sequence:
- [[1, 1, 1, 2, 2, 2, 0],
- [1, 1, 2, 2, 2, 3, 3]]
- Returns:
- BlockMask
- """
- batch_size, total_seq_len = attention_mask_2d.shape
- if not key_length:
- key_length = total_seq_len
- if not query_length:
- query_length = total_seq_len
- # older torch (2.5.x) cannot handle sequences not in multiples of 128 (default block size)
- pad_len = ((key_length // flex_default_block_size) + 1) * flex_default_block_size
- attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, pad_len - key_length))
- device = attention_mask_2d.device
- document_ids = attention_mask_2d.clone()
- if attention_chunk_size is not None:
- # we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
- chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (attention_chunk_size)
- # Instead of passing a tensor mask, flex attention requires a mask_mod function
- # that determines which elements of QK^T should be included in the attention
- # computation prior to the softmax. For sample packing, we need both the
- # logic for both causal mask and document mask. See PyTorch's official
- # blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
- def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
- """
- Defines the logic of a block causal mask by combining both a standard causal mask
- and a block diagonal document mask.
- See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
- for an illustration.
- """
- causal_mask = q_idx >= kv_idx # not valid when decoding
- document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
- padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
- final_mask = causal_mask & padding_mask & document_mask
- return final_mask
- def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
- """
- Combines the chunk mask with the causal mask for chunked attention.
- """
- chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
- causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
- return chunk_mask & causal_doc_mask
- def default_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
- """
- Utilizes default attention mask to enable encoder and encoder-decoder
- attention masks.
- """
- document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
- # kv indexing is crucial in order to work correctly
- padding_mask = attention_mask_2d[batch_idx, kv_idx] > 0
- final_mask = padding_mask & document_mask
- return final_mask
- if not is_causal:
- mask_mod_maybe_combined = default_mask_mod
- else:
- mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
- if offsets is not None:
- q_offset = offsets[0].to(device)
- kv_offset = offsets[1].to(device)
- def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
- offset_q = q_idx + q_offset
- offset_kv = kv_idx + kv_offset
- return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
- else:
- mask_mod = mask_mod_maybe_combined
- return create_block_mask(
- mask_mod=mask_mod,
- B=batch_size,
- H=None, # attention head
- Q_LEN=query_length,
- KV_LEN=key_length,
- device=device,
- # compiling the mask is not BC with older torch
- _compile=not is_torch_less_or_equal("2.5.1"),
- )
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
- def flex_attention_forward(
- module: torch.nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Union[torch.Tensor, "BlockMask"],
- scaling: float | None = None,
- softcap: float | None = None,
- s_aux: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- if kwargs.get("dropout", 0.0) > 0:
- raise ValueError(
- "`flex_attention` does not support `dropout`. Please use it with inference"
- " only (`model.eval()`) or turn off the attention dropout in the respective config."
- )
- block_mask = None
- score_mask = None
- if isinstance(attention_mask, BlockMask):
- block_mask = attention_mask
- else:
- score_mask = attention_mask
- if score_mask is not None:
- score_mask = score_mask[:, :, :, : key.shape[-2]]
- def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
- if softcap is not None:
- score = softcap * torch.tanh(score / softcap)
- if score_mask is not None:
- score = score + score_mask[batch_idx][0][q_idx][kv_idx]
- # Note: attention sinks cannot be correctly implemented in score_mod
- # because it requires operating on the full attention matrix before softmax.
- # ==> this is done after flex attention
- return score
- enable_gqa = True
- num_local_query_heads = query.shape[1]
- # When running TP this helps:
- if (num_local_query_heads & (num_local_query_heads - 1)) != 0:
- key = repeat_kv(key, query.shape[1] // key.shape[1])
- value = repeat_kv(value, query.shape[1] // value.shape[1])
- enable_gqa = False
- kernel_options = kwargs.get("kernel_options")
- # On CPU we must skip returning LSE due to a runtime issue; elsewhere, follow PyTorch API and return it
- return_lse = query.device.type != "cpu"
- if not return_lse and s_aux is not None:
- raise ValueError(
- "Attention sinks cannot be run on CPU with flex attention. Please switch to a different device, e.g. CUDA"
- )
- flex_attention_output = compile_friendly_flex_attention(
- query,
- key,
- value,
- score_mod=score_mod,
- block_mask=block_mask,
- enable_gqa=enable_gqa,
- scale=scaling,
- kernel_options=kernel_options,
- # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
- # For simplification, we thus always return it as no additional computations are introduced.
- training=module.training,
- # inject the lse args
- **get_flex_attention_lse_kwargs(return_lse),
- )
- if return_lse:
- # before torch 2.9, return_lse returns the LSE directly as a second tuple element
- # in torch 2.9 and later, return_aux returns AuxOutput as a second tuple element -- the LSE must be extracted
- if _TORCH_FLEX_USE_AUX:
- attention_output, aux = flex_attention_output # type: ignore[misc]
- lse = aux.lse
- else:
- attention_output, lse = flex_attention_output # type: ignore[misc]
- # lse is returned in float32
- lse = lse.to(value.dtype)
- if s_aux is not None:
- # Apply attention sinks by renormalizing using LSE
- batch_size, num_heads, seq_len_q, _ = attention_output.shape # batch, num_heads, seq_len, head_dim
- sinks = s_aux.view(1, -1, 1, 1).expand(batch_size, num_heads, seq_len_q, 1)
- # We need to compute the normalization that includes the sinks
- # since log(sum(exp(scores))) = lse, exp(log(sum(exp(scores)))) = exp(lse)
- # NB: log(sum(exp(scores)) + exp(sink)) = log(exp(lse) + exp(sink))
- lse_expanded = lse.unsqueeze(-1) # [batch, num_heads, seq_len, 1]
- combined_lse = torch.logsumexp(torch.cat([lse_expanded, sinks], dim=-1), dim=-1, keepdim=True)
- # Use new_norm / old_norm = exp(combined_lse - lse) to compute renorm and apply
- renorm_factor = torch.exp(lse_expanded - combined_lse)
- attention_output = attention_output * renorm_factor
- else:
- attention_output = flex_attention_output # type: ignore[assignment]
- lse = None
- attention_output = attention_output.transpose(1, 2).contiguous()
- return attention_output, lse
|