| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692 |
- """
- PROTOTYPE!
- Flash Attention 3 implementation.
- For fp8: only supports forward pass right now.
- For fp16/bf16: supports forward and backward pass.
- """
- # mypy: allow-untyped-defs
- from __future__ import annotations
- import importlib
- import warnings
- from typing import TYPE_CHECKING
- if TYPE_CHECKING:
- from collections.abc import Callable
- from dataclasses import dataclass
- from functools import cache
- from typing_extensions import TypeVarTuple, Unpack
- import torch
- from torch.library import Library
- from . import _registry
- __all__ = [
- "register_flash_attention_fa3",
- ]
- _FA3_CUDA_FWD: Callable | None = None # Cache for torch.ops.flash_attn_3.fwd
- _FA3_CUDA_BWD: Callable | None = None # Cache for torch.ops.flash_attn_3.bwd
- @dataclass
- class _FA3Handle:
- library: Library | None
- def remove(self) -> None:
- self.library = None
- # Clear the C++ flag
- torch._C._set_sdp_use_fa3(False)
- @cache
- def _get_device_major(device: torch.device) -> int:
- major, _ = torch.cuda.get_device_capability(device)
- return major
- def register_flash_attention_fa3(
- module_path: str = "flash_attn_interface",
- ) -> _FA3Handle:
- """
- Register FA3 flash attention kernels with the PyTorch dispatcher.
- Args:
- module_path: Python module path to the FA3 implementation.
- """
- _fa3_import_module(module_path)
- # Expose FA3 registration status to C++
- torch._C._set_sdp_use_fa3(True)
- return _FA3Handle(_fa3_register_kernels())
- def _fa3_import_module(module_path: str) -> None:
- importlib.import_module(module_path)
- if not hasattr(torch.ops, "flash_attn_3"):
- raise RuntimeError(f"Module '{module_path}' does not expose FA3 kernels")
- if not hasattr(torch.ops.flash_attn_3, "fwd"):
- raise RuntimeError(
- f"Module '{module_path}' does not expose FA3 forward kernels"
- )
- if not hasattr(torch.ops.flash_attn_3, "bwd"):
- raise RuntimeError(
- f"Module '{module_path}' does not expose FA3 backward kernels"
- )
- global _FA3_CUDA_FWD, _FA3_CUDA_BWD
- _FA3_CUDA_FWD = torch.ops.flash_attn_3.fwd
- _FA3_CUDA_BWD = torch.ops.flash_attn_3.bwd
- def _fa3_register_kernels() -> Library:
- lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901
- lib.impl(
- "_flash_attention_forward.quantized", _fa3_flash_attention_forward_impl, "CUDA"
- )
- lib.impl(
- "_scaled_dot_product_flash_attention.quantized",
- _fa3_scaled_dot_product_flash_attention_forward_impl,
- "CUDA",
- )
- lib.impl(
- "_flash_attention_forward", _fa3_flash_attention_forward_impl_default, "CUDA"
- )
- lib.impl(
- "_scaled_dot_product_flash_attention",
- _fa3_scaled_dot_product_flash_attention_forward_impl_default,
- "CUDA",
- )
- lib.impl("_flash_attention_backward", _fa3_flash_attention_backward_impl, "CUDA")
- lib.impl(
- "_scaled_dot_product_flash_attention_backward",
- _fa3_scaled_dot_product_flash_attention_backward_impl,
- "CUDA",
- )
- return lib
- def _fa3_common_support_error(
- query: torch.Tensor,
- tensors: tuple[torch.Tensor, ...],
- dropout_p: float,
- cum_seq_q: torch.Tensor | None,
- q_descale: torch.Tensor | None,
- k_descale: torch.Tensor | None,
- v_descale: torch.Tensor | None,
- ) -> str | None:
- if dropout_p != 0.0:
- return "dropout_p must be 0"
- if not all(t.is_cuda for t in tensors):
- return "inputs must be CUDA tensors"
- if len({t.device for t in tensors}) != 1:
- return "inputs must share device"
- if query.dtype == torch.float8_e4m3fn and (
- q_descale is None or k_descale is None or v_descale is None
- ):
- warnings.warn(
- "When using SDPA with fp8, descale tensor should always be used"
- " for accurate dequantization. Please use "
- "_scaled_dot_product_attention_quantized and "
- "provide the descale tensors.",
- UserWarning,
- )
- if cum_seq_q is None and query.dim() != 4:
- return "dense query must be 4D"
- if cum_seq_q is not None and query.dim() != 3:
- return "ragged query must be 3D"
- if not torch.cuda.is_available():
- return "CUDA not available"
- if _get_device_major(query.device) != 9:
- return "FA3 requires compute capability 9.0"
- return None
- def _fa3_forward_support_error(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- dropout_p: float,
- return_debug_mask: bool,
- alibi_slopes: torch.Tensor | None,
- seqused_k: torch.Tensor | None,
- cum_seq_q: torch.Tensor | None,
- q_descale: torch.Tensor | None,
- k_descale: torch.Tensor | None,
- v_descale: torch.Tensor | None,
- ) -> str | None:
- if return_debug_mask:
- return "return_debug_mask must be False"
- if alibi_slopes is not None:
- return "alibi_slopes not supported"
- if seqused_k is not None:
- if seqused_k.dtype != torch.int32:
- return "seqused_k must be int32"
- if not seqused_k.is_cuda:
- return "seqused_k must be CUDA"
- supported_dtypes = (torch.float8_e4m3fn, torch.float16, torch.bfloat16)
- if not all(t.dtype in supported_dtypes for t in {query, key, value}):
- return f"inputs must be one of {supported_dtypes}"
- if len({t.dtype for t in {query, key, value}}) != 1:
- return "all inputs must have the same dtype"
- error = _fa3_common_support_error(
- query,
- (query, key, value),
- dropout_p,
- cum_seq_q,
- q_descale,
- k_descale,
- v_descale,
- )
- if error is not None:
- if error == "inputs must share device":
- return "query, key, value must be on same device"
- return error
- return None
- def _fa3_backward_support_error(
- grad_out: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- out: torch.Tensor,
- logsumexp: torch.Tensor,
- dropout_p: float,
- cum_seq_q: torch.Tensor | None,
- window_size_left: int | None,
- window_size_right: int | None,
- ) -> str | None:
- # FA3 backward ONLY supports fp16/bf16, NOT fp8
- if query.dtype == torch.float8_e4m3fn:
- return (
- "FA3 backward does not support fp8 - use inference only (torch.no_grad())"
- )
- if logsumexp.dtype != torch.float32:
- return "logsumexp dtype must be float32"
- supported_dtypes = (torch.float16, torch.bfloat16)
- if not all(t.dtype in supported_dtypes for t in {grad_out, query, key, value, out}):
- return f"inputs must be one of {supported_dtypes}"
- if len({t.dtype for t in {grad_out, query, key, value, out}}) != 1:
- return "all inputs must have the same dtype"
- error = _fa3_common_support_error(
- query,
- (grad_out, query, key, value, out, logsumexp),
- dropout_p,
- cum_seq_q,
- None,
- None,
- None,
- )
- if error is not None:
- return error
- return None
- Ts = TypeVarTuple("Ts")
- def _transpose_dense(*tensors: Unpack[Ts]) -> tuple[Unpack[Ts]]:
- return tuple(t.transpose(1, 2) for t in tensors) # type: ignore[attr-defined]
- def _maybe_contiguous(x: torch.Tensor | None) -> torch.Tensor | None:
- """Ensure tensor is contiguous in the last dimension."""
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
- def _fa3_run_forward(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- cu_seq_q: torch.Tensor | None,
- cu_seq_k: torch.Tensor | None,
- max_q: int,
- max_k: int,
- scale: float | None,
- is_causal: bool,
- window_size_left: int | None,
- window_size_right: int | None,
- seqused_k: torch.Tensor | None,
- out: torch.Tensor | None = None,
- q_descale: torch.Tensor | None = None,
- k_descale: torch.Tensor | None = None,
- v_descale: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Run the FA3 forward pass by calling the C++ kernel directly.
- """
- if _FA3_CUDA_FWD is None:
- raise RuntimeError("FA3 not registered")
- # Ensure contiguous in the last dimension
- q = _maybe_contiguous(query)
- k = _maybe_contiguous(key)
- v = (
- value.contiguous()
- if value.dtype == torch.float8_e4m3fn
- and value.stride(-1) != 1
- and value.stride(-3) != 1
- else _maybe_contiguous(value)
- )
- cu_seqlens_q = _maybe_contiguous(cu_seq_q)
- cu_seqlens_k = _maybe_contiguous(cu_seq_k)
- seqused_k = _maybe_contiguous(seqused_k)
- out, softmax_lse, out_accum, softmax_lse_accum = _FA3_CUDA_FWD(
- q,
- k,
- v,
- None, # k_new
- None, # v_new
- None, # qv
- out, # out_ (pre-allocated output)
- cu_seqlens_q, # cu_seqlens_q
- cu_seqlens_k, # cu_seqlens_k
- None, # cu_seqlens_k_new
- None, # seqused_q
- seqused_k, # seqused_k
- max_q, # max_seqlen_q
- max_k, # max_seqlen_k
- None, # page_table,
- None, # kv_batch_idx,
- None, # leftpad_k,
- None, # rotary_cos,
- None, # rotary_sin,
- None, # seqlens_rotary,
- q_descale, # q_descale,
- k_descale, # k_descale,
- v_descale, # v_descale,
- scale, # softmax_scale,
- is_causal, # causal,
- window_size_left if window_size_left is not None else -1, # window_size_left
- window_size_right if window_size_right is not None else -1, # window_size_right
- 0, # attention_chunk,
- 0.0, # softcap,
- True, # rotary_interleaved,
- None, # scheduler_metadata,
- 1 if torch.are_deterministic_algorithms_enabled() else 0, # num_splits,
- None, # pack_gqa,
- torch._C._get_sm_carveout_experimental() or 0, # sm_margin,
- )
- return out, softmax_lse.contiguous()
- def _fa3_run_backward(
- grad_out: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- out: torch.Tensor,
- logsumexp: torch.Tensor,
- cu_seq_q: torch.Tensor | None,
- cu_seq_k: torch.Tensor | None,
- max_seqlen_q: int | None,
- max_seqlen_k: int | None,
- scale: float | None,
- is_causal: bool,
- window_size_left: int,
- window_size_right: int,
- deterministic: bool = False,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- if _FA3_CUDA_BWD is None:
- raise RuntimeError("FA3 not registered")
- # Ensure contiguous
- dout = _maybe_contiguous(grad_out)
- q = query.contiguous() if query.stride(-1) != 1 else query
- k = key.contiguous() if key.stride(-1) != 1 else key
- v = value.contiguous() if value.stride(-1) != 1 else value
- o = _maybe_contiguous(out)
- lse = _maybe_contiguous(logsumexp)
- # Pre-allocate gradient tensors
- dq = torch.empty_like(q)
- dk = torch.empty_like(k)
- dv = torch.empty_like(v)
- _FA3_CUDA_BWD(
- dout,
- q,
- k,
- v,
- o,
- lse,
- dq,
- dk,
- dv,
- cu_seq_q,
- cu_seq_k,
- None,
- None,
- max_seqlen_q,
- max_seqlen_k,
- scale,
- is_causal,
- window_size_left,
- window_size_right,
- 0.0,
- deterministic,
- torch._C._get_sm_carveout_experimental() or 0,
- )
- return dq, dk, dv
- def _fa3_flash_attention_forward_impl(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- cum_seq_q: torch.Tensor | None,
- cum_seq_k: torch.Tensor | None,
- max_q: int,
- max_k: int,
- dropout_p: float,
- is_causal: bool,
- return_debug_mask: bool,
- q_descale: torch.Tensor | None = None,
- k_descale: torch.Tensor | None = None,
- v_descale: torch.Tensor | None = None,
- *,
- scale: float | None = None,
- window_size_left: int = -1,
- window_size_right: int = -1,
- seqused_k: torch.Tensor | None = None,
- alibi_slopes: torch.Tensor | None = None,
- out: torch.Tensor | None = None,
- ):
- error = _fa3_forward_support_error(
- query,
- key,
- value,
- dropout_p,
- return_debug_mask,
- alibi_slopes,
- seqused_k,
- cum_seq_q,
- q_descale,
- k_descale,
- v_descale,
- )
- if error is not None:
- raise RuntimeError(f"FA3 flash_attention forward unsupported: {error}")
- out, lse = _fa3_run_forward(
- query,
- key,
- value,
- cum_seq_q,
- cum_seq_k,
- max_q,
- max_k,
- scale,
- is_causal,
- window_size_left,
- window_size_right,
- seqused_k,
- out,
- q_descale,
- k_descale,
- v_descale,
- )
- rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device)
- philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device)
- debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
- return out, lse, rng_state, philox_offset, debug_mask
- def _fa3_flash_attention_forward_impl_default(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- cum_seq_q: torch.Tensor | None,
- cum_seq_k: torch.Tensor | None,
- max_q: int,
- max_k: int,
- dropout_p: float,
- is_causal: bool,
- return_debug_mask: bool,
- *,
- scale: float | None = None,
- window_size_left: int = -1,
- window_size_right: int = -1,
- seqused_k: torch.Tensor | None = None,
- alibi_slopes: torch.Tensor | None = None,
- out: torch.Tensor | None = None,
- ):
- return _fa3_flash_attention_forward_impl(
- query,
- key,
- value,
- cum_seq_q,
- cum_seq_k,
- max_q,
- max_k,
- dropout_p,
- is_causal,
- return_debug_mask,
- None,
- None,
- None,
- scale=scale,
- window_size_left=window_size_left,
- window_size_right=window_size_right,
- seqused_k=seqused_k,
- alibi_slopes=alibi_slopes,
- out=out,
- )
- def _fa3_flash_attention_backward_impl(
- grad_out: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- out: torch.Tensor,
- logsumexp: torch.Tensor,
- cum_seq_q: torch.Tensor | None,
- cum_seq_k: torch.Tensor | None,
- max_q: int,
- max_k: int,
- dropout_p: float,
- is_causal: bool,
- rng_state: torch.Tensor,
- unused: torch.Tensor,
- *,
- scale: float | None = None,
- window_size_left: int | None = None,
- window_size_right: int | None = None,
- ):
- """FA3 implementation of _flash_attention_backward."""
- error = _fa3_backward_support_error(
- grad_out,
- query,
- key,
- value,
- out,
- logsumexp,
- dropout_p,
- cum_seq_q,
- window_size_left,
- window_size_right,
- )
- if error is not None:
- raise RuntimeError(f"FA3 flash_attention backward unsupported: {error}")
- deterministic = torch.are_deterministic_algorithms_enabled()
- dq, dk, dv = _fa3_run_backward(
- grad_out,
- query,
- key,
- value,
- out,
- logsumexp,
- cum_seq_q,
- cum_seq_k,
- max_q,
- max_k,
- scale,
- is_causal,
- window_size_left if window_size_left is not None else -1,
- window_size_right if window_size_right is not None else -1,
- deterministic,
- )
- return dq, dk, dv
- def _fa3_scaled_dot_product_flash_attention_forward_impl(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- q_descale: torch.Tensor | None = None,
- k_descale: torch.Tensor | None = None,
- v_descale: torch.Tensor | None = None,
- dropout_p: float = 0.0,
- is_causal: bool = False,
- return_debug_mask: bool = False,
- *,
- scale: float | None = None,
- ):
- error = _fa3_forward_support_error(
- query,
- key,
- value,
- dropout_p,
- return_debug_mask,
- None,
- None,
- None,
- q_descale,
- k_descale,
- v_descale,
- )
- if error is not None:
- raise RuntimeError(f"FA3 SDPA forward unsupported: {error}")
- q, k, v = _transpose_dense(query, key, value)
- # Pre-allocate output with query's strides (BHSD layout), then create
- # a BSHD view for the kernel. This ensures the returned output has
- # the same memory layout as the input query.
- out_dtype = torch.bfloat16 if query.dtype == torch.float8_e4m3fn else query.dtype
- out_bhsd = torch.empty_like(query, dtype=out_dtype)
- out_bshd = out_bhsd.transpose(1, 2)
- max_q_flash = q.size(1)
- max_k_flash = k.size(1)
- _, lse, rng_state, philox_offset, debug_mask = _fa3_flash_attention_forward_impl(
- q,
- k,
- v,
- None,
- None,
- max_q_flash,
- max_k_flash,
- dropout_p,
- is_causal,
- return_debug_mask,
- scale=scale,
- out=out_bshd,
- q_descale=q_descale,
- k_descale=k_descale,
- v_descale=v_descale,
- )
- max_q = query.size(2)
- max_k = key.size(2)
- return (
- out_bhsd,
- lse,
- None,
- None,
- max_q,
- max_k,
- rng_state,
- philox_offset,
- debug_mask,
- )
- def _fa3_scaled_dot_product_flash_attention_forward_impl_default(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- dropout_p: float = 0.0,
- is_causal: bool = False,
- return_debug_mask: bool = False,
- *,
- scale: float | None = None,
- ):
- return _fa3_scaled_dot_product_flash_attention_forward_impl(
- query,
- key,
- value,
- None,
- None,
- None,
- dropout_p,
- is_causal,
- return_debug_mask,
- scale=scale,
- )
- def _fa3_scaled_dot_product_flash_attention_backward_impl(
- grad_out: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- out: torch.Tensor,
- logsumexp: torch.Tensor,
- cum_seq_q: torch.Tensor | None,
- cum_seq_k: torch.Tensor | None,
- max_q: int,
- max_k: int,
- dropout_p: float,
- is_causal: bool,
- philox_seed: torch.Tensor,
- philox_offset: torch.Tensor,
- *,
- scale: float | None = None,
- ):
- """FA3 implementation of _scaled_dot_product_flash_attention_backward."""
- error = _fa3_backward_support_error(
- grad_out, query, key, value, out, logsumexp, dropout_p, None, None, None
- )
- if error is not None:
- raise RuntimeError(f"FA3 SDPA backward unsupported: {error}")
- # SDPA uses BHSD layout, FA3 uses BSHD - transpose
- grad_out_t, q_t, k_t, v_t, out_t = _transpose_dense(
- grad_out, query, key, value, out
- )
- dq, dk, dv = _fa3_flash_attention_backward_impl(
- grad_out_t,
- q_t,
- k_t,
- v_t,
- out_t,
- logsumexp,
- None, # cum_seq_q (dense attention)
- None, # cum_seq_k
- max_q, # max_seqlen_q
- max_k, # max_seqlen_k
- dropout_p,
- is_causal,
- philox_seed,
- philox_offset,
- scale=scale,
- )
- # Transpose gradients back to BHSD layout
- dq_out, dk_out, dv_out = _transpose_dense(dq, dk, dv)
- return dq_out, dk_out, dv_out
- _registry.register_flash_attention_impl("FA3", register_fn=register_flash_attention_fa3)
|