_fa3.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. """
  2. PROTOTYPE!
  3. Flash Attention 3 implementation.
  4. For fp8: only supports forward pass right now.
  5. For fp16/bf16: supports forward and backward pass.
  6. """
  7. # mypy: allow-untyped-defs
  8. from __future__ import annotations
  9. import importlib
  10. import warnings
  11. from typing import TYPE_CHECKING
  12. if TYPE_CHECKING:
  13. from collections.abc import Callable
  14. from dataclasses import dataclass
  15. from functools import cache
  16. from typing_extensions import TypeVarTuple, Unpack
  17. import torch
  18. from torch.library import Library
  19. from . import _registry
  20. __all__ = [
  21. "register_flash_attention_fa3",
  22. ]
  23. _FA3_CUDA_FWD: Callable | None = None # Cache for torch.ops.flash_attn_3.fwd
  24. _FA3_CUDA_BWD: Callable | None = None # Cache for torch.ops.flash_attn_3.bwd
  25. @dataclass
  26. class _FA3Handle:
  27. library: Library | None
  28. def remove(self) -> None:
  29. self.library = None
  30. # Clear the C++ flag
  31. torch._C._set_sdp_use_fa3(False)
  32. @cache
  33. def _get_device_major(device: torch.device) -> int:
  34. major, _ = torch.cuda.get_device_capability(device)
  35. return major
  36. def register_flash_attention_fa3(
  37. module_path: str = "flash_attn_interface",
  38. ) -> _FA3Handle:
  39. """
  40. Register FA3 flash attention kernels with the PyTorch dispatcher.
  41. Args:
  42. module_path: Python module path to the FA3 implementation.
  43. """
  44. _fa3_import_module(module_path)
  45. # Expose FA3 registration status to C++
  46. torch._C._set_sdp_use_fa3(True)
  47. return _FA3Handle(_fa3_register_kernels())
  48. def _fa3_import_module(module_path: str) -> None:
  49. importlib.import_module(module_path)
  50. if not hasattr(torch.ops, "flash_attn_3"):
  51. raise RuntimeError(f"Module '{module_path}' does not expose FA3 kernels")
  52. if not hasattr(torch.ops.flash_attn_3, "fwd"):
  53. raise RuntimeError(
  54. f"Module '{module_path}' does not expose FA3 forward kernels"
  55. )
  56. if not hasattr(torch.ops.flash_attn_3, "bwd"):
  57. raise RuntimeError(
  58. f"Module '{module_path}' does not expose FA3 backward kernels"
  59. )
  60. global _FA3_CUDA_FWD, _FA3_CUDA_BWD
  61. _FA3_CUDA_FWD = torch.ops.flash_attn_3.fwd
  62. _FA3_CUDA_BWD = torch.ops.flash_attn_3.bwd
  63. def _fa3_register_kernels() -> Library:
  64. lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901
  65. lib.impl(
  66. "_flash_attention_forward.quantized", _fa3_flash_attention_forward_impl, "CUDA"
  67. )
  68. lib.impl(
  69. "_scaled_dot_product_flash_attention.quantized",
  70. _fa3_scaled_dot_product_flash_attention_forward_impl,
  71. "CUDA",
  72. )
  73. lib.impl(
  74. "_flash_attention_forward", _fa3_flash_attention_forward_impl_default, "CUDA"
  75. )
  76. lib.impl(
  77. "_scaled_dot_product_flash_attention",
  78. _fa3_scaled_dot_product_flash_attention_forward_impl_default,
  79. "CUDA",
  80. )
  81. lib.impl("_flash_attention_backward", _fa3_flash_attention_backward_impl, "CUDA")
  82. lib.impl(
  83. "_scaled_dot_product_flash_attention_backward",
  84. _fa3_scaled_dot_product_flash_attention_backward_impl,
  85. "CUDA",
  86. )
  87. return lib
  88. def _fa3_common_support_error(
  89. query: torch.Tensor,
  90. tensors: tuple[torch.Tensor, ...],
  91. dropout_p: float,
  92. cum_seq_q: torch.Tensor | None,
  93. q_descale: torch.Tensor | None,
  94. k_descale: torch.Tensor | None,
  95. v_descale: torch.Tensor | None,
  96. ) -> str | None:
  97. if dropout_p != 0.0:
  98. return "dropout_p must be 0"
  99. if not all(t.is_cuda for t in tensors):
  100. return "inputs must be CUDA tensors"
  101. if len({t.device for t in tensors}) != 1:
  102. return "inputs must share device"
  103. if query.dtype == torch.float8_e4m3fn and (
  104. q_descale is None or k_descale is None or v_descale is None
  105. ):
  106. warnings.warn(
  107. "When using SDPA with fp8, descale tensor should always be used"
  108. " for accurate dequantization. Please use "
  109. "_scaled_dot_product_attention_quantized and "
  110. "provide the descale tensors.",
  111. UserWarning,
  112. )
  113. if cum_seq_q is None and query.dim() != 4:
  114. return "dense query must be 4D"
  115. if cum_seq_q is not None and query.dim() != 3:
  116. return "ragged query must be 3D"
  117. if not torch.cuda.is_available():
  118. return "CUDA not available"
  119. if _get_device_major(query.device) != 9:
  120. return "FA3 requires compute capability 9.0"
  121. return None
  122. def _fa3_forward_support_error(
  123. query: torch.Tensor,
  124. key: torch.Tensor,
  125. value: torch.Tensor,
  126. dropout_p: float,
  127. return_debug_mask: bool,
  128. alibi_slopes: torch.Tensor | None,
  129. seqused_k: torch.Tensor | None,
  130. cum_seq_q: torch.Tensor | None,
  131. q_descale: torch.Tensor | None,
  132. k_descale: torch.Tensor | None,
  133. v_descale: torch.Tensor | None,
  134. ) -> str | None:
  135. if return_debug_mask:
  136. return "return_debug_mask must be False"
  137. if alibi_slopes is not None:
  138. return "alibi_slopes not supported"
  139. if seqused_k is not None:
  140. if seqused_k.dtype != torch.int32:
  141. return "seqused_k must be int32"
  142. if not seqused_k.is_cuda:
  143. return "seqused_k must be CUDA"
  144. supported_dtypes = (torch.float8_e4m3fn, torch.float16, torch.bfloat16)
  145. if not all(t.dtype in supported_dtypes for t in {query, key, value}):
  146. return f"inputs must be one of {supported_dtypes}"
  147. if len({t.dtype for t in {query, key, value}}) != 1:
  148. return "all inputs must have the same dtype"
  149. error = _fa3_common_support_error(
  150. query,
  151. (query, key, value),
  152. dropout_p,
  153. cum_seq_q,
  154. q_descale,
  155. k_descale,
  156. v_descale,
  157. )
  158. if error is not None:
  159. if error == "inputs must share device":
  160. return "query, key, value must be on same device"
  161. return error
  162. return None
  163. def _fa3_backward_support_error(
  164. grad_out: torch.Tensor,
  165. query: torch.Tensor,
  166. key: torch.Tensor,
  167. value: torch.Tensor,
  168. out: torch.Tensor,
  169. logsumexp: torch.Tensor,
  170. dropout_p: float,
  171. cum_seq_q: torch.Tensor | None,
  172. window_size_left: int | None,
  173. window_size_right: int | None,
  174. ) -> str | None:
  175. # FA3 backward ONLY supports fp16/bf16, NOT fp8
  176. if query.dtype == torch.float8_e4m3fn:
  177. return (
  178. "FA3 backward does not support fp8 - use inference only (torch.no_grad())"
  179. )
  180. if logsumexp.dtype != torch.float32:
  181. return "logsumexp dtype must be float32"
  182. supported_dtypes = (torch.float16, torch.bfloat16)
  183. if not all(t.dtype in supported_dtypes for t in {grad_out, query, key, value, out}):
  184. return f"inputs must be one of {supported_dtypes}"
  185. if len({t.dtype for t in {grad_out, query, key, value, out}}) != 1:
  186. return "all inputs must have the same dtype"
  187. error = _fa3_common_support_error(
  188. query,
  189. (grad_out, query, key, value, out, logsumexp),
  190. dropout_p,
  191. cum_seq_q,
  192. None,
  193. None,
  194. None,
  195. )
  196. if error is not None:
  197. return error
  198. return None
  199. Ts = TypeVarTuple("Ts")
  200. def _transpose_dense(*tensors: Unpack[Ts]) -> tuple[Unpack[Ts]]:
  201. return tuple(t.transpose(1, 2) for t in tensors) # type: ignore[attr-defined]
  202. def _maybe_contiguous(x: torch.Tensor | None) -> torch.Tensor | None:
  203. """Ensure tensor is contiguous in the last dimension."""
  204. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  205. def _fa3_run_forward(
  206. query: torch.Tensor,
  207. key: torch.Tensor,
  208. value: torch.Tensor,
  209. cu_seq_q: torch.Tensor | None,
  210. cu_seq_k: torch.Tensor | None,
  211. max_q: int,
  212. max_k: int,
  213. scale: float | None,
  214. is_causal: bool,
  215. window_size_left: int | None,
  216. window_size_right: int | None,
  217. seqused_k: torch.Tensor | None,
  218. out: torch.Tensor | None = None,
  219. q_descale: torch.Tensor | None = None,
  220. k_descale: torch.Tensor | None = None,
  221. v_descale: torch.Tensor | None = None,
  222. ) -> tuple[torch.Tensor, torch.Tensor]:
  223. """
  224. Run the FA3 forward pass by calling the C++ kernel directly.
  225. """
  226. if _FA3_CUDA_FWD is None:
  227. raise RuntimeError("FA3 not registered")
  228. # Ensure contiguous in the last dimension
  229. q = _maybe_contiguous(query)
  230. k = _maybe_contiguous(key)
  231. v = (
  232. value.contiguous()
  233. if value.dtype == torch.float8_e4m3fn
  234. and value.stride(-1) != 1
  235. and value.stride(-3) != 1
  236. else _maybe_contiguous(value)
  237. )
  238. cu_seqlens_q = _maybe_contiguous(cu_seq_q)
  239. cu_seqlens_k = _maybe_contiguous(cu_seq_k)
  240. seqused_k = _maybe_contiguous(seqused_k)
  241. out, softmax_lse, out_accum, softmax_lse_accum = _FA3_CUDA_FWD(
  242. q,
  243. k,
  244. v,
  245. None, # k_new
  246. None, # v_new
  247. None, # qv
  248. out, # out_ (pre-allocated output)
  249. cu_seqlens_q, # cu_seqlens_q
  250. cu_seqlens_k, # cu_seqlens_k
  251. None, # cu_seqlens_k_new
  252. None, # seqused_q
  253. seqused_k, # seqused_k
  254. max_q, # max_seqlen_q
  255. max_k, # max_seqlen_k
  256. None, # page_table,
  257. None, # kv_batch_idx,
  258. None, # leftpad_k,
  259. None, # rotary_cos,
  260. None, # rotary_sin,
  261. None, # seqlens_rotary,
  262. q_descale, # q_descale,
  263. k_descale, # k_descale,
  264. v_descale, # v_descale,
  265. scale, # softmax_scale,
  266. is_causal, # causal,
  267. window_size_left if window_size_left is not None else -1, # window_size_left
  268. window_size_right if window_size_right is not None else -1, # window_size_right
  269. 0, # attention_chunk,
  270. 0.0, # softcap,
  271. True, # rotary_interleaved,
  272. None, # scheduler_metadata,
  273. 1 if torch.are_deterministic_algorithms_enabled() else 0, # num_splits,
  274. None, # pack_gqa,
  275. torch._C._get_sm_carveout_experimental() or 0, # sm_margin,
  276. )
  277. return out, softmax_lse.contiguous()
  278. def _fa3_run_backward(
  279. grad_out: torch.Tensor,
  280. query: torch.Tensor,
  281. key: torch.Tensor,
  282. value: torch.Tensor,
  283. out: torch.Tensor,
  284. logsumexp: torch.Tensor,
  285. cu_seq_q: torch.Tensor | None,
  286. cu_seq_k: torch.Tensor | None,
  287. max_seqlen_q: int | None,
  288. max_seqlen_k: int | None,
  289. scale: float | None,
  290. is_causal: bool,
  291. window_size_left: int,
  292. window_size_right: int,
  293. deterministic: bool = False,
  294. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  295. if _FA3_CUDA_BWD is None:
  296. raise RuntimeError("FA3 not registered")
  297. # Ensure contiguous
  298. dout = _maybe_contiguous(grad_out)
  299. q = query.contiguous() if query.stride(-1) != 1 else query
  300. k = key.contiguous() if key.stride(-1) != 1 else key
  301. v = value.contiguous() if value.stride(-1) != 1 else value
  302. o = _maybe_contiguous(out)
  303. lse = _maybe_contiguous(logsumexp)
  304. # Pre-allocate gradient tensors
  305. dq = torch.empty_like(q)
  306. dk = torch.empty_like(k)
  307. dv = torch.empty_like(v)
  308. _FA3_CUDA_BWD(
  309. dout,
  310. q,
  311. k,
  312. v,
  313. o,
  314. lse,
  315. dq,
  316. dk,
  317. dv,
  318. cu_seq_q,
  319. cu_seq_k,
  320. None,
  321. None,
  322. max_seqlen_q,
  323. max_seqlen_k,
  324. scale,
  325. is_causal,
  326. window_size_left,
  327. window_size_right,
  328. 0.0,
  329. deterministic,
  330. torch._C._get_sm_carveout_experimental() or 0,
  331. )
  332. return dq, dk, dv
  333. def _fa3_flash_attention_forward_impl(
  334. query: torch.Tensor,
  335. key: torch.Tensor,
  336. value: torch.Tensor,
  337. cum_seq_q: torch.Tensor | None,
  338. cum_seq_k: torch.Tensor | None,
  339. max_q: int,
  340. max_k: int,
  341. dropout_p: float,
  342. is_causal: bool,
  343. return_debug_mask: bool,
  344. q_descale: torch.Tensor | None = None,
  345. k_descale: torch.Tensor | None = None,
  346. v_descale: torch.Tensor | None = None,
  347. *,
  348. scale: float | None = None,
  349. window_size_left: int = -1,
  350. window_size_right: int = -1,
  351. seqused_k: torch.Tensor | None = None,
  352. alibi_slopes: torch.Tensor | None = None,
  353. out: torch.Tensor | None = None,
  354. ):
  355. error = _fa3_forward_support_error(
  356. query,
  357. key,
  358. value,
  359. dropout_p,
  360. return_debug_mask,
  361. alibi_slopes,
  362. seqused_k,
  363. cum_seq_q,
  364. q_descale,
  365. k_descale,
  366. v_descale,
  367. )
  368. if error is not None:
  369. raise RuntimeError(f"FA3 flash_attention forward unsupported: {error}")
  370. out, lse = _fa3_run_forward(
  371. query,
  372. key,
  373. value,
  374. cum_seq_q,
  375. cum_seq_k,
  376. max_q,
  377. max_k,
  378. scale,
  379. is_causal,
  380. window_size_left,
  381. window_size_right,
  382. seqused_k,
  383. out,
  384. q_descale,
  385. k_descale,
  386. v_descale,
  387. )
  388. rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device)
  389. philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device)
  390. debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
  391. return out, lse, rng_state, philox_offset, debug_mask
  392. def _fa3_flash_attention_forward_impl_default(
  393. query: torch.Tensor,
  394. key: torch.Tensor,
  395. value: torch.Tensor,
  396. cum_seq_q: torch.Tensor | None,
  397. cum_seq_k: torch.Tensor | None,
  398. max_q: int,
  399. max_k: int,
  400. dropout_p: float,
  401. is_causal: bool,
  402. return_debug_mask: bool,
  403. *,
  404. scale: float | None = None,
  405. window_size_left: int = -1,
  406. window_size_right: int = -1,
  407. seqused_k: torch.Tensor | None = None,
  408. alibi_slopes: torch.Tensor | None = None,
  409. out: torch.Tensor | None = None,
  410. ):
  411. return _fa3_flash_attention_forward_impl(
  412. query,
  413. key,
  414. value,
  415. cum_seq_q,
  416. cum_seq_k,
  417. max_q,
  418. max_k,
  419. dropout_p,
  420. is_causal,
  421. return_debug_mask,
  422. None,
  423. None,
  424. None,
  425. scale=scale,
  426. window_size_left=window_size_left,
  427. window_size_right=window_size_right,
  428. seqused_k=seqused_k,
  429. alibi_slopes=alibi_slopes,
  430. out=out,
  431. )
  432. def _fa3_flash_attention_backward_impl(
  433. grad_out: torch.Tensor,
  434. query: torch.Tensor,
  435. key: torch.Tensor,
  436. value: torch.Tensor,
  437. out: torch.Tensor,
  438. logsumexp: torch.Tensor,
  439. cum_seq_q: torch.Tensor | None,
  440. cum_seq_k: torch.Tensor | None,
  441. max_q: int,
  442. max_k: int,
  443. dropout_p: float,
  444. is_causal: bool,
  445. rng_state: torch.Tensor,
  446. unused: torch.Tensor,
  447. *,
  448. scale: float | None = None,
  449. window_size_left: int | None = None,
  450. window_size_right: int | None = None,
  451. ):
  452. """FA3 implementation of _flash_attention_backward."""
  453. error = _fa3_backward_support_error(
  454. grad_out,
  455. query,
  456. key,
  457. value,
  458. out,
  459. logsumexp,
  460. dropout_p,
  461. cum_seq_q,
  462. window_size_left,
  463. window_size_right,
  464. )
  465. if error is not None:
  466. raise RuntimeError(f"FA3 flash_attention backward unsupported: {error}")
  467. deterministic = torch.are_deterministic_algorithms_enabled()
  468. dq, dk, dv = _fa3_run_backward(
  469. grad_out,
  470. query,
  471. key,
  472. value,
  473. out,
  474. logsumexp,
  475. cum_seq_q,
  476. cum_seq_k,
  477. max_q,
  478. max_k,
  479. scale,
  480. is_causal,
  481. window_size_left if window_size_left is not None else -1,
  482. window_size_right if window_size_right is not None else -1,
  483. deterministic,
  484. )
  485. return dq, dk, dv
  486. def _fa3_scaled_dot_product_flash_attention_forward_impl(
  487. query: torch.Tensor,
  488. key: torch.Tensor,
  489. value: torch.Tensor,
  490. q_descale: torch.Tensor | None = None,
  491. k_descale: torch.Tensor | None = None,
  492. v_descale: torch.Tensor | None = None,
  493. dropout_p: float = 0.0,
  494. is_causal: bool = False,
  495. return_debug_mask: bool = False,
  496. *,
  497. scale: float | None = None,
  498. ):
  499. error = _fa3_forward_support_error(
  500. query,
  501. key,
  502. value,
  503. dropout_p,
  504. return_debug_mask,
  505. None,
  506. None,
  507. None,
  508. q_descale,
  509. k_descale,
  510. v_descale,
  511. )
  512. if error is not None:
  513. raise RuntimeError(f"FA3 SDPA forward unsupported: {error}")
  514. q, k, v = _transpose_dense(query, key, value)
  515. # Pre-allocate output with query's strides (BHSD layout), then create
  516. # a BSHD view for the kernel. This ensures the returned output has
  517. # the same memory layout as the input query.
  518. out_dtype = torch.bfloat16 if query.dtype == torch.float8_e4m3fn else query.dtype
  519. out_bhsd = torch.empty_like(query, dtype=out_dtype)
  520. out_bshd = out_bhsd.transpose(1, 2)
  521. max_q_flash = q.size(1)
  522. max_k_flash = k.size(1)
  523. _, lse, rng_state, philox_offset, debug_mask = _fa3_flash_attention_forward_impl(
  524. q,
  525. k,
  526. v,
  527. None,
  528. None,
  529. max_q_flash,
  530. max_k_flash,
  531. dropout_p,
  532. is_causal,
  533. return_debug_mask,
  534. scale=scale,
  535. out=out_bshd,
  536. q_descale=q_descale,
  537. k_descale=k_descale,
  538. v_descale=v_descale,
  539. )
  540. max_q = query.size(2)
  541. max_k = key.size(2)
  542. return (
  543. out_bhsd,
  544. lse,
  545. None,
  546. None,
  547. max_q,
  548. max_k,
  549. rng_state,
  550. philox_offset,
  551. debug_mask,
  552. )
  553. def _fa3_scaled_dot_product_flash_attention_forward_impl_default(
  554. query: torch.Tensor,
  555. key: torch.Tensor,
  556. value: torch.Tensor,
  557. dropout_p: float = 0.0,
  558. is_causal: bool = False,
  559. return_debug_mask: bool = False,
  560. *,
  561. scale: float | None = None,
  562. ):
  563. return _fa3_scaled_dot_product_flash_attention_forward_impl(
  564. query,
  565. key,
  566. value,
  567. None,
  568. None,
  569. None,
  570. dropout_p,
  571. is_causal,
  572. return_debug_mask,
  573. scale=scale,
  574. )
  575. def _fa3_scaled_dot_product_flash_attention_backward_impl(
  576. grad_out: torch.Tensor,
  577. query: torch.Tensor,
  578. key: torch.Tensor,
  579. value: torch.Tensor,
  580. out: torch.Tensor,
  581. logsumexp: torch.Tensor,
  582. cum_seq_q: torch.Tensor | None,
  583. cum_seq_k: torch.Tensor | None,
  584. max_q: int,
  585. max_k: int,
  586. dropout_p: float,
  587. is_causal: bool,
  588. philox_seed: torch.Tensor,
  589. philox_offset: torch.Tensor,
  590. *,
  591. scale: float | None = None,
  592. ):
  593. """FA3 implementation of _scaled_dot_product_flash_attention_backward."""
  594. error = _fa3_backward_support_error(
  595. grad_out, query, key, value, out, logsumexp, dropout_p, None, None, None
  596. )
  597. if error is not None:
  598. raise RuntimeError(f"FA3 SDPA backward unsupported: {error}")
  599. # SDPA uses BHSD layout, FA3 uses BSHD - transpose
  600. grad_out_t, q_t, k_t, v_t, out_t = _transpose_dense(
  601. grad_out, query, key, value, out
  602. )
  603. dq, dk, dv = _fa3_flash_attention_backward_impl(
  604. grad_out_t,
  605. q_t,
  606. k_t,
  607. v_t,
  608. out_t,
  609. logsumexp,
  610. None, # cum_seq_q (dense attention)
  611. None, # cum_seq_k
  612. max_q, # max_seqlen_q
  613. max_k, # max_seqlen_k
  614. dropout_p,
  615. is_causal,
  616. philox_seed,
  617. philox_offset,
  618. scale=scale,
  619. )
  620. # Transpose gradients back to BHSD layout
  621. dq_out, dk_out, dv_out = _transpose_dense(dq, dk, dv)
  622. return dq_out, dk_out, dv_out
  623. _registry.register_flash_attention_impl("FA3", register_fn=register_flash_attention_fa3)