| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376 |
- # mypy: allow-untyped-defs
- """Defines bias subclasses that work with scaled_dot_product_attention"""
- from enum import auto, IntEnum
- from warnings import warn
- import torch
- import torch.nn.functional as F
- from torch.backends.cuda import (
- can_use_efficient_attention,
- can_use_flash_attention,
- is_flash_attention_available,
- SDPAParams,
- )
- from torch.nn.attention import _raise_kernel_warnings
- from torch.nn.attention._utils import (
- _calculate_scale,
- _input_requires_grad,
- _postprocess_flash_output,
- _validate_sdpa_input,
- )
- __all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"]
- torch._dynamo.allow_in_graph(is_flash_attention_available)
- torch._dynamo.allow_in_graph(can_use_flash_attention)
- torch._dynamo.allow_in_graph(can_use_efficient_attention)
- torch._dynamo.allow_in_graph(SDPAParams)
- class CausalVariant(IntEnum):
- r"""
- Enum for causal variants used in attention mechanisms.
- Defines two types of causal biases:
- ``UPPER_LEFT``: Represents upper-left triangular bias for standard causal attention.
- The equivalent pytorch code for constructing this bias is:
- .. code-block:: python
- torch.tril(torch.ones(size, dtype=torch.bool))
- For instance, with ``shape=(3,4)``, the materialized bias tensor will be:
- .. code-block:: text
- [[1, 0, 0, 0],
- [1, 1, 0, 0],
- [1, 1, 1, 0]]
- ``LOWER_RIGHT``: Represents lower-right triangular bias, the include values are aligned to the lower
- right corner of the matrix.
- The equivalent pytorch code for constructing this bias is:
- .. code-block:: python
- diagonal_offset = size[1] - size[0]
- torch.tril(
- torch.ones(size, dtype=torch.bool),
- diagonal=diagonal_offset,
- )
- For instance, with ``shape=(3,4)``, the materialized bias tensor will be:
- .. code-block:: text
- [[1, 1, 0, 0],
- [1, 1, 1, 0],
- [1, 1, 1, 1]]
- Note that these variants are equivalent to each other when the sequence lengths of the query and key/value
- tensors are equal since the triangular matrix is square.
- .. warning:: This enum is a prototype and subject to change.
- """
- UPPER_LEFT = auto()
- LOWER_RIGHT = auto()
- class CausalBias(torch.Tensor):
- """
- A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum.
- This class is used for defining causal (triangular) attention biases. For construing the bias, there exist
- two factory functions: :func:`causal_upper_left` and :func:`causal_lower_right`.
- Example:
- .. code-block:: python
- from torch.nn.attention.bias import causal_lower_right
- bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8
- # Create a lower-right causal bias
- attn_bias = causal_lower_right(seqlen_q, seqlen_kv)
- q = torch.randn(
- bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16
- )
- k = torch.randn(
- bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16
- )
- v = torch.randn(
- bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16
- )
- out = F.scaled_dot_product_attention(q, k, v, attn_bias)
- .. warning:: This class is a prototype and subject to change.
- """
- def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int) -> None:
- """
- Initializes the CausalBias instance with a specified variant and sequence lengths.
- Args:
- variant (CausalVariant): The type of causal bias to use (either UPPER_LEFT or LOWER_RIGHT).
- seq_len_q (int): The sequence length of the query tensor.
- seq_len_kv (int): The sequence length of the key/value tensor.
- Raises a warning if the LOWER_RIGHT variant is used with seq_len_q > seq_len_kv, as it may produce NaNs.
- """
- if not isinstance(variant, CausalVariant):
- raise AssertionError(
- f"variant must be a CausalVariant, got {type(variant).__name__}"
- )
- super().__init__()
- self.variant = variant
- self.seq_len_q = seq_len_q
- self.seq_len_kv = seq_len_kv
- if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT:
- warn(
- "Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!",
- stacklevel=2,
- )
- def _upper_left(self, device: torch.device) -> torch.Tensor:
- """Upper left causal bias"""
- return torch.tril(
- torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool)
- )
- def _lower_right(self, device: torch.device) -> torch.Tensor:
- """Lower right causal bias"""
- diagonal_offset = self.seq_len_kv - self.seq_len_q
- return torch.tril(
- torch.ones(
- self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool
- ),
- diagonal=diagonal_offset,
- )
- # pyrefly: ignore [bad-return]
- def _materialize(self, device: torch.device | None = None) -> torch.Tensor:
- """
- Materializes the causal bias into a tensor form.
- Depending on the variant, this method generates either an upper-left or lower-right
- triangular matrix to represent the causal bias.
- Args:
- device (Optional[torch.device]): The device on which to create the tensor. Defaults to CPU.
- Returns:
- torch.Tensor: The materialized bias tensor.
- """
- if device is None:
- device = torch.device("cpu")
- if self.variant == CausalVariant.UPPER_LEFT:
- return self._upper_left(device)
- elif self.variant == CausalVariant.LOWER_RIGHT:
- return self._lower_right(device)
- @staticmethod
- def _dispatch(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attn_mask: "CausalBias",
- dropout_p: float = 0.0,
- is_causal: bool = False,
- scale: float | None = None,
- enable_gqa: bool = False,
- ) -> torch.Tensor:
- r"""
- Handles the logic for computing attention with the specified causal bias.
- Args:
- query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
- key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
- value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
- attn_mask (CausalBias): The type of causal attention to apply.
- A boolean mask where a value of True indicates that the element *should* take part in attention.
- A float mask of the same type as query, key, value that is added to the attention score.
- dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
- is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal
- are set.
- scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
- to :math:`\frac{1}{\sqrt{E}}`.
- enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
- Returns:
- output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
- Raises:
- ValueError: If the causal bias variant is not a CausalVariant type.
- """
- if is_causal:
- raise ValueError("CausalBias should not be used with causal=True")
- if (
- attn_mask.seq_len_q == attn_mask.seq_len_kv
- or attn_mask.variant == CausalVariant.UPPER_LEFT
- ):
- return F.scaled_dot_product_attention(
- query,
- key,
- value,
- attn_mask=None,
- dropout_p=dropout_p,
- is_causal=True,
- scale=scale,
- enable_gqa=enable_gqa,
- )
- elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
- _validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
- sdpa_params = SDPAParams(
- query, key, value, None, dropout_p, is_causal, enable_gqa
- )
- if can_use_flash_attention(sdpa_params):
- alignment = 64 if query.device.type == "xpu" else 8
- og_head_size = query.size(-1)
- og_scale = _calculate_scale(og_head_size, scale)
- needs_padding = og_head_size % alignment != 0
- if needs_padding:
- pad_len = alignment - (og_head_size % alignment)
- query = torch.nn.functional.pad(query, (0, pad_len))
- key = torch.nn.functional.pad(key, (0, pad_len))
- value = torch.nn.functional.pad(value, (0, pad_len))
- out = torch.ops.aten._scaled_dot_product_flash_attention(
- query,
- key,
- value,
- dropout_p,
- is_causal=True, # TODO: Flash accepts causal = True and for this particular op it means lower right
- return_debug_mask=False,
- scale=og_scale,
- )[0]
- return _postprocess_flash_output(out, og_head_size)
- if can_use_efficient_attention(sdpa_params):
- compute_log_sumexp = False
- if _input_requires_grad(query, key, value):
- compute_log_sumexp = True
- return torch.ops.aten._efficient_attention_forward(
- query.transpose(1, 2),
- key.transpose(1, 2),
- value.transpose(1, 2),
- bias=None,
- cu_seqlens_q=None,
- cu_seqlens_k=None,
- max_seqlen_q=None,
- max_seqlen_k=None,
- dropout_p=dropout_p,
- custom_mask_type=int(attn_mask.variant),
- compute_log_sumexp=compute_log_sumexp,
- scale=scale,
- seqlen_k=None,
- )[0].transpose(1, 2)
- else:
- _raise_kernel_warnings(sdpa_params)
- # We can't use efficient attention the only support for lower right is via materialization
- return F.scaled_dot_product_attention(
- query,
- key,
- value,
- attn_mask=attn_mask._materialize(query.device),
- dropout_p=dropout_p,
- is_causal=False,
- scale=scale,
- enable_gqa=enable_gqa,
- )
- else:
- raise ValueError(
- f"CausalBias.variant must be a CausalVariant type, but found: {attn_mask.variant}"
- )
- @classmethod
- def __torch_function__(cls, func, types, args=(), kwargs=None):
- """Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias"""
- if kwargs is None:
- kwargs = {}
- if func is torch.nn.functional.scaled_dot_product_attention:
- return cls._dispatch(*args, **kwargs)
- return super().__torch_function__(func, types, args, kwargs)
- def __repr__(self) -> str: # type:ignore[override]
- return self._materialize().__repr__()
- def causal_upper_left(*size) -> CausalBias:
- """
- Creates an upper-left triangular causal bias.
- This function generates a upper-left triangular matrix to represent causal attention bias with a
- diagonal offset set so that the inclusive values are aligned to the upper left corner of the matrix.
- This equivalent to the `is_causal=True` argument in `scaled_dot_product_attention`.
- The equivalent pytorch code for constructing this bias is:
- .. code-block:: python
- torch.tril(torch.ones(size, dtype=torch.bool))
- For instance, with `shape=(3,4)`, the materialized bias tensor will be:
- .. code-block:: text
- [[1, 0, 0, 0],
- [1, 1, 0, 0],
- [1, 1, 1, 0]]
- Args:
- size: The size of the bias matrix.
- Returns:
- CausalBias: The UPPER_LEFT triangular causal bias variant.
- """
- if len(size) != 2:
- raise AssertionError("causal_upper_left only supports 2D tensors")
- seq_len_q, seq_len_kv = size
- return CausalBias(CausalVariant.UPPER_LEFT, seq_len_q, seq_len_kv)
- def causal_lower_right(*size) -> CausalBias:
- """
- Creates a lower-right triangular causal bias.
- This function generates a lower-right triangular matrix to represent causal attention bias with a
- diagonal offset set so that the inclusive values are aligned to the lower right corner of the matrix.
- The equivalent pytorch code for constructing this bias is:
- .. code-block:: python
- diagonal_offset = size[1] - size[0]
- torch.tril(
- torch.ones(size, dtype=torch.bool),
- diagonal=diagonal_offset,
- )
- For instance, with `shape=(3,4)`, the materialized bias tensor will be:
- .. code-block:: text
- [[1, 1, 0, 0],
- [1, 1, 1, 0],
- [1, 1, 1, 1]]
- Args:
- size: The size of the bias matrix.
- Returns:
- CausalBias: The LOWER_RIGHT triangular causal bias variant.
- """
- if len(size) != 2:
- raise AssertionError("causal_lower_right only supports 2D tensors")
- seq_len_q, seq_len_kv = size
- return CausalBias(CausalVariant.LOWER_RIGHT, seq_len_q, seq_len_kv)
|