_fa4.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. """UBER PROTOTYPE!!!"""
  2. # mypy: allow-untyped-defs
  3. from __future__ import annotations
  4. import importlib
  5. from dataclasses import dataclass
  6. from functools import cache
  7. from typing import Any, TYPE_CHECKING
  8. from typing_extensions import TypeVarTuple, Unpack
  9. from . import _registry
  10. if TYPE_CHECKING:
  11. from types import ModuleType
  12. import torch
  13. from torch.library import Library
  14. __all__ = [
  15. "register_flash_attention_fa4",
  16. ]
  17. _FA4_MODULE_PATH: str | None = None
  18. @dataclass
  19. class _FA4Handle:
  20. library: Library | None
  21. def remove(self) -> None:
  22. self.library = None
  23. @cache
  24. def _get_device_major(device: torch.device) -> int:
  25. major, _ = torch.cuda.get_device_capability(device)
  26. return major
  27. def register_flash_attention_fa4(
  28. module_path: str = "flash_attn.cute.interface",
  29. ) -> _FA4Handle:
  30. """
  31. Register FA4 flash attention kernels with the PyTorch dispatcher.
  32. Args:
  33. module_path: Python module path to the FA4 implementation.
  34. """
  35. global _FA4_MODULE_PATH
  36. _ = _fa4_import_module(module_path)
  37. _FA4_MODULE_PATH = module_path
  38. return _FA4Handle(_fa4_register_kernels())
  39. @cache
  40. def _fa4_import_module(module_path: str) -> ModuleType:
  41. module = importlib.import_module(module_path)
  42. if not hasattr(module, "_flash_attn_fwd") or not hasattr(module, "_flash_attn_bwd"):
  43. raise RuntimeError(f"Module '{module_path}' does not expose FA4 kernels")
  44. return module
  45. def _fa4_register_kernels() -> Library:
  46. lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901
  47. lib.impl("_flash_attention_forward", _fa4_flash_attention_forward_impl, "CUDA")
  48. lib.impl("_flash_attention_backward", _fa4_flash_attention_backward_impl, "CUDA")
  49. lib.impl(
  50. "_scaled_dot_product_flash_attention",
  51. _fa4_scaled_dot_product_flash_attention_forward_impl,
  52. "CUDA",
  53. )
  54. lib.impl(
  55. "_scaled_dot_product_flash_attention_backward",
  56. _fa4_scaled_dot_product_flash_attention_backward_impl,
  57. "CUDA",
  58. )
  59. return lib
  60. def _fa4_common_support_error(
  61. query: torch.Tensor,
  62. tensors: tuple[torch.Tensor, ...],
  63. cum_seq_q: torch.Tensor | None,
  64. require_fp32: tuple[tuple[str, torch.Tensor], ...] = (),
  65. ) -> str | None:
  66. if not all(t.is_cuda for t in tensors):
  67. return "inputs must be CUDA tensors"
  68. if len({t.device for t in tensors}) != 1:
  69. return "inputs must share device"
  70. if query.dtype not in (torch.float16, torch.bfloat16):
  71. return "query dtype must be float16 or bfloat16"
  72. for name, tensor in require_fp32:
  73. if tensor.dtype != torch.float32:
  74. return f"{name} dtype must be float32"
  75. if cum_seq_q is None and query.dim() != 4:
  76. return "dense query must be 4D"
  77. if cum_seq_q is not None and query.dim() != 3:
  78. return "ragged query must be 3D"
  79. if not torch.cuda.is_available():
  80. return "CUDA not available"
  81. if _get_device_major(query.device) not in (9, 10):
  82. return "FA4 requires compute capability 9.0 or 10.0"
  83. return None
  84. def _fa4_forward_support_error(
  85. query: torch.Tensor,
  86. key: torch.Tensor,
  87. value: torch.Tensor,
  88. dropout_p: float,
  89. return_debug_mask: bool,
  90. alibi_slopes: torch.Tensor | None,
  91. seqused_k: torch.Tensor | None,
  92. cum_seq_q: torch.Tensor | None,
  93. ) -> str | None:
  94. if dropout_p != 0.0:
  95. return "dropout_p must be 0"
  96. if return_debug_mask:
  97. return "return_debug_mask must be False"
  98. if alibi_slopes is not None:
  99. return "alibi_slopes not supported"
  100. if seqused_k is not None:
  101. if seqused_k.dtype != torch.int32:
  102. return "seqused_k must be int32"
  103. if not seqused_k.is_cuda:
  104. return "seqused_k must be CUDA"
  105. error = _fa4_common_support_error(
  106. query,
  107. (query, key, value),
  108. cum_seq_q,
  109. )
  110. if error is not None:
  111. if error == "inputs must share device":
  112. return "query, key, value must be on same device"
  113. return error
  114. return None
  115. def _fa4_backward_support_error(
  116. grad_out: torch.Tensor,
  117. query: torch.Tensor,
  118. key: torch.Tensor,
  119. value: torch.Tensor,
  120. out: torch.Tensor,
  121. logsumexp: torch.Tensor,
  122. dropout_p: float,
  123. cum_seq_q: torch.Tensor | None,
  124. window_size_left: int | None,
  125. window_size_right: int | None,
  126. ) -> str | None:
  127. if dropout_p != 0.0:
  128. return "dropout_p must be 0"
  129. if window_size_left is not None or window_size_right is not None:
  130. return "windowed attention not supported"
  131. error = _fa4_common_support_error(
  132. query,
  133. (grad_out, query, key, value, out, logsumexp),
  134. cum_seq_q,
  135. require_fp32=(("logsumexp", logsumexp),),
  136. )
  137. if error is not None:
  138. return error
  139. return None
  140. Ts = TypeVarTuple("Ts")
  141. def _transpose_dense(*tensors: Unpack[Ts]) -> tuple[Unpack[Ts]]:
  142. return tuple(t.transpose(1, 2) for t in tensors) # type: ignore[attr-defined]
  143. def _fa4_run_forward(
  144. query: torch.Tensor,
  145. key: torch.Tensor,
  146. value: torch.Tensor,
  147. cu_seq_q: torch.Tensor | None,
  148. cu_seq_k: torch.Tensor | None,
  149. scale: float | None,
  150. is_causal: bool,
  151. window_size_left: int | None,
  152. window_size_right: int | None,
  153. seqused_k: torch.Tensor | None,
  154. out: torch.Tensor | None = None,
  155. ) -> tuple[torch.Tensor, torch.Tensor]:
  156. if _FA4_MODULE_PATH is None:
  157. raise RuntimeError("FA4 not registered")
  158. module = _fa4_import_module(_FA4_MODULE_PATH)
  159. kwargs: dict[str, Any] = {
  160. "softmax_scale": scale,
  161. "causal": is_causal,
  162. "window_size_left": window_size_left,
  163. "window_size_right": window_size_right,
  164. "return_lse": True,
  165. "cu_seqlens_q": cu_seq_q,
  166. "cu_seqlens_k": cu_seq_k,
  167. "seqused_k": seqused_k.contiguous() if seqused_k is not None else None,
  168. }
  169. if out is not None:
  170. kwargs["out"] = out
  171. out, lse = module._flash_attn_fwd(query, key, value, **kwargs)
  172. return out, lse.contiguous()
  173. def _fa4_run_backward(
  174. grad_out: torch.Tensor,
  175. query: torch.Tensor,
  176. key: torch.Tensor,
  177. value: torch.Tensor,
  178. out: torch.Tensor,
  179. logsumexp: torch.Tensor,
  180. cu_seq_q: torch.Tensor | None,
  181. cu_seq_k: torch.Tensor | None,
  182. scale: float | None,
  183. is_causal: bool,
  184. deterministic: bool = False,
  185. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  186. if _FA4_MODULE_PATH is None:
  187. raise RuntimeError("FA4 not registered")
  188. module = _fa4_import_module(_FA4_MODULE_PATH)
  189. dq, dk, dv = module._flash_attn_bwd(
  190. query,
  191. key,
  192. value,
  193. out,
  194. grad_out,
  195. logsumexp.contiguous(),
  196. softmax_scale=scale,
  197. causal=is_causal,
  198. cu_seqlens_q=cu_seq_q,
  199. cu_seqlens_k=cu_seq_k,
  200. deterministic=deterministic,
  201. )
  202. return dq, dk, dv
  203. def _fa4_flash_attention_forward_impl(
  204. query: torch.Tensor,
  205. key: torch.Tensor,
  206. value: torch.Tensor,
  207. cum_seq_q: torch.Tensor | None,
  208. cum_seq_k: torch.Tensor | None,
  209. max_q: int,
  210. max_k: int,
  211. dropout_p: float,
  212. is_causal: bool,
  213. return_debug_mask: bool,
  214. *,
  215. scale: float | None = None,
  216. window_size_left: int | None = None,
  217. window_size_right: int | None = None,
  218. seqused_k: torch.Tensor | None = None,
  219. alibi_slopes: torch.Tensor | None = None,
  220. out: torch.Tensor | None = None,
  221. ):
  222. error = _fa4_forward_support_error(
  223. query,
  224. key,
  225. value,
  226. dropout_p,
  227. return_debug_mask,
  228. alibi_slopes,
  229. seqused_k,
  230. cum_seq_q,
  231. )
  232. if error is not None:
  233. raise RuntimeError(f"FA4 flash_attention forward unsupported: {error}")
  234. out, lse = _fa4_run_forward(
  235. query,
  236. key,
  237. value,
  238. cum_seq_q,
  239. cum_seq_k,
  240. scale,
  241. is_causal,
  242. window_size_left,
  243. window_size_right,
  244. seqused_k,
  245. out,
  246. )
  247. rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device)
  248. philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device)
  249. debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
  250. return out, lse, rng_state, philox_offset, debug_mask
  251. def _fa4_flash_attention_backward_impl(
  252. grad_out: torch.Tensor,
  253. query: torch.Tensor,
  254. key: torch.Tensor,
  255. value: torch.Tensor,
  256. out: torch.Tensor,
  257. logsumexp: torch.Tensor,
  258. cum_seq_q: torch.Tensor | None,
  259. cum_seq_k: torch.Tensor | None,
  260. max_q: int,
  261. max_k: int,
  262. dropout_p: float,
  263. is_causal: bool,
  264. rng_state: torch.Tensor,
  265. unused: torch.Tensor,
  266. *,
  267. scale: float | None = None,
  268. window_size_left: int | None = None,
  269. window_size_right: int | None = None,
  270. ):
  271. error = _fa4_backward_support_error(
  272. grad_out,
  273. query,
  274. key,
  275. value,
  276. out,
  277. logsumexp,
  278. dropout_p,
  279. cum_seq_q,
  280. window_size_left,
  281. window_size_right,
  282. )
  283. if error is not None:
  284. raise RuntimeError(f"FA4 flash_attention backward unsupported: {error}")
  285. deterministic = torch.are_deterministic_algorithms_enabled()
  286. dq, dk, dv = _fa4_run_backward(
  287. grad_out,
  288. query,
  289. key,
  290. value,
  291. out,
  292. logsumexp,
  293. cum_seq_q,
  294. cum_seq_k,
  295. scale,
  296. is_causal,
  297. deterministic,
  298. )
  299. return dq, dk, dv
  300. def _fa4_scaled_dot_product_flash_attention_forward_impl(
  301. query: torch.Tensor,
  302. key: torch.Tensor,
  303. value: torch.Tensor,
  304. dropout_p: float = 0.0,
  305. is_causal: bool = False,
  306. return_debug_mask: bool = False,
  307. *,
  308. scale: float | None = None,
  309. ):
  310. error = _fa4_forward_support_error(
  311. query,
  312. key,
  313. value,
  314. dropout_p,
  315. return_debug_mask,
  316. None,
  317. None,
  318. None,
  319. )
  320. if error is not None:
  321. raise RuntimeError(f"FA4 SDPA forward unsupported: {error}")
  322. q, k, v = _transpose_dense(query, key, value)
  323. # Pre-allocate output with query's strides (BHSD layout), then create
  324. # a BSHD view for the kernel. This ensures the returned output has
  325. # the same memory layout as the input query.
  326. out_bhsd = torch.empty_like(query)
  327. out_bshd = out_bhsd.transpose(1, 2)
  328. max_q_flash = q.size(1)
  329. max_k_flash = k.size(1)
  330. _, lse, rng_state, philox_offset, debug_mask = _fa4_flash_attention_forward_impl(
  331. q,
  332. k,
  333. v,
  334. None,
  335. None,
  336. max_q_flash,
  337. max_k_flash,
  338. dropout_p,
  339. is_causal,
  340. return_debug_mask,
  341. scale=scale,
  342. out=out_bshd,
  343. )
  344. max_q = query.size(2)
  345. max_k = key.size(2)
  346. return (
  347. out_bhsd,
  348. lse,
  349. None,
  350. None,
  351. max_q,
  352. max_k,
  353. rng_state,
  354. philox_offset,
  355. debug_mask,
  356. )
  357. def _fa4_scaled_dot_product_flash_attention_backward_impl(
  358. grad_out: torch.Tensor,
  359. query: torch.Tensor,
  360. key: torch.Tensor,
  361. value: torch.Tensor,
  362. out: torch.Tensor,
  363. logsumexp: torch.Tensor,
  364. cum_seq_q: torch.Tensor | None,
  365. cum_seq_k: torch.Tensor | None,
  366. max_q: int,
  367. max_k: int,
  368. dropout_p: float,
  369. is_causal: bool,
  370. philox_seed: torch.Tensor,
  371. philox_offset: torch.Tensor,
  372. *,
  373. scale: float | None = None,
  374. ):
  375. error = _fa4_backward_support_error(
  376. grad_out,
  377. query,
  378. key,
  379. value,
  380. out,
  381. logsumexp,
  382. dropout_p,
  383. None,
  384. None,
  385. None,
  386. )
  387. if error is not None:
  388. raise RuntimeError(f"FA4 SDPA backward unsupported: {error}")
  389. q, k, v, o, go = _transpose_dense(query, key, value, out, grad_out)
  390. max_q = query.size(2)
  391. max_k = key.size(2)
  392. dq, dk, dv = _fa4_flash_attention_backward_impl(
  393. go,
  394. q,
  395. k,
  396. v,
  397. o,
  398. logsumexp,
  399. None,
  400. None,
  401. max_q,
  402. max_k,
  403. dropout_p,
  404. is_causal,
  405. philox_seed,
  406. philox_offset,
  407. scale=scale,
  408. )
  409. dq, dk, dv = _transpose_dense(dq, dk, dv)
  410. return dq, dk, dv
  411. _registry.register_flash_attention_impl("FA4", register_fn=register_flash_attention_fa4)