| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460 |
- """UBER PROTOTYPE!!!"""
- # mypy: allow-untyped-defs
- from __future__ import annotations
- import importlib
- from dataclasses import dataclass
- from functools import cache
- from typing import Any, TYPE_CHECKING
- from typing_extensions import TypeVarTuple, Unpack
- from . import _registry
- if TYPE_CHECKING:
- from types import ModuleType
- import torch
- from torch.library import Library
- __all__ = [
- "register_flash_attention_fa4",
- ]
- _FA4_MODULE_PATH: str | None = None
- @dataclass
- class _FA4Handle:
- library: Library | None
- def remove(self) -> None:
- self.library = None
- @cache
- def _get_device_major(device: torch.device) -> int:
- major, _ = torch.cuda.get_device_capability(device)
- return major
- def register_flash_attention_fa4(
- module_path: str = "flash_attn.cute.interface",
- ) -> _FA4Handle:
- """
- Register FA4 flash attention kernels with the PyTorch dispatcher.
- Args:
- module_path: Python module path to the FA4 implementation.
- """
- global _FA4_MODULE_PATH
- _ = _fa4_import_module(module_path)
- _FA4_MODULE_PATH = module_path
- return _FA4Handle(_fa4_register_kernels())
- @cache
- def _fa4_import_module(module_path: str) -> ModuleType:
- module = importlib.import_module(module_path)
- if not hasattr(module, "_flash_attn_fwd") or not hasattr(module, "_flash_attn_bwd"):
- raise RuntimeError(f"Module '{module_path}' does not expose FA4 kernels")
- return module
- def _fa4_register_kernels() -> Library:
- lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901
- lib.impl("_flash_attention_forward", _fa4_flash_attention_forward_impl, "CUDA")
- lib.impl("_flash_attention_backward", _fa4_flash_attention_backward_impl, "CUDA")
- lib.impl(
- "_scaled_dot_product_flash_attention",
- _fa4_scaled_dot_product_flash_attention_forward_impl,
- "CUDA",
- )
- lib.impl(
- "_scaled_dot_product_flash_attention_backward",
- _fa4_scaled_dot_product_flash_attention_backward_impl,
- "CUDA",
- )
- return lib
- def _fa4_common_support_error(
- query: torch.Tensor,
- tensors: tuple[torch.Tensor, ...],
- cum_seq_q: torch.Tensor | None,
- require_fp32: tuple[tuple[str, torch.Tensor], ...] = (),
- ) -> str | None:
- if not all(t.is_cuda for t in tensors):
- return "inputs must be CUDA tensors"
- if len({t.device for t in tensors}) != 1:
- return "inputs must share device"
- if query.dtype not in (torch.float16, torch.bfloat16):
- return "query dtype must be float16 or bfloat16"
- for name, tensor in require_fp32:
- if tensor.dtype != torch.float32:
- return f"{name} dtype must be float32"
- if cum_seq_q is None and query.dim() != 4:
- return "dense query must be 4D"
- if cum_seq_q is not None and query.dim() != 3:
- return "ragged query must be 3D"
- if not torch.cuda.is_available():
- return "CUDA not available"
- if _get_device_major(query.device) not in (9, 10):
- return "FA4 requires compute capability 9.0 or 10.0"
- return None
- def _fa4_forward_support_error(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- dropout_p: float,
- return_debug_mask: bool,
- alibi_slopes: torch.Tensor | None,
- seqused_k: torch.Tensor | None,
- cum_seq_q: torch.Tensor | None,
- ) -> str | None:
- if dropout_p != 0.0:
- return "dropout_p must be 0"
- if return_debug_mask:
- return "return_debug_mask must be False"
- if alibi_slopes is not None:
- return "alibi_slopes not supported"
- if seqused_k is not None:
- if seqused_k.dtype != torch.int32:
- return "seqused_k must be int32"
- if not seqused_k.is_cuda:
- return "seqused_k must be CUDA"
- error = _fa4_common_support_error(
- query,
- (query, key, value),
- cum_seq_q,
- )
- if error is not None:
- if error == "inputs must share device":
- return "query, key, value must be on same device"
- return error
- return None
- def _fa4_backward_support_error(
- grad_out: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- out: torch.Tensor,
- logsumexp: torch.Tensor,
- dropout_p: float,
- cum_seq_q: torch.Tensor | None,
- window_size_left: int | None,
- window_size_right: int | None,
- ) -> str | None:
- if dropout_p != 0.0:
- return "dropout_p must be 0"
- if window_size_left is not None or window_size_right is not None:
- return "windowed attention not supported"
- error = _fa4_common_support_error(
- query,
- (grad_out, query, key, value, out, logsumexp),
- cum_seq_q,
- require_fp32=(("logsumexp", logsumexp),),
- )
- if error is not None:
- return error
- return None
- Ts = TypeVarTuple("Ts")
- def _transpose_dense(*tensors: Unpack[Ts]) -> tuple[Unpack[Ts]]:
- return tuple(t.transpose(1, 2) for t in tensors) # type: ignore[attr-defined]
- def _fa4_run_forward(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- cu_seq_q: torch.Tensor | None,
- cu_seq_k: torch.Tensor | None,
- scale: float | None,
- is_causal: bool,
- window_size_left: int | None,
- window_size_right: int | None,
- seqused_k: torch.Tensor | None,
- out: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- if _FA4_MODULE_PATH is None:
- raise RuntimeError("FA4 not registered")
- module = _fa4_import_module(_FA4_MODULE_PATH)
- kwargs: dict[str, Any] = {
- "softmax_scale": scale,
- "causal": is_causal,
- "window_size_left": window_size_left,
- "window_size_right": window_size_right,
- "return_lse": True,
- "cu_seqlens_q": cu_seq_q,
- "cu_seqlens_k": cu_seq_k,
- "seqused_k": seqused_k.contiguous() if seqused_k is not None else None,
- }
- if out is not None:
- kwargs["out"] = out
- out, lse = module._flash_attn_fwd(query, key, value, **kwargs)
- return out, lse.contiguous()
- def _fa4_run_backward(
- grad_out: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- out: torch.Tensor,
- logsumexp: torch.Tensor,
- cu_seq_q: torch.Tensor | None,
- cu_seq_k: torch.Tensor | None,
- scale: float | None,
- is_causal: bool,
- deterministic: bool = False,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- if _FA4_MODULE_PATH is None:
- raise RuntimeError("FA4 not registered")
- module = _fa4_import_module(_FA4_MODULE_PATH)
- dq, dk, dv = module._flash_attn_bwd(
- query,
- key,
- value,
- out,
- grad_out,
- logsumexp.contiguous(),
- softmax_scale=scale,
- causal=is_causal,
- cu_seqlens_q=cu_seq_q,
- cu_seqlens_k=cu_seq_k,
- deterministic=deterministic,
- )
- return dq, dk, dv
- def _fa4_flash_attention_forward_impl(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- cum_seq_q: torch.Tensor | None,
- cum_seq_k: torch.Tensor | None,
- max_q: int,
- max_k: int,
- dropout_p: float,
- is_causal: bool,
- return_debug_mask: bool,
- *,
- scale: float | None = None,
- window_size_left: int | None = None,
- window_size_right: int | None = None,
- seqused_k: torch.Tensor | None = None,
- alibi_slopes: torch.Tensor | None = None,
- out: torch.Tensor | None = None,
- ):
- error = _fa4_forward_support_error(
- query,
- key,
- value,
- dropout_p,
- return_debug_mask,
- alibi_slopes,
- seqused_k,
- cum_seq_q,
- )
- if error is not None:
- raise RuntimeError(f"FA4 flash_attention forward unsupported: {error}")
- out, lse = _fa4_run_forward(
- query,
- key,
- value,
- cum_seq_q,
- cum_seq_k,
- scale,
- is_causal,
- window_size_left,
- window_size_right,
- seqused_k,
- out,
- )
- rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device)
- philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device)
- debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
- return out, lse, rng_state, philox_offset, debug_mask
- def _fa4_flash_attention_backward_impl(
- grad_out: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- out: torch.Tensor,
- logsumexp: torch.Tensor,
- cum_seq_q: torch.Tensor | None,
- cum_seq_k: torch.Tensor | None,
- max_q: int,
- max_k: int,
- dropout_p: float,
- is_causal: bool,
- rng_state: torch.Tensor,
- unused: torch.Tensor,
- *,
- scale: float | None = None,
- window_size_left: int | None = None,
- window_size_right: int | None = None,
- ):
- error = _fa4_backward_support_error(
- grad_out,
- query,
- key,
- value,
- out,
- logsumexp,
- dropout_p,
- cum_seq_q,
- window_size_left,
- window_size_right,
- )
- if error is not None:
- raise RuntimeError(f"FA4 flash_attention backward unsupported: {error}")
- deterministic = torch.are_deterministic_algorithms_enabled()
- dq, dk, dv = _fa4_run_backward(
- grad_out,
- query,
- key,
- value,
- out,
- logsumexp,
- cum_seq_q,
- cum_seq_k,
- scale,
- is_causal,
- deterministic,
- )
- return dq, dk, dv
- def _fa4_scaled_dot_product_flash_attention_forward_impl(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- dropout_p: float = 0.0,
- is_causal: bool = False,
- return_debug_mask: bool = False,
- *,
- scale: float | None = None,
- ):
- error = _fa4_forward_support_error(
- query,
- key,
- value,
- dropout_p,
- return_debug_mask,
- None,
- None,
- None,
- )
- if error is not None:
- raise RuntimeError(f"FA4 SDPA forward unsupported: {error}")
- q, k, v = _transpose_dense(query, key, value)
- # Pre-allocate output with query's strides (BHSD layout), then create
- # a BSHD view for the kernel. This ensures the returned output has
- # the same memory layout as the input query.
- out_bhsd = torch.empty_like(query)
- out_bshd = out_bhsd.transpose(1, 2)
- max_q_flash = q.size(1)
- max_k_flash = k.size(1)
- _, lse, rng_state, philox_offset, debug_mask = _fa4_flash_attention_forward_impl(
- q,
- k,
- v,
- None,
- None,
- max_q_flash,
- max_k_flash,
- dropout_p,
- is_causal,
- return_debug_mask,
- scale=scale,
- out=out_bshd,
- )
- max_q = query.size(2)
- max_k = key.size(2)
- return (
- out_bhsd,
- lse,
- None,
- None,
- max_q,
- max_k,
- rng_state,
- philox_offset,
- debug_mask,
- )
- def _fa4_scaled_dot_product_flash_attention_backward_impl(
- grad_out: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- out: torch.Tensor,
- logsumexp: torch.Tensor,
- cum_seq_q: torch.Tensor | None,
- cum_seq_k: torch.Tensor | None,
- max_q: int,
- max_k: int,
- dropout_p: float,
- is_causal: bool,
- philox_seed: torch.Tensor,
- philox_offset: torch.Tensor,
- *,
- scale: float | None = None,
- ):
- error = _fa4_backward_support_error(
- grad_out,
- query,
- key,
- value,
- out,
- logsumexp,
- dropout_p,
- None,
- None,
- None,
- )
- if error is not None:
- raise RuntimeError(f"FA4 SDPA backward unsupported: {error}")
- q, k, v, o, go = _transpose_dense(query, key, value, out, grad_out)
- max_q = query.size(2)
- max_k = key.size(2)
- dq, dk, dv = _fa4_flash_attention_backward_impl(
- go,
- q,
- k,
- v,
- o,
- logsumexp,
- None,
- None,
- max_q,
- max_k,
- dropout_p,
- is_causal,
- philox_seed,
- philox_offset,
- scale=scale,
- )
- dq, dk, dv = _transpose_dense(dq, dk, dv)
- return dq, dk, dv
- _registry.register_flash_attention_impl("FA4", register_fn=register_flash_attention_fa4)
|