| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407 |
- """
- Variable-length attention implementation using Flash Attention.
- This module provides a high-level Python interface for variable-length attention
- that calls into the optimized Flash Attention kernels.
- """
- import logging
- from functools import lru_cache
- from typing import Any, NamedTuple
- import torch
- log = logging.getLogger(__name__)
- __all__ = ["varlen_attn", "AuxRequest"]
- def _normalize_window_size(window_size: list[int] | None) -> list[int]:
- if window_size is None:
- window_size = [-1, -1]
- if len(window_size) != 2:
- raise ValueError(f"window_size must have length 2, got {len(window_size)}")
- return window_size
- @lru_cache(maxsize=8)
- def _should_use_cudnn(device_index: int) -> bool:
- """Cache device capability check to avoid repeated CUDA calls."""
- return False
- class AuxRequest(NamedTuple):
- """
- Request which auxiliary outputs to compute from varlen_attn.
- Each field is a boolean indicating whether that auxiliary output should be computed.
- """
- lse: bool = False
- @torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={})
- def _varlen_attn(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- cu_seq_q: torch.Tensor,
- cu_seq_k: torch.Tensor,
- max_q: int,
- max_k: int,
- is_causal: bool = False,
- scale: float | None = None,
- window_size: list[int] | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- Private custom op for variable-length attention.
- This is the internal implementation. Users should use the public varlen_attn function instead.
- """
- window_size = _normalize_window_size(window_size)
- use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
- if use_cudnn:
- log.info("Using cuDNN backend for varlen_attn")
- if window_size[0] != -1 or window_size[1] != -1:
- raise RuntimeError(
- "cuDNN backend does not support window attention. Please use Flash Attention backend."
- )
- result = torch.ops.aten._cudnn_attention_forward(
- query,
- key,
- value,
- None, # attn_bias
- cu_seq_q,
- cu_seq_k,
- max_q,
- max_k,
- True, # compute_log_sumexp
- 0.0, # dropout_p hardcoded to 0.0
- is_causal,
- False, # return_debug_mask
- scale=scale,
- )
- # cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask)
- output, softmax_lse, rng_state = result[0], result[1], result[6]
- else:
- log.info("Using Flash Attention backend for varlen_attn")
- output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward(
- query,
- key,
- value,
- cu_seq_q,
- cu_seq_k,
- max_q,
- max_k,
- 0.0, # dropout_p hardcoded to 0.0
- is_causal,
- return_debug_mask=False,
- scale=scale,
- window_size_left=window_size[0],
- window_size_right=window_size[1],
- )
- rng_state_ = torch.zeros(
- (2,), dtype=torch.uint64, device=query.device
- ) # hardcoded since dropout is hardcoded to 0
- return output, softmax_lse, rng_state_
- @_varlen_attn.register_fake
- def _varlen_attn_fake(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- cu_seq_q: torch.Tensor,
- cu_seq_k: torch.Tensor,
- max_q: int,
- max_k: int,
- is_causal: bool = False,
- scale: float | None = None,
- window_size: list[int] | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- Fake implementation for meta tensor computation and tracing.
- Based on the 3D varlen path from meta__flash_attention_forward:
- - query shape: (total, num_heads, head_dim)
- - logsumexp shape: (num_heads, total_q)
- """
- window_size = _normalize_window_size(window_size)
- # Output has same shape as query
- output = torch.empty_like(query)
- # For varlen path: logsumexp shape is (num_heads, total_q)
- total_q = query.size(0)
- num_heads = query.size(1)
- if torch.version.hip:
- # ROCm uses batched format: [batch_size, num_heads, max_q]
- batch_size = cu_seq_q.size(0) - 1
- logsumexp = torch.empty(
- (batch_size, num_heads, max_q), dtype=torch.float, device=query.device
- )
- else:
- logsumexp = torch.empty(
- (num_heads, total_q), dtype=torch.float, device=query.device
- )
- rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device)
- return output, logsumexp, rng_state
- def varlen_attn(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- cu_seq_q: torch.Tensor,
- cu_seq_k: torch.Tensor,
- max_q: int,
- max_k: int,
- *,
- return_aux: AuxRequest | None = None,
- scale: float | None = None,
- window_size: tuple[int, int] = (-1, -1),
- ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- """
- Compute variable-length attention using Flash Attention.
- This function is similar to scaled_dot_product_attention but optimized for
- variable-length sequences using cumulative sequence position tensors.
- Args:
- query (Tensor): Query tensor; shape :math:`(T_q, H, D)`
- key (Tensor): Key tensor; shape :math:`(T_k, H, D)`
- value (Tensor): Value tensor; shape :math:`(T_k, H, D)`
- cu_seq_q (Tensor): Cumulative sequence positions for queries; shape :math:`(N+1,)`
- cu_seq_k (Tensor): Cumulative sequence positions for keys/values; shape :math:`(N+1,)`
- max_q (int): Maximum query sequence length in the batch.
- max_k (int): Maximum key/value sequence length in the batch.
- return_aux (Optional[AuxRequest]): If not None and ``return_aux.lse`` is True, also returns the logsumexp tensor.
- scale (float, optional): Scaling factor for attention scores
- window_size (tuple[int, int], optional): Window size for sliding window attention as (left, right).
- Use (-1, -1) for full attention (default), (-1, 0) for causal attention,
- or (W, 0) for causal attention with sliding window of size W.
- Returns:
- output (Tensor): Output tensor from attention computation; shape :math:`(T_q, H, D)`.
- If ``return_aux`` is not None and ``return_aux.lse`` is True:
- lse (Tensor): Log-sum-exp of attention scores; shape :math:`(T_q, H)`.
- Shape legend:
- - :math:`N`: Batch size
- - :math:`T_q`: Total number of query tokens in the batch (sum of all query sequence lengths)
- - :math:`T_k`: Total number of key/value tokens in the batch (sum of all key/value sequence lengths)
- - :math:`H`: Number of attention heads
- - :math:`D`: Head dimension
- Example::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
- >>> batch_size, max_seq_len, embed_dim, num_heads = 2, 512, 1024, 16
- >>> head_dim = embed_dim // num_heads
- >>> seq_lengths = []
- >>> for _ in range(batch_size):
- ... length = torch.randint(1, max_seq_len // 64 + 1, (1,)).item() * 64
- ... seq_lengths.append(min(length, max_seq_len))
- >>> seq_lengths = torch.tensor(seq_lengths, device="cuda")
- >>> total_tokens = seq_lengths.sum().item()
- >>>
- >>> # Create packed query, key, value tensors
- >>> query = torch.randn(
- ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
- ... )
- >>> key = torch.randn(
- ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
- ... )
- >>> value = torch.randn(
- ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
- ... )
- >>>
- >>> # Build cumulative sequence tensor
- >>> cu_seq = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
- >>> cu_seq[1:] = seq_lengths.cumsum(0)
- >>> max_len = seq_lengths.max().item()
- >>>
- >>> # Call varlen_attn
- >>> output = varlen_attn(
- ... query, key, value, cu_seq, cu_seq, max_len, max_len
- ... )
- """
- is_causal = window_size == (-1, 0)
- out, lse, _ = torch.ops.torch_attn._varlen_attn(
- query,
- key,
- value,
- cu_seq_q,
- cu_seq_k,
- max_q,
- max_k,
- is_causal,
- scale,
- list(window_size),
- )
- if return_aux is not None and return_aux.lse:
- return out, lse
- return out
- def _setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None:
- (
- query,
- key,
- value,
- cu_seq_q,
- cu_seq_k,
- max_q,
- max_k,
- is_causal,
- scale,
- window_size,
- ) = inputs
- out, lse, rng_state = output
- ctx.save_for_backward(query, key, value, cu_seq_q, cu_seq_k, out, lse, rng_state)
- ctx.max_q = max_q
- ctx.max_k = max_k
- ctx.is_causal = is_causal
- ctx.scale = scale
- ctx.window_size = window_size
- @torch.library.custom_op("torch_attn::_varlen_attn_backward", mutates_args={})
- def _varlen_attn_backward(
- grad_out: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- out: torch.Tensor,
- lse: torch.Tensor,
- cu_seq_q: torch.Tensor,
- cu_seq_k: torch.Tensor,
- max_q: int,
- max_k: int,
- is_causal: bool,
- rng_state: torch.Tensor,
- scale: float | None = None,
- window_size: list[int] | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- window_size = _normalize_window_size(window_size)
- unused = torch.empty(0, device=query.device)
- use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
- if use_cudnn:
- log.info("Using cuDNN backend for varlen_attn")
- if window_size[0] != -1 or window_size[1] != -1:
- raise RuntimeError(
- "cuDNN backend does not support window attention. Please use Flash Attention backend."
- )
- dq, dk, dv = torch.ops.aten._cudnn_attention_backward(
- grad_out,
- query,
- key,
- value,
- out,
- lse,
- cu_seq_q,
- cu_seq_k,
- max_q,
- max_k,
- 0.0,
- is_causal,
- rng_state,
- unused,
- scale=scale,
- )
- else:
- log.info("Using Flash Attention backend for varlen_attn")
- dq, dk, dv = torch.ops.aten._flash_attention_backward(
- grad_out,
- query,
- key,
- value,
- out,
- lse,
- cu_seq_q,
- cu_seq_k,
- max_q,
- max_k,
- 0.0,
- is_causal,
- rng_state,
- unused,
- scale=scale,
- window_size_left=window_size[0],
- window_size_right=window_size[1],
- )
- return dq, dk, dv
- @_varlen_attn_backward.register_fake
- def _varlen_attn_backward_fake(
- grad_out: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- out: torch.Tensor,
- lse: torch.Tensor,
- cu_seq_q: torch.Tensor,
- cu_seq_k: torch.Tensor,
- max_q: int,
- max_k: int,
- is_causal: bool,
- rng_state: torch.Tensor,
- scale: float | None = None,
- window_size: list[int] | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- Fake implementation for meta tensor computation and tracing.
- """
- window_size = _normalize_window_size(window_size)
- grad_query = torch.empty_like(query)
- grad_key = torch.empty_like(key)
- grad_value = torch.empty_like(value)
- return grad_query, grad_key, grad_value
- def _backward(
- ctx: Any, grad_out: torch.Tensor, grad_lse: torch.Tensor, grad_rng: torch.Tensor
- ) -> tuple[torch.Tensor | None, ...]:
- query, key, value, cu_seq_q, cu_seq_k, out, lse, rng_state = ctx.saved_tensors
- max_q = ctx.max_q
- max_k = ctx.max_k
- is_causal = ctx.is_causal
- scale = ctx.scale
- window_size = ctx.window_size
- dq, dk, dv = torch.ops.torch_attn._varlen_attn_backward(
- grad_out,
- query,
- key,
- value,
- out,
- lse,
- cu_seq_q,
- cu_seq_k,
- max_q,
- max_k,
- is_causal,
- rng_state,
- scale,
- window_size,
- )
- return dq, dk, dv, None, None, None, None, None, None, None
- _varlen_attn.register_autograd(_backward, setup_context=_setup_context)
|