varlen.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. """
  2. Variable-length attention implementation using Flash Attention.
  3. This module provides a high-level Python interface for variable-length attention
  4. that calls into the optimized Flash Attention kernels.
  5. """
  6. import logging
  7. from functools import lru_cache
  8. from typing import Any, NamedTuple
  9. import torch
  10. log = logging.getLogger(__name__)
  11. __all__ = ["varlen_attn", "AuxRequest"]
  12. def _normalize_window_size(window_size: list[int] | None) -> list[int]:
  13. if window_size is None:
  14. window_size = [-1, -1]
  15. if len(window_size) != 2:
  16. raise ValueError(f"window_size must have length 2, got {len(window_size)}")
  17. return window_size
  18. @lru_cache(maxsize=8)
  19. def _should_use_cudnn(device_index: int) -> bool:
  20. """Cache device capability check to avoid repeated CUDA calls."""
  21. return False
  22. class AuxRequest(NamedTuple):
  23. """
  24. Request which auxiliary outputs to compute from varlen_attn.
  25. Each field is a boolean indicating whether that auxiliary output should be computed.
  26. """
  27. lse: bool = False
  28. @torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={})
  29. def _varlen_attn(
  30. query: torch.Tensor,
  31. key: torch.Tensor,
  32. value: torch.Tensor,
  33. cu_seq_q: torch.Tensor,
  34. cu_seq_k: torch.Tensor,
  35. max_q: int,
  36. max_k: int,
  37. is_causal: bool = False,
  38. scale: float | None = None,
  39. window_size: list[int] | None = None,
  40. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  41. """
  42. Private custom op for variable-length attention.
  43. This is the internal implementation. Users should use the public varlen_attn function instead.
  44. """
  45. window_size = _normalize_window_size(window_size)
  46. use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
  47. if use_cudnn:
  48. log.info("Using cuDNN backend for varlen_attn")
  49. if window_size[0] != -1 or window_size[1] != -1:
  50. raise RuntimeError(
  51. "cuDNN backend does not support window attention. Please use Flash Attention backend."
  52. )
  53. result = torch.ops.aten._cudnn_attention_forward(
  54. query,
  55. key,
  56. value,
  57. None, # attn_bias
  58. cu_seq_q,
  59. cu_seq_k,
  60. max_q,
  61. max_k,
  62. True, # compute_log_sumexp
  63. 0.0, # dropout_p hardcoded to 0.0
  64. is_causal,
  65. False, # return_debug_mask
  66. scale=scale,
  67. )
  68. # cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask)
  69. output, softmax_lse, rng_state = result[0], result[1], result[6]
  70. else:
  71. log.info("Using Flash Attention backend for varlen_attn")
  72. output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward(
  73. query,
  74. key,
  75. value,
  76. cu_seq_q,
  77. cu_seq_k,
  78. max_q,
  79. max_k,
  80. 0.0, # dropout_p hardcoded to 0.0
  81. is_causal,
  82. return_debug_mask=False,
  83. scale=scale,
  84. window_size_left=window_size[0],
  85. window_size_right=window_size[1],
  86. )
  87. rng_state_ = torch.zeros(
  88. (2,), dtype=torch.uint64, device=query.device
  89. ) # hardcoded since dropout is hardcoded to 0
  90. return output, softmax_lse, rng_state_
  91. @_varlen_attn.register_fake
  92. def _varlen_attn_fake(
  93. query: torch.Tensor,
  94. key: torch.Tensor,
  95. value: torch.Tensor,
  96. cu_seq_q: torch.Tensor,
  97. cu_seq_k: torch.Tensor,
  98. max_q: int,
  99. max_k: int,
  100. is_causal: bool = False,
  101. scale: float | None = None,
  102. window_size: list[int] | None = None,
  103. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  104. """
  105. Fake implementation for meta tensor computation and tracing.
  106. Based on the 3D varlen path from meta__flash_attention_forward:
  107. - query shape: (total, num_heads, head_dim)
  108. - logsumexp shape: (num_heads, total_q)
  109. """
  110. window_size = _normalize_window_size(window_size)
  111. # Output has same shape as query
  112. output = torch.empty_like(query)
  113. # For varlen path: logsumexp shape is (num_heads, total_q)
  114. total_q = query.size(0)
  115. num_heads = query.size(1)
  116. if torch.version.hip:
  117. # ROCm uses batched format: [batch_size, num_heads, max_q]
  118. batch_size = cu_seq_q.size(0) - 1
  119. logsumexp = torch.empty(
  120. (batch_size, num_heads, max_q), dtype=torch.float, device=query.device
  121. )
  122. else:
  123. logsumexp = torch.empty(
  124. (num_heads, total_q), dtype=torch.float, device=query.device
  125. )
  126. rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device)
  127. return output, logsumexp, rng_state
  128. def varlen_attn(
  129. query: torch.Tensor,
  130. key: torch.Tensor,
  131. value: torch.Tensor,
  132. cu_seq_q: torch.Tensor,
  133. cu_seq_k: torch.Tensor,
  134. max_q: int,
  135. max_k: int,
  136. *,
  137. return_aux: AuxRequest | None = None,
  138. scale: float | None = None,
  139. window_size: tuple[int, int] = (-1, -1),
  140. ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
  141. """
  142. Compute variable-length attention using Flash Attention.
  143. This function is similar to scaled_dot_product_attention but optimized for
  144. variable-length sequences using cumulative sequence position tensors.
  145. Args:
  146. query (Tensor): Query tensor; shape :math:`(T_q, H, D)`
  147. key (Tensor): Key tensor; shape :math:`(T_k, H, D)`
  148. value (Tensor): Value tensor; shape :math:`(T_k, H, D)`
  149. cu_seq_q (Tensor): Cumulative sequence positions for queries; shape :math:`(N+1,)`
  150. cu_seq_k (Tensor): Cumulative sequence positions for keys/values; shape :math:`(N+1,)`
  151. max_q (int): Maximum query sequence length in the batch.
  152. max_k (int): Maximum key/value sequence length in the batch.
  153. return_aux (Optional[AuxRequest]): If not None and ``return_aux.lse`` is True, also returns the logsumexp tensor.
  154. scale (float, optional): Scaling factor for attention scores
  155. window_size (tuple[int, int], optional): Window size for sliding window attention as (left, right).
  156. Use (-1, -1) for full attention (default), (-1, 0) for causal attention,
  157. or (W, 0) for causal attention with sliding window of size W.
  158. Returns:
  159. output (Tensor): Output tensor from attention computation; shape :math:`(T_q, H, D)`.
  160. If ``return_aux`` is not None and ``return_aux.lse`` is True:
  161. lse (Tensor): Log-sum-exp of attention scores; shape :math:`(T_q, H)`.
  162. Shape legend:
  163. - :math:`N`: Batch size
  164. - :math:`T_q`: Total number of query tokens in the batch (sum of all query sequence lengths)
  165. - :math:`T_k`: Total number of key/value tokens in the batch (sum of all key/value sequence lengths)
  166. - :math:`H`: Number of attention heads
  167. - :math:`D`: Head dimension
  168. Example::
  169. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  170. >>> batch_size, max_seq_len, embed_dim, num_heads = 2, 512, 1024, 16
  171. >>> head_dim = embed_dim // num_heads
  172. >>> seq_lengths = []
  173. >>> for _ in range(batch_size):
  174. ... length = torch.randint(1, max_seq_len // 64 + 1, (1,)).item() * 64
  175. ... seq_lengths.append(min(length, max_seq_len))
  176. >>> seq_lengths = torch.tensor(seq_lengths, device="cuda")
  177. >>> total_tokens = seq_lengths.sum().item()
  178. >>>
  179. >>> # Create packed query, key, value tensors
  180. >>> query = torch.randn(
  181. ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
  182. ... )
  183. >>> key = torch.randn(
  184. ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
  185. ... )
  186. >>> value = torch.randn(
  187. ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
  188. ... )
  189. >>>
  190. >>> # Build cumulative sequence tensor
  191. >>> cu_seq = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
  192. >>> cu_seq[1:] = seq_lengths.cumsum(0)
  193. >>> max_len = seq_lengths.max().item()
  194. >>>
  195. >>> # Call varlen_attn
  196. >>> output = varlen_attn(
  197. ... query, key, value, cu_seq, cu_seq, max_len, max_len
  198. ... )
  199. """
  200. is_causal = window_size == (-1, 0)
  201. out, lse, _ = torch.ops.torch_attn._varlen_attn(
  202. query,
  203. key,
  204. value,
  205. cu_seq_q,
  206. cu_seq_k,
  207. max_q,
  208. max_k,
  209. is_causal,
  210. scale,
  211. list(window_size),
  212. )
  213. if return_aux is not None and return_aux.lse:
  214. return out, lse
  215. return out
  216. def _setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None:
  217. (
  218. query,
  219. key,
  220. value,
  221. cu_seq_q,
  222. cu_seq_k,
  223. max_q,
  224. max_k,
  225. is_causal,
  226. scale,
  227. window_size,
  228. ) = inputs
  229. out, lse, rng_state = output
  230. ctx.save_for_backward(query, key, value, cu_seq_q, cu_seq_k, out, lse, rng_state)
  231. ctx.max_q = max_q
  232. ctx.max_k = max_k
  233. ctx.is_causal = is_causal
  234. ctx.scale = scale
  235. ctx.window_size = window_size
  236. @torch.library.custom_op("torch_attn::_varlen_attn_backward", mutates_args={})
  237. def _varlen_attn_backward(
  238. grad_out: torch.Tensor,
  239. query: torch.Tensor,
  240. key: torch.Tensor,
  241. value: torch.Tensor,
  242. out: torch.Tensor,
  243. lse: torch.Tensor,
  244. cu_seq_q: torch.Tensor,
  245. cu_seq_k: torch.Tensor,
  246. max_q: int,
  247. max_k: int,
  248. is_causal: bool,
  249. rng_state: torch.Tensor,
  250. scale: float | None = None,
  251. window_size: list[int] | None = None,
  252. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  253. window_size = _normalize_window_size(window_size)
  254. unused = torch.empty(0, device=query.device)
  255. use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
  256. if use_cudnn:
  257. log.info("Using cuDNN backend for varlen_attn")
  258. if window_size[0] != -1 or window_size[1] != -1:
  259. raise RuntimeError(
  260. "cuDNN backend does not support window attention. Please use Flash Attention backend."
  261. )
  262. dq, dk, dv = torch.ops.aten._cudnn_attention_backward(
  263. grad_out,
  264. query,
  265. key,
  266. value,
  267. out,
  268. lse,
  269. cu_seq_q,
  270. cu_seq_k,
  271. max_q,
  272. max_k,
  273. 0.0,
  274. is_causal,
  275. rng_state,
  276. unused,
  277. scale=scale,
  278. )
  279. else:
  280. log.info("Using Flash Attention backend for varlen_attn")
  281. dq, dk, dv = torch.ops.aten._flash_attention_backward(
  282. grad_out,
  283. query,
  284. key,
  285. value,
  286. out,
  287. lse,
  288. cu_seq_q,
  289. cu_seq_k,
  290. max_q,
  291. max_k,
  292. 0.0,
  293. is_causal,
  294. rng_state,
  295. unused,
  296. scale=scale,
  297. window_size_left=window_size[0],
  298. window_size_right=window_size[1],
  299. )
  300. return dq, dk, dv
  301. @_varlen_attn_backward.register_fake
  302. def _varlen_attn_backward_fake(
  303. grad_out: torch.Tensor,
  304. query: torch.Tensor,
  305. key: torch.Tensor,
  306. value: torch.Tensor,
  307. out: torch.Tensor,
  308. lse: torch.Tensor,
  309. cu_seq_q: torch.Tensor,
  310. cu_seq_k: torch.Tensor,
  311. max_q: int,
  312. max_k: int,
  313. is_causal: bool,
  314. rng_state: torch.Tensor,
  315. scale: float | None = None,
  316. window_size: list[int] | None = None,
  317. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  318. """
  319. Fake implementation for meta tensor computation and tracing.
  320. """
  321. window_size = _normalize_window_size(window_size)
  322. grad_query = torch.empty_like(query)
  323. grad_key = torch.empty_like(key)
  324. grad_value = torch.empty_like(value)
  325. return grad_query, grad_key, grad_value
  326. def _backward(
  327. ctx: Any, grad_out: torch.Tensor, grad_lse: torch.Tensor, grad_rng: torch.Tensor
  328. ) -> tuple[torch.Tensor | None, ...]:
  329. query, key, value, cu_seq_q, cu_seq_k, out, lse, rng_state = ctx.saved_tensors
  330. max_q = ctx.max_q
  331. max_k = ctx.max_k
  332. is_causal = ctx.is_causal
  333. scale = ctx.scale
  334. window_size = ctx.window_size
  335. dq, dk, dv = torch.ops.torch_attn._varlen_attn_backward(
  336. grad_out,
  337. query,
  338. key,
  339. value,
  340. out,
  341. lse,
  342. cu_seq_q,
  343. cu_seq_k,
  344. max_q,
  345. max_k,
  346. is_causal,
  347. rng_state,
  348. scale,
  349. window_size,
  350. )
  351. return dq, dk, dv, None, None, None, None, None, None, None
  352. _varlen_attn.register_autograd(_backward, setup_context=_setup_context)