sdpa.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936
  1. # mypy: allow-untyped-defs
  2. import logging
  3. import torch
  4. import torch.nn
  5. import torch.nn.functional as F
  6. from torch.backends.cuda import (
  7. can_use_cudnn_attention,
  8. can_use_efficient_attention,
  9. can_use_flash_attention,
  10. cudnn_sdp_enabled,
  11. flash_sdp_enabled,
  12. math_sdp_enabled,
  13. mem_efficient_sdp_enabled,
  14. SDPAParams,
  15. )
  16. from torch.nn.attention import SDPBackend
  17. from .nested_tensor import NestedTensor
  18. log = logging.getLogger(__name__)
  19. def _validate_sdpa_input(
  20. query: torch.Tensor,
  21. key: torch.Tensor,
  22. value: torch.Tensor,
  23. attn_mask: torch.Tensor | None = None,
  24. dropout_p=0.0,
  25. is_causal=False,
  26. scale=None,
  27. ) -> None:
  28. if (
  29. not isinstance(query, NestedTensor)
  30. or not isinstance(key, NestedTensor)
  31. or not isinstance(value, NestedTensor)
  32. ):
  33. raise ValueError(
  34. f"Expected query, key, and value to be nested tensors, "
  35. f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, "
  36. f"and value.is_nested: {value.is_nested} instead."
  37. )
  38. if query.dtype != key.dtype or query.dtype != value.dtype:
  39. raise ValueError(
  40. f"Expected query, key, and value to have the same dtype, "
  41. f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
  42. f"and value.dtype: {value.dtype} instead."
  43. )
  44. if query.device != key.device or query.device != value.device:
  45. raise ValueError(
  46. f"Expected query, key, and value to have the same device type, "
  47. f"but got query.device: {query.device}, key.device: {key.device}, "
  48. f"and value.device: {value.device} instead."
  49. )
  50. if query.dim() < 3 or key.dim() < 3 or value.dim() < 3:
  51. raise ValueError(
  52. f"Expected query, key, and value to all be at least 3 dimensional, but got query.dim: "
  53. f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
  54. )
  55. if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
  56. raise ValueError(
  57. f"Expected query, key, and value to all be ragged on the same dimension, but got ragged "
  58. f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively."
  59. )
  60. if attn_mask is not None:
  61. # TODO: Figure out whether masks are actually supported for this layout or not
  62. raise ValueError("Masks are not yet supported!")
  63. if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype:
  64. raise ValueError(
  65. f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: "
  66. f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead."
  67. )
  68. def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool:
  69. # This is expected to be called after check_tensor_shapes ensuring that the
  70. # size() calls won't error since the inputs are all 4 dimensional
  71. q_batch_size = params.query.size(0)
  72. k_batch_size = params.key.size(0)
  73. v_batch_size = params.value.size(0)
  74. # num_heads logic for nested input is checked in
  75. # check_for_seq_len_0_nested_tensor as there is handling there to make sure
  76. # num_heads is not ragged
  77. return q_batch_size == k_batch_size and q_batch_size == v_batch_size
  78. def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool:
  79. max_size = 256
  80. query_size_last = params.query.size(-1)
  81. key_size_last = params.key.size(-1)
  82. value_size_last = params.value.size(-1)
  83. same_head_dim_size = (
  84. query_size_last == key_size_last and query_size_last == value_size_last
  85. )
  86. if not (
  87. same_head_dim_size
  88. and (query_size_last % 8 == 0)
  89. and (query_size_last <= max_size)
  90. ):
  91. if debug:
  92. log.warning(
  93. "For NestedTensor inputs, Flash attention requires q,k,v to have the same "
  94. "last dimension and to be a multiple of 8 and less than or equal to 256. "
  95. "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
  96. query_size_last,
  97. key_size_last,
  98. value_size_last,
  99. )
  100. return False
  101. return True
  102. def _check_head_dim_size_cudnn_nested(params: SDPAParams, debug=False) -> bool:
  103. max_size = 128
  104. query_size_last = params.query.size(-1)
  105. key_size_last = params.key.size(-1)
  106. value_size_last = params.value.size(-1)
  107. same_head_dim_size = (
  108. query_size_last == key_size_last and query_size_last == value_size_last
  109. )
  110. if not (
  111. same_head_dim_size
  112. and (query_size_last % 8 == 0)
  113. and (query_size_last <= max_size)
  114. ):
  115. if debug:
  116. log.warning(
  117. "For NestedTensor inputs, cuDNN attention requires q,k,v to have the same "
  118. "last dimension and to be a multiple of 8 and less than or equal to 128. "
  119. "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
  120. query_size_last,
  121. key_size_last,
  122. value_size_last,
  123. )
  124. return False
  125. return True
  126. def _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
  127. param: torch.Tensor, param_name: str, debug=False
  128. ) -> bool:
  129. if not isinstance(param, NestedTensor):
  130. raise AssertionError("param should be a jagged NT")
  131. if param._ragged_idx == 1:
  132. # num_head_dims is ragged
  133. if debug:
  134. log.warning(
  135. "Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.",
  136. param_name,
  137. )
  138. return False
  139. # This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
  140. if param._get_min_seqlen() == 0:
  141. if debug:
  142. log.warning(
  143. "Fused kernels do not support seq_len == 0, %s has a seq len of 0.",
  144. param_name,
  145. )
  146. return False
  147. return True
  148. def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool:
  149. max_size = max(q_size, k_size, v_size)
  150. if (
  151. (q_size != max_size and q_size != 1)
  152. or (k_size != max_size and k_size != 1)
  153. or (v_size != max_size and v_size != 1)
  154. ):
  155. if debug:
  156. log.warning(
  157. "Both fused kernels require query, key and value to have broadcastable %s, "
  158. "got Query %s %d, Key %s %d, Value %s %d instead.",
  159. param_name,
  160. param_name,
  161. q_size,
  162. param_name,
  163. k_size,
  164. param_name,
  165. v_size,
  166. )
  167. return False
  168. return True
  169. def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool:
  170. # When this function is called we are assured that the nt is dim==4
  171. q_is_safe = (
  172. _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
  173. params.query, "query", debug
  174. )
  175. if params.query.is_nested
  176. else True
  177. )
  178. # short circuit if any is unsafe
  179. if not q_is_safe:
  180. return False
  181. k_is_safe = (
  182. _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
  183. params.key, "key", debug
  184. )
  185. if params.key.is_nested
  186. else True
  187. )
  188. # short circuit if any is unsafe
  189. if not k_is_safe:
  190. return False
  191. v_is_safe = (
  192. _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
  193. params.value, "value", debug
  194. )
  195. if params.value.is_nested
  196. else True
  197. )
  198. # short circuit if any is unsafe
  199. if not v_is_safe:
  200. return False
  201. # We now know none of the inputs have ragged num_heads, so we can safely
  202. # access .size(1)
  203. q_num_heads = params.query.size(1)
  204. k_num_heads = params.key.size(1)
  205. v_num_heads = params.value.size(1)
  206. same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads
  207. if not same_num_heads:
  208. if (
  209. params.query.requires_grad
  210. or params.key.requires_grad
  211. or params.value.requires_grad
  212. ):
  213. if debug:
  214. log.warning(
  215. "Both fused kernels do not support training with broadcasted NT inputs."
  216. )
  217. return False
  218. return _try_broadcast_param_size(
  219. q_num_heads, k_num_heads, v_num_heads, "num heads", debug
  220. )
  221. return True
  222. def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
  223. constraints = (
  224. _check_batch_size_nested,
  225. _check_head_dim_size_flash_nested,
  226. _check_for_seq_len_0_nested,
  227. )
  228. for constraint in constraints:
  229. if not constraint(params, debug):
  230. return False
  231. return True
  232. def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
  233. constraints = (
  234. _check_batch_size_nested,
  235. _check_for_seq_len_0_nested,
  236. )
  237. for constraint in constraints:
  238. if not constraint(params, debug):
  239. return False
  240. return True
  241. def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
  242. if (
  243. not params.query.transpose(1, 2).is_contiguous()
  244. or not params.key.transpose(1, 2).is_contiguous()
  245. or not params.value.transpose(1, 2).is_contiguous()
  246. ):
  247. if debug:
  248. log.warning(
  249. "If inputs are nested tensors they must be contiguous after transposing."
  250. )
  251. return False
  252. if params.is_causal:
  253. if debug:
  254. log.warning(
  255. "Nested tensors for query / key are not supported when is_causal=True."
  256. )
  257. return False
  258. return True
  259. def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable_gqa):
  260. if (
  261. not flash_sdp_enabled()
  262. and not mem_efficient_sdp_enabled()
  263. and not math_sdp_enabled()
  264. and not cudnn_sdp_enabled()
  265. ):
  266. return SDPBackend.ERROR
  267. ordering = (
  268. SDPBackend.FLASH_ATTENTION,
  269. SDPBackend.EFFICIENT_ATTENTION,
  270. SDPBackend.MATH,
  271. SDPBackend.CUDNN_ATTENTION,
  272. )
  273. params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa)
  274. for backend in ordering:
  275. if backend == SDPBackend.CUDNN_ATTENTION:
  276. if can_use_cudnn_attention(params):
  277. return SDPBackend.CUDNN_ATTENTION
  278. if backend == SDPBackend.FLASH_ATTENTION:
  279. if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params):
  280. return SDPBackend.FLASH_ATTENTION
  281. if backend == SDPBackend.EFFICIENT_ATTENTION:
  282. if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged(
  283. params
  284. ):
  285. return SDPBackend.EFFICIENT_ATTENTION
  286. if backend == SDPBackend.MATH:
  287. if math_sdp_enabled() and _can_use_math_sdpa_jagged(params):
  288. return SDPBackend.MATH
  289. log.warning("Memory efficient kernel not used because:")
  290. can_use_efficient_attention(params, debug=True)
  291. _can_use_efficient_sdpa_jagged(params, debug=True)
  292. log.warning("Flash attention kernel not used because:")
  293. can_use_flash_attention(params, debug=True)
  294. _can_use_flash_sdpa_jagged(params, debug=True)
  295. log.warning("Math attention kernel not used because:")
  296. _can_use_math_sdpa_jagged(params, debug=True)
  297. log.warning("cuDNN attention kernel not used because:")
  298. can_use_cudnn_attention(params, debug=True)
  299. return SDPBackend.ERROR
  300. def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> tuple[torch.Tensor, int, int]:
  301. # This function is used to calculate two pieces of metadata that are needed
  302. # for use with flash-attention and efficient_attention kernels. They are the
  303. # cumulative sequence_length over a batch of sequences and the maximum
  304. # sequence length.
  305. # It returns a tuple of cumulative sequence lengths and the maximum sequence
  306. # length, and the last element in the cumulative_sequence_lengths
  307. if not isinstance(qkv, NestedTensor):
  308. raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.")
  309. if qkv.lengths() is None:
  310. # TODO: Explore performance impact of copying
  311. cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device)
  312. max_seqlen = qkv._get_max_seqlen()
  313. n_elem = qkv.values().shape[0]
  314. else:
  315. # TODO: Explore performance impact of copying
  316. cumulative_seqlen = (
  317. qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device)
  318. )
  319. max_seqlen = qkv._get_max_seqlen()
  320. # TODO: Explore performance impact when compiling
  321. n_elem = int(cumulative_seqlen[-1].item())
  322. return cumulative_seqlen, max_seqlen, n_elem
  323. def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor) -> bool:
  324. # This function checks if a nested tensor is valid for
  325. # use with the flash-attention and efficient_attention kernels without
  326. # needing to call contiguous on the nested tensor input.
  327. # It checks that the storage offsets' adjacent_differences are a constant
  328. # multiple of the previous tensor in the nested tensor and that the strides
  329. # are monitonically decreasing. This check is done after calling transpose on
  330. # the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim]
  331. # Returns a boolean indicating if contiguous needs to be called for input
  332. if not isinstance(tensor, NestedTensor):
  333. raise AssertionError("tensor must be a NestedTensor")
  334. offsets = tensor.offsets()
  335. strides = tensor._strides
  336. n_tensors = offsets.size(0) - 1
  337. if n_tensors <= 1:
  338. return True
  339. # Check initially that the tensor strides are in strictly descending order
  340. prev_stride = strides[1]
  341. for stride in strides[2:]:
  342. if prev_stride <= stride:
  343. # This would mean that the last stride is greater than the seq_len
  344. # stride
  345. return False
  346. prev_stride = stride
  347. # Congrats you made it!
  348. return True
  349. def _view_as_dense(
  350. tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int
  351. ) -> torch.Tensor:
  352. if tensor.is_nested:
  353. return tensor.values()
  354. return tensor.view(Nnz, num_heads, head_dim)
  355. # TODO: Next iteration should add test cases and check it works
  356. # def _sdpa_nested_preprocessing_with_broadcast(query, key, value):
  357. # # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
  358. # # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
  359. # # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
  360. # q_batch_size = query.size(0)
  361. # k_batch_size = key.size(0)
  362. # v_batch_size = value.size(0)
  363. # output_batch_size = max(q_batch_size, k_batch_size, v_batch_size)
  364. # q_num_heads = query.size(1)
  365. # k_num_heads = key.size(1)
  366. # v_num_heads = value.size(1)
  367. # output_num_heads = max(q_num_heads, k_num_heads, v_num_heads)
  368. # head_dim_qk = query.size(3)
  369. # head_dim_v = value.size(3)
  370. # q_t = query.transpose(1, 2)
  371. # k_t = key.transpose(1, 2)
  372. # v_t = value.transpose(1, 2)
  373. # # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads !=
  374. # # output_batch_size/num_heads then they are 1
  375. # q_batch_size_needs_broadcast = q_batch_size != output_batch_size
  376. # k_batch_size_needs_broadcast = k_batch_size != output_batch_size
  377. # v_batch_size_needs_broadcast = v_batch_size != output_batch_size
  378. # # If {*}_batch_size_needs_broadcast, then
  379. # # (1) max_seqlen_batch_{*} is given by {*}_t.size(1)
  380. # # this is because needs_broadcast indicates that the batch_size is 1
  381. # # and hence there is only 1 value for seq_len
  382. # # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
  383. # # ..., output_batch_size * {*}_t.size(1)]
  384. # # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1)
  385. # if q_batch_size_needs_broadcast or not q_t.is_nested:
  386. # max_seqlen_batch_q = q_t.size(1)
  387. # cumulative_sequence_length_q = torch.arange(
  388. # 0,
  389. # (output_batch_size + 1) * max_seqlen_batch_q,
  390. # max_seqlen_batch_q,
  391. # device=q_t.device,
  392. # dtype=torch.int32,
  393. # )
  394. # Nnz_q = output_batch_size * max_seqlen_batch_q
  395. # else:
  396. # (
  397. # cumulative_sequence_length_q,
  398. # max_seqlen_batch_q,
  399. # Nnz_q,
  400. # ) = _cumulative_and_max_seq_len_nnz(q_t)
  401. # if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast:
  402. # assert k_t.size(1) == v_t.size(1)
  403. # max_seqlen_batch_kv = k_t.size(1)
  404. # cumulative_sequence_length_kv = torch.arange(
  405. # 0,
  406. # (output_batch_size + 1) * max_seqlen_batch_kv,
  407. # max_seqlen_batch_kv,
  408. # device=k_t.device,
  409. # dtype=torch.int32,
  410. # )
  411. # Nnz_kv = output_batch_size * max_seqlen_batch_kv
  412. # else:
  413. # cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = (
  414. # _cumulative_and_max_seq_len_nnz(v_t)
  415. # if k_batch_size_needs_broadcast
  416. # else _cumulative_and_max_seq_len_nnz(k_t)
  417. # )
  418. # q_num_heads_needs_broadcast = q_num_heads != output_num_heads
  419. # k_num_heads_needs_broadcast = k_num_heads != output_num_heads
  420. # v_num_heads_needs_broadcast = v_num_heads != output_num_heads
  421. # if not q_t.is_nested:
  422. # query_buffer_reshaped = q_t.expand(
  423. # output_batch_size, q_t.size(1), output_num_heads, head_dim_qk
  424. # )
  425. # query_buffer_reshaped = query_buffer_reshaped.reshape(
  426. # Nnz_q, output_num_heads, head_dim_qk
  427. # )
  428. # else:
  429. # if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
  430. # q_t = q_t.contiguous()
  431. # # If we are broadcasting then Nnz_q will be the output_batch_size since
  432. # # seq_len is 1
  433. # effective_batch_size_q = (
  434. # output_batch_size if q_batch_size_needs_broadcast else Nnz_q
  435. # )
  436. # query_buffer_reshaped = _view_as_dense(
  437. # q_t, effective_batch_size_q, output_num_heads, head_dim_qk
  438. # )
  439. # # If the physical layout of the NestedTensor's storage
  440. # # is not: batch, {seq_len}, num_heads, head_dim then we need
  441. # # to call contiguous
  442. # if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
  443. # k_t = k_t.contiguous()
  444. # if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
  445. # v_t = v_t.contiguous()
  446. # effective_batch_size_k = (
  447. # output_batch_size if k_batch_size_needs_broadcast else Nnz_kv
  448. # )
  449. # key_buffer_reshaped = _view_as_dense(
  450. # k_t, effective_batch_size_k, output_num_heads, head_dim_qk
  451. # )
  452. # effective_batch_size_v = (
  453. # output_batch_size if v_batch_size_needs_broadcast else Nnz_kv
  454. # )
  455. # value_buffer_reshaped = _view_as_dense(
  456. # v_t, effective_batch_size_v, output_num_heads, head_dim_v
  457. # )
  458. # if not q_batch_size_needs_broadcast:
  459. # output_shape = q_t._size
  460. # if head_dim_v != head_dim_qk:
  461. # output_shape[-1] = head_dim_v
  462. # if q_num_heads_needs_broadcast:
  463. # output_shape[1] = output_num_heads
  464. # else:
  465. # output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu"))
  466. # output_shape[0] = q_t.size(1)
  467. # output_shape[1] = output_num_heads
  468. # output_shape[2] = head_dim_v
  469. # return (
  470. # query_buffer_reshaped,
  471. # key_buffer_reshaped,
  472. # value_buffer_reshaped,
  473. # cumulative_sequence_length_q,
  474. # cumulative_sequence_length_kv,
  475. # max_seqlen_batch_q,
  476. # max_seqlen_batch_kv,
  477. # output_shape,
  478. # )
  479. def _sdpa_nested_preprocessing(query, key, value):
  480. # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
  481. # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
  482. # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
  483. q_batch_size = query.size(0)
  484. k_batch_size = key.size(0)
  485. v_batch_size = value.size(0)
  486. q_num_heads = query.size(1)
  487. k_num_heads = key.size(1)
  488. v_num_heads = value.size(1)
  489. if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not (
  490. q_num_heads == k_num_heads and k_num_heads == v_num_heads
  491. ):
  492. raise RuntimeError(
  493. "This path is currently not implemented for jagged layout NT."
  494. )
  495. # return _sdpa_nested_preprocessing_with_broadcast(query, key, value)
  496. num_heads = query.size(1)
  497. head_dim_qk = query.size(3)
  498. head_dim_v = value.size(3)
  499. q_t = query.transpose(1, 2)
  500. k_t = key.transpose(1, 2)
  501. v_t = value.transpose(1, 2)
  502. (
  503. cumulative_sequence_length_q,
  504. max_seqlen_batch_q,
  505. Nnz_q,
  506. ) = _cumulative_and_max_seq_len_nnz(q_t)
  507. (
  508. cumulative_sequence_length_kv,
  509. max_seqlen_batch_kv,
  510. Nnz_kv,
  511. ) = _cumulative_and_max_seq_len_nnz(k_t)
  512. # [TODO] K and V have to have the same Nnz, should probably torch_check
  513. # assume in order to not iterate over v
  514. # If the physical layout of the NestedTensor's storage
  515. # is not: batch, {seq_len}, num_heads, head_dim then we need
  516. # to call contiguous
  517. if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
  518. q_t = q_t.contiguous()
  519. if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
  520. k_t = k_t.contiguous()
  521. if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
  522. v_t = v_t.contiguous()
  523. query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk)
  524. key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk)
  525. value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v)
  526. output_nt_info = {
  527. "offsets": q_t.offsets(),
  528. "lengths": q_t.lengths(),
  529. "max_seqlen": q_t._get_max_seqlen(),
  530. "min_seqlen": q_t._get_min_seqlen(),
  531. }
  532. return (
  533. query_buffer_reshaped,
  534. key_buffer_reshaped,
  535. value_buffer_reshaped,
  536. cumulative_sequence_length_q,
  537. cumulative_sequence_length_kv,
  538. max_seqlen_batch_q,
  539. max_seqlen_batch_kv,
  540. output_nt_info,
  541. )
  542. def _pad_last_dim(
  543. tensor: torch.Tensor, alignment_size: int, slice: bool
  544. ) -> torch.Tensor:
  545. # FlashAttentionV2 requires that head dimension be a multiple of 8
  546. # This was previously done within the kernel, however
  547. # This causes the kernel to maybe alias query, key, value
  548. # So instead we pad the head_dimensions to be a multiple of 8
  549. # in the composite region
  550. last_dim_size = tensor.size(-1)
  551. if last_dim_size % alignment_size == 0:
  552. return tensor
  553. pad_count = alignment_size - (last_dim_size % alignment_size)
  554. tensor = torch.nn.functional.pad(tensor, [0, pad_count])
  555. if slice:
  556. return tensor[..., 0:last_dim_size]
  557. return tensor
  558. # TODO: coalesce with torch/nn/utils/attention.py
  559. def _calculate_scale(query, scale):
  560. # TODO: Investigate why math.sqrt() isn't properly handled by Dynamo?
  561. softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1))
  562. return softmax_scale
  563. def _post_process_flash_output(out: torch.Tensor, og_size):
  564. if not out.is_nested and out.size(-1) != og_size:
  565. out = out[..., 0:og_size]
  566. return out
  567. def _is_computing_meta_flops(x):
  568. # Note: there's a use case of using meta tensors & the dispatch-based flop counter.
  569. # We can use this function to check for this scenario in order to handle it specially.
  570. if not torch.jit.is_scripting() and x.device.type == "meta":
  571. torch_dispatch_mode_stack = (
  572. torch.utils._python_dispatch._get_current_dispatch_mode_stack()
  573. )
  574. return any(
  575. type(x) is torch.utils.flop_counter._FlopCounterMode
  576. for x in torch_dispatch_mode_stack
  577. )
  578. return False
  579. def _autocast(
  580. query: torch.Tensor,
  581. key: torch.Tensor,
  582. value: torch.Tensor,
  583. attn_mask: torch.Tensor | None,
  584. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
  585. """
  586. [Autocasting SDPA for NJT]
  587. Normal autocasting doesn't work for NJT+SDPA right now:
  588. * NJT intercepts the __torch_function__ call for scaled_dot_product_attention, which happens
  589. before we get to any aten ops or dispatcher logic; then the torch_function logic calls into
  590. efficient attention or flash attention. So, autocasting on the scaled_dot_product_attention
  591. op won't work because we never see that aten op.
  592. * If we put autocasting on `_flash_attention_forward`, then we'll get autocasting to run, but
  593. the kernel selection logic in torch_function handling (ie. jagged_scaled_dot_product_attention)
  594. won't work correctly: the kernel selection logic will run before autocasting, and choose
  595. a kernel based on the un-autocasted dtypes; but then autocasting will run and the actual
  596. attention computation will happen in a different dtype.
  597. An alternative is to just change the backend selection logic for SDPA+NJT to be autocast-aware
  598. and rely on autocasting to do the actual conversions for flash attention / efficient attention.
  599. However, by manually doing the actual autocast before the backend selection, we ensure that the
  600. autocast handling for backend selection doesn't diverge from the autocast handling for the
  601. actual dtype conversions.
  602. """
  603. device_type = query.device.type
  604. # meta device is not supported by autocast, so break early for it
  605. if _is_computing_meta_flops(query) or not torch.is_autocast_enabled(device_type):
  606. return query, key, value, attn_mask
  607. def cvt(x):
  608. if x is None:
  609. return x
  610. target_dtype = torch.get_autocast_dtype(device_type)
  611. if (
  612. (not x.dtype.is_floating_point)
  613. or x.dtype == target_dtype
  614. or x.dtype == torch.float64
  615. ):
  616. return x
  617. return x.to(target_dtype)
  618. return cvt(query), cvt(key), cvt(value), cvt(attn_mask)
  619. def jagged_scaled_dot_product_attention(
  620. query: torch.Tensor,
  621. key: torch.Tensor,
  622. value: torch.Tensor,
  623. attn_mask: torch.Tensor | None = None,
  624. dropout_p=0.0,
  625. is_causal=False,
  626. scale=None,
  627. enable_gqa=False,
  628. ):
  629. query, key, value, attn_mask = _autocast(query, key, value, attn_mask)
  630. _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
  631. # for mypy, ugh
  632. if not (
  633. isinstance(query, NestedTensor)
  634. and isinstance(key, NestedTensor)
  635. and isinstance(value, NestedTensor)
  636. ):
  637. raise AssertionError("query, key, and value must all be NestedTensor instances")
  638. from torch.nested._internal.nested_tensor import (
  639. nested_view_from_values_offsets_lengths,
  640. )
  641. # Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
  642. # second batch dim instead). For this case, we can just send the dense buffers through
  643. # vanilla SDPA.
  644. if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
  645. output = F.scaled_dot_product_attention(
  646. query.values(),
  647. key.values(),
  648. value.values(),
  649. attn_mask=(
  650. attn_mask.values() if isinstance(attn_mask, NestedTensor) else attn_mask
  651. ),
  652. dropout_p=dropout_p,
  653. is_causal=is_causal,
  654. scale=scale,
  655. )
  656. return nested_view_from_values_offsets_lengths(
  657. output,
  658. query.offsets(),
  659. query.lengths(),
  660. min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
  661. max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
  662. )
  663. compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
  664. backend_choice = _select_sdp_backend(
  665. query, key, value, attn_mask, dropout_p, is_causal, enable_gqa
  666. )
  667. if _is_computing_meta_flops(query):
  668. # Backend choice will probably not be correct if we have a meta device,
  669. # because backend choice is device-aware. In this case, we mostly just
  670. # want to avoid using math backend (which does a .item() call).
  671. # Arbitrarily choose flash attention.
  672. backend_choice = SDPBackend.FLASH_ATTENTION
  673. if backend_choice == SDPBackend.FLASH_ATTENTION:
  674. og_size = query.size(-1)
  675. query_padded = _pad_last_dim(query, 8, False)
  676. key_padded = _pad_last_dim(key, 8, False)
  677. value_padded = _pad_last_dim(value, 8, False)
  678. # We need to calculate the scale based off the OG head dim size
  679. og_scale = _calculate_scale(query, scale)
  680. (
  681. query_buffer_reshaped,
  682. key_buffer_reshaped,
  683. value_buffer_reshaped,
  684. cumulative_sequence_length_q,
  685. cumulative_sequence_length_kv,
  686. max_seqlen_batch_q,
  687. max_seqlen_batch_kv,
  688. output_nt_info,
  689. ) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded)
  690. (
  691. attention,
  692. _logsumexp,
  693. _philox_seed,
  694. _philox_offset,
  695. _debug_attn_mask,
  696. ) = torch.ops.aten._flash_attention_forward(
  697. query_buffer_reshaped,
  698. key_buffer_reshaped,
  699. value_buffer_reshaped,
  700. cumulative_sequence_length_q,
  701. cumulative_sequence_length_kv,
  702. max_seqlen_batch_q,
  703. max_seqlen_batch_kv,
  704. dropout_p,
  705. is_causal,
  706. False,
  707. scale=og_scale,
  708. )
  709. # Reshape output to convert nnz to batch_size and seq_len
  710. attention = nested_view_from_values_offsets_lengths(
  711. attention, # output from flash_attn is [total_q, num_heads, head_size_og]
  712. **output_nt_info,
  713. ).transpose(1, 2)
  714. return _post_process_flash_output(attention, og_size)
  715. elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
  716. (
  717. query_reshaped,
  718. key_reshaped,
  719. value_reshaped,
  720. cumulative_sequence_length_q,
  721. cumulative_sequence_length_kv,
  722. max_seqlen_batch_q,
  723. max_seqlen_batch_kv,
  724. output_nt_info,
  725. ) = _sdpa_nested_preprocessing(query, key, value)
  726. (
  727. attention,
  728. log_sumexp,
  729. seed,
  730. offset,
  731. max_seqlen_q,
  732. max_seqlen_batch_kv,
  733. ) = torch.ops.aten._efficient_attention_forward(
  734. query_reshaped.unsqueeze(0),
  735. key_reshaped.unsqueeze(0),
  736. value_reshaped.unsqueeze(0),
  737. None,
  738. cumulative_sequence_length_q,
  739. cumulative_sequence_length_kv,
  740. max_seqlen_batch_q,
  741. max_seqlen_batch_kv,
  742. dropout_p,
  743. int(is_causal),
  744. compute_logsumexp,
  745. scale=scale,
  746. )
  747. # Reshape output to convert nnz to batch_size and seq_len
  748. return nested_view_from_values_offsets_lengths(
  749. attention.squeeze(0),
  750. **output_nt_info,
  751. ).transpose(1, 2)
  752. elif backend_choice == SDPBackend.CUDNN_ATTENTION:
  753. (
  754. query_reshaped,
  755. key_reshaped,
  756. value_reshaped,
  757. cumulative_sequence_length_q,
  758. cumulative_sequence_length_kv,
  759. max_seqlen_batch_q,
  760. max_seqlen_batch_kv,
  761. output_nt_info,
  762. ) = _sdpa_nested_preprocessing(query, key, value)
  763. (
  764. attention,
  765. logsumexp,
  766. cum_seqlen_q,
  767. cum_seqlen_kv,
  768. max_seqlen_q,
  769. max_seqlen_kv,
  770. seed,
  771. offset,
  772. _,
  773. ) = torch.ops.aten._cudnn_attention_forward(
  774. query_reshaped,
  775. key_reshaped,
  776. value_reshaped,
  777. attn_mask,
  778. cumulative_sequence_length_q,
  779. cumulative_sequence_length_kv,
  780. max_seqlen_batch_q,
  781. max_seqlen_batch_kv,
  782. compute_logsumexp,
  783. dropout_p,
  784. is_causal,
  785. False,
  786. scale=scale,
  787. )
  788. return nested_view_from_values_offsets_lengths(
  789. attention,
  790. **output_nt_info,
  791. ).transpose(1, 2)
  792. elif backend_choice == SDPBackend.MATH:
  793. # save the offsets and shape of the inputs, so we can reshape the final output
  794. # query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
  795. # attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
  796. offsets = query.offsets()
  797. q_lengths = query.lengths()
  798. min_seqlen = query._maybe_min_seqlen
  799. max_seqlen = query._maybe_max_seqlen
  800. d1 = query._size[1]
  801. d2 = value._size[-1]
  802. # convert jagged layout Nested Tensor to strided layout Nested Tensor
  803. # which support the math implementation of SDPA
  804. def get_strided_layout_nested_tensor(jagged_layout_nt):
  805. lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1]
  806. transpose = torch.transpose(jagged_layout_nt, 1, 2)
  807. tensor_list = transpose.values().split(list(lengths), dim=0)
  808. strided_nt = torch.nested.as_nested_tensor(list(tensor_list))
  809. strided_nt = strided_nt.transpose(1, 2).contiguous()
  810. return strided_nt
  811. query = get_strided_layout_nested_tensor(query)
  812. key = get_strided_layout_nested_tensor(key)
  813. value = get_strided_layout_nested_tensor(value)
  814. attn_out = torch._scaled_dot_product_attention_math(
  815. query, key, value, attn_mask, dropout_p, is_causal, scale=scale
  816. )[0]
  817. # convert strided layout Nested Tensor back to jagged layout Nested Tensor
  818. attn_out = attn_out.transpose(1, 2).contiguous().values()
  819. attn_out = attn_out.view(-1, d1, d2)
  820. attn_out = nested_view_from_values_offsets_lengths(
  821. attn_out,
  822. offsets,
  823. lengths=q_lengths,
  824. min_seqlen=min_seqlen,
  825. max_seqlen=max_seqlen,
  826. ).transpose(1, 2)
  827. return attn_out
  828. else:
  829. raise RuntimeError(
  830. "No viable backend for scaled_dot_product_attention was found."
  831. )