flex_attention.py 67 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742
  1. # mypy: allow-untyped-defs
  2. # flake8: noqa: B950
  3. """This module implements the user facing API for flex_attention in PyTorch."""
  4. import functools
  5. import inspect
  6. import itertools
  7. import math
  8. import operator
  9. import typing
  10. import warnings
  11. from collections.abc import Callable
  12. from enum import Enum
  13. from typing import Any, Literal, NamedTuple, TypeAlias
  14. from typing_extensions import NotRequired, TypedDict
  15. import torch
  16. from torch import Tensor
  17. from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop
  18. from torch._higher_order_ops.utils import setup_compilation_env
  19. from torch._prims_common import DeviceLikeType
  20. from torch.nn.attention._utils import _validate_sdpa_input
  21. from torch.utils._pytree import GetAttrKey, tree_map_only
  22. # Private debug flag to disable internal compilation wrapping for debugging purposes.
  23. # WARNING: This is intended ONLY for debugging score_mod and mask_mod functions.
  24. # When enabled, this bypasses the required internal compilation that ensures correctness
  25. # and performance. Only use this temporarily when you need to set breakpoints
  26. # in your score_mod/mask_mod functions during development.
  27. #
  28. # This flag only affects the internal compilation when flex_attention is called directly.
  29. # If you have already wrapped flex_attention in torch.compile(), this flag has no effect
  30. # and the user's compilation will still occur.
  31. #
  32. # Usage:
  33. # import torch.nn.attention.flex_attention as fa
  34. # fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True
  35. # # Now you can set breakpoints in your score_mod/mask_mod
  36. # output = fa.flex_attention(q, k, v, score_mod=my_score_mod)
  37. #
  38. _FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = False
  39. _WARNINGS_SHOWN: set[str] = set()
  40. def _warn_once(
  41. warning_id: str, message: str, category: type[Warning] = UserWarning
  42. ) -> None:
  43. """Helper to ensure each warning is shown only once per process."""
  44. if warning_id not in _WARNINGS_SHOWN:
  45. if not torch.compiler.is_compiling():
  46. warnings.warn(message, category, stacklevel=2)
  47. _WARNINGS_SHOWN.add(warning_id)
  48. __all__ = [
  49. "BlockMask",
  50. "flex_attention",
  51. "AuxOutput",
  52. "AuxRequest",
  53. "FlexKernelOptions",
  54. "create_block_mask",
  55. "create_mask",
  56. "or_masks",
  57. "and_masks",
  58. "noop_mask",
  59. ]
  60. _score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor]
  61. _mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
  62. _Backend: TypeAlias = Literal["AUTO", "TRITON", "FLASH", "TRITON_DECODE"]
  63. class FlexKernelOptions(TypedDict, total=False):
  64. """Options for controlling the behavior of FlexAttention kernels.
  65. These options are passed to the underlying Triton kernels to control performance
  66. and numerical behavior. Most users will not need to specify these options as the
  67. default autotuning provides good performance.
  68. The options can be prefixed with ``fwd_`` or ``bwd_`` to apply only to forward or
  69. backward pass respectively. For example: ``fwd_BLOCK_M`` and ``bwd_BLOCK_M1``.
  70. Note:
  71. We currently do not provide any backward compatibility guarantees for these options.
  72. That being said most of these have remained pretty stable since their introduction. But
  73. We do not consider this part of the public API just yet. We think that some documentation
  74. Is better than secret hidden flags, but we may change these options in the future.
  75. Example Usage:
  76. .. code-block:: python
  77. # Using dictionary (backward compatible)
  78. kernel_opts = {"BLOCK_M": 64, "BLOCK_N": 64, "PRESCALE_QK": True}
  79. output = flex_attention(q, k, v, kernel_options=kernel_opts)
  80. # Using TypedDict (recommended for type safety)
  81. from torch.nn.attention.flex_attention import FlexKernelOptions
  82. kernel_opts: FlexKernelOptions = {
  83. "BLOCK_M": 64,
  84. "BLOCK_N": 64,
  85. "PRESCALE_QK": True,
  86. }
  87. output = flex_attention(q, k, v, kernel_options=kernel_opts)
  88. # Forward/backward specific options
  89. kernel_opts: FlexKernelOptions = {
  90. "fwd_BLOCK_M": 64,
  91. "bwd_BLOCK_M1": 32,
  92. "PRESCALE_QK": False,
  93. }
  94. output = flex_attention(q, k, v, kernel_options=kernel_opts)
  95. """
  96. # Performance tuning options
  97. num_warps: NotRequired[int]
  98. """Number of warps to use in the CUDA kernel. Higher values may improve performance
  99. but increase register pressure. Default is determined by autotuning."""
  100. num_stages: NotRequired[int]
  101. """Number of pipeline stages in the CUDA kernel. Higher values may improve performance
  102. but increase shared memory usage. Default is determined by autotuning."""
  103. BLOCK_M: NotRequired[int]
  104. """Thread block size for the sequence length dimension of Q in forward pass.
  105. Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
  106. BLOCK_N: NotRequired[int]
  107. """Thread block size for the sequence length dimension of K/V in forward pass.
  108. Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
  109. # Backward-specific block sizes (when prefixed with 'bwd_')
  110. BLOCK_M1: NotRequired[int]
  111. """Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'.
  112. Default is determined by autotuning."""
  113. BLOCK_N1: NotRequired[int]
  114. """Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'.
  115. Default is determined by autotuning."""
  116. BLOCK_M2: NotRequired[int]
  117. """Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'.
  118. Default is determined by autotuning."""
  119. BLOCK_N2: NotRequired[int]
  120. """Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'.
  121. Default is determined by autotuning."""
  122. PRESCALE_QK: NotRequired[bool]
  123. """Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but
  124. may have more numerical error. Default: False."""
  125. ROWS_GUARANTEED_SAFE: NotRequired[bool]
  126. """If True, guarantees that at least one value in each row is not masked out.
  127. Allows skipping safety checks for better performance. Only set this if you are certain
  128. your mask guarantees this property. For example, causal attention is guaranteed safe
  129. because each query has at least 1 key-value to attend to. Default: False."""
  130. BLOCKS_ARE_CONTIGUOUS: NotRequired[bool]
  131. """If True, guarantees that all blocks in the mask are contiguous.
  132. Allows optimizing block traversal. For example, causal masks would satisfy this,
  133. but prefix_lm + sliding window would not. Default: False."""
  134. WRITE_DQ: NotRequired[bool]
  135. """Controls whether gradient scatters are done in the DQ iteration loop of the backward pass.
  136. Setting this to False will force this to happen in the DK loop which depending on your
  137. specific score_mod and mask_mod might be faster. Default: True."""
  138. FORCE_USE_FLEX_ATTENTION: NotRequired[bool]
  139. """If True, forces the use of the flex attention kernel instead of potentially using
  140. the more optimized flex-decoding kernel for short sequences. This can be a helpful
  141. option for debugging. Default: False."""
  142. USE_TMA: NotRequired[bool]
  143. """Whether to use Tensor Memory Accelerator (TMA) on supported hardware.
  144. This is experimental and may not work on all hardware, currently specific
  145. to NVIDIA GPUs Hopper+. Default: False."""
  146. # ROCm-specific options
  147. kpack: NotRequired[int]
  148. """ROCm-specific kernel packing parameter."""
  149. matrix_instr_nonkdim: NotRequired[int]
  150. """ROCm-specific matrix instruction non-K dimension."""
  151. waves_per_eu: NotRequired[int]
  152. """ROCm-specific waves per execution unit."""
  153. BACKEND: NotRequired[_Backend]
  154. """Selects a specific kernel backend.
  155. Options:
  156. - "AUTO": Use current heuristics (typically Triton-based kernels with
  157. automatic selection between flex_attention and flex_decoding)
  158. - "TRITON": Standard Triton flex_attention kernel
  159. - "TRITON_DECODE": Triton flex_decoding kernel, only available for short sequence lengths with specific configurations
  160. - "FLASH": Experimental: Flash Attention kernel (cute-dsl), user needs to have flash installed
  161. This option cannot be combined with legacy knobs such as ``FORCE_USE_FLEX_ATTENTION``.
  162. Raises an error if the requested backend cannot be used. Default: "AUTO"
  163. """
  164. class AuxRequest(NamedTuple):
  165. """Request which auxiliary outputs to compute from flex_attention.
  166. Each field is a boolean indicating whether that auxiliary output should be computed.
  167. """
  168. lse: bool = False
  169. max_scores: bool = False
  170. class AuxOutput(NamedTuple):
  171. """Auxiliary outputs from flex_attention operation.
  172. Fields will be None if not requested, or contain the tensor if requested.
  173. """
  174. lse: Tensor | None = None
  175. max_scores: Tensor | None = None
  176. class _ModificationType(Enum):
  177. """Enum for the type of modification function.
  178. - SCORE_MOD: score_mod function which accepts a score as the first argument
  179. - mask_mod: mask function which does not accept a score and is only used for generating
  180. block mask
  181. """
  182. SCORE_MOD = 1
  183. MASK_MOD = 2
  184. UNKNOWN = 3
  185. def _get_mod_type(fn: Callable) -> _ModificationType:
  186. """Get the type of modification function.
  187. This function inspects the number of positional arguments of the function to determine
  188. the type of modification function. If the function has 5 positional arguments, it is
  189. considered as a score_mod function. If the function has 4 positional arguments, it is
  190. considered as a mask function.
  191. """
  192. if hasattr(fn, "__code__"):
  193. code = fn.__code__
  194. num_positional_total = code.co_argcount
  195. defaults = ()
  196. if hasattr(fn, "__defaults__"):
  197. defaults = fn.__defaults__ or ()
  198. num_defaults = len(defaults)
  199. num_positional_args = num_positional_total - num_defaults
  200. else:
  201. num_positional_args = sum(
  202. 1
  203. for param in inspect.signature(fn).parameters.values()
  204. if param.default is inspect.Parameter.empty
  205. )
  206. if num_positional_args != 5 and num_positional_args != 4:
  207. raise AssertionError(
  208. f"Expected 4 or 5 positional args, got {num_positional_args}"
  209. )
  210. if num_positional_args == 5:
  211. return _ModificationType.SCORE_MOD
  212. elif num_positional_args == 4:
  213. return _ModificationType.MASK_MOD
  214. else:
  215. return _ModificationType.UNKNOWN
  216. # Need to define it here so that Dynamo doesn't skip it
  217. def _vmap_for_bhqkv(
  218. fn: Callable,
  219. prefix: tuple[int | None, ...],
  220. suffix: tuple[int | None, ...] = (),
  221. out_dims: int | list[int | None] = 0,
  222. group_dim: bool = False,
  223. ):
  224. """Used to vmap both score_mods and mask_mods over 4-dimensional/5-dimension inputs.
  225. Mapping over the [b, hq, q_idx, kv_idx] or [b, hkv, g, q_idx, kv_idx] dimensions.
  226. Args:
  227. fn (callable): The function to vmap.
  228. prefix (tuple): The prefix of the vmap. For score mod functions,
  229. this should be set to (0,). For mask_mods = ()
  230. suffix (tuple): We need to add (0,) if gradOut is being mapped over,
  231. and (None,) * len(other_buffers).
  232. out_dims (tuple): For forward cases, keep this as the default 0 since
  233. we are only returning 1 output. For backwards, the joint
  234. graph returns grads for B, H, Q_idx, KV_idx and other_buffers,
  235. so we set this to (0, None, None, None, None) + (None,) * len(other_buffers).
  236. Returns:
  237. callable: The vmapped function.
  238. """
  239. # We vamp a function 4 times, broadcasting the [b, h, q_idx, kv_idx] dimensions
  240. dimensions: list[tuple[int | None, int | None, int | None, int | None]] = []
  241. dimensions = [
  242. (None, None, None, 0),
  243. (None, None, 0, None),
  244. (None, 0, None, None),
  245. ]
  246. if group_dim:
  247. dimensions += [
  248. (None, 0, None, None),
  249. ]
  250. dimensions += [
  251. (0, None, None, None),
  252. ]
  253. for dims in dimensions:
  254. fn = torch.vmap(fn, in_dims=prefix + dims + suffix, out_dims=out_dims) # type: ignore[arg-type]
  255. return fn
  256. def _identity(
  257. score: Tensor,
  258. batch: Tensor,
  259. head: Tensor,
  260. token_q: Tensor,
  261. token_kv: Tensor,
  262. ) -> Tensor:
  263. return score
  264. def noop_mask(
  265. batch: Tensor,
  266. head: Tensor,
  267. token_q: Tensor,
  268. token_kv: Tensor,
  269. ) -> Tensor:
  270. """Returns a noop mask_mod"""
  271. return batch.new_ones(size=(), dtype=torch.bool, device=batch.device)
  272. def _sliced_mask_mod_error(
  273. batch: Tensor,
  274. head: Tensor,
  275. token_q: Tensor,
  276. token_kv: Tensor,
  277. ) -> Tensor:
  278. """
  279. Raises helpful error when using mask_mod from a sliced BlockMask.
  280. After slicing a BlockMask, the mask_mod is reset and cannot be used directly.
  281. Users must reassign mask_mod from the original (unsliced) BlockMask.
  282. """
  283. raise RuntimeError(
  284. "Cannot use mask_mod from a sliced BlockMask. "
  285. "When you slice a BlockMask using [], the mask_mod attribute is reset. "
  286. "You must set it from the original BlockMask's mask_mod."
  287. "\n\nIncorrect usage:"
  288. "\n base_mask = create_block_mask(my_mask_fn, ...)"
  289. "\n sliced_mask = base_mask[:, :, block_idx]"
  290. "\n sliced_mask.mask_mod = apply_offset(sliced_mask.mask_mod, offset) # WRONG!"
  291. "\n\nCorrect usage:"
  292. "\n base_mask = create_block_mask(my_mask_fn, ...)"
  293. "\n sliced_mask = base_mask[:, :, block_idx]"
  294. "\n sliced_mask.mask_mod = apply_offset(base_mask.mask_mod, offset) # Use base_mask!"
  295. )
  296. _DEFAULT_SPARSE_BLOCK_SIZE = 128
  297. _LARGE_SPARSE_BLOCK_SIZE = 1 << 30
  298. def _ordered_to_dense(num_blocks_in_row: Tensor, col_indices: Tensor):
  299. num_rows = col_indices.shape[-2]
  300. num_cols = col_indices.shape[-1]
  301. batch_dims = num_blocks_in_row.shape[:-1]
  302. device = num_blocks_in_row.device
  303. def create_dense_one(kv_num_blocks, kv_indices):
  304. dense_mask = kv_indices.new_zeros(num_rows, num_cols + 1, dtype=torch.int32)
  305. row_indices = torch.arange(num_rows, dtype=torch.int, device=device).unsqueeze(
  306. -1
  307. )
  308. col_range = torch.arange(num_cols, dtype=torch.int, device=device)
  309. index_mask = col_range < kv_num_blocks.unsqueeze(-1)
  310. # We write to one spot "out of bounds"
  311. valid_indices = torch.where(index_mask, kv_indices, num_cols)
  312. # set the values in 'a' to 1 where the indices are valid
  313. dense_mask[row_indices, valid_indices] = dense_mask.new_ones(())
  314. return dense_mask[:, :num_cols].contiguous()
  315. create_dense_batched = create_dense_one
  316. for _ in range(len(batch_dims)):
  317. create_dense_batched = torch.vmap(create_dense_batched, in_dims=(0, 0))
  318. out = create_dense_batched(num_blocks_in_row, col_indices)
  319. return out
  320. def _dense_to_ordered(dense_mask) -> tuple[Tensor, Tensor]:
  321. dense_mask = dense_mask.to(dtype=torch.int32)
  322. num_blocks_in_row = dense_mask.sum(dim=-1)
  323. col_indices = torch.argsort(dense_mask, dim=-1, descending=True, stable=True)
  324. return (
  325. num_blocks_in_row.to(torch.int32, memory_format=torch.contiguous_format),
  326. col_indices.to(torch.int32, memory_format=torch.contiguous_format),
  327. )
  328. def _transpose_ordered(num_blocks_in_row: Tensor, col_indices: Tensor):
  329. dense = _ordered_to_dense(num_blocks_in_row, col_indices)
  330. return _dense_to_ordered(dense.transpose(-2, -1))
  331. def _adjust_num_blocks_and_indices(
  332. num_blocks: Tensor,
  333. indices: Tensor,
  334. new_num_rows: int,
  335. new_num_cols: int,
  336. ):
  337. indices = indices[:, :, :new_num_rows, :new_num_cols]
  338. num_blocks = num_blocks[:, :, :new_num_rows]
  339. num_blocks = torch.where(num_blocks < new_num_cols, num_blocks, new_num_cols)
  340. num_blocks = torch.sum(indices < num_blocks[:, :, :, None], dim=-1).to(torch.int32)
  341. return num_blocks, indices
  342. def _closure_contents(fn: object) -> tuple[object, ...]:
  343. """Extract closure cell contents for comparison."""
  344. closure = getattr(fn, "__closure__", None)
  345. if closure is None:
  346. return ()
  347. return tuple(cell.cell_contents for cell in closure)
  348. class _MaskModWrapper:
  349. """Wraps a mask_mod function with value-based equality.
  350. BlockMask stores an arbitrary callable (mask_mod) in its pytree context.
  351. The default __eq__ for functions uses identity comparison, which is too
  352. strict when the same closure is recreated (e.g., defined inside forward()).
  353. This wrapper compares functions by their code object and closure contents.
  354. """
  355. __slots__ = ("fn",)
  356. def __init__(self, fn: _mask_mod_signature) -> None:
  357. self.fn = fn
  358. def __call__(self, *args, **kwargs):
  359. return self.fn(*args, **kwargs)
  360. def __eq__(self, other: object) -> bool:
  361. if not isinstance(other, _MaskModWrapper):
  362. return False
  363. if self.fn is other.fn:
  364. return True
  365. if (
  366. inspect.isfunction(self.fn)
  367. and inspect.isfunction(other.fn)
  368. and self.fn.__code__ == other.fn.__code__
  369. and _closure_contents(self.fn) == _closure_contents(other.fn)
  370. ):
  371. return True
  372. return False
  373. def __hash__(self) -> int:
  374. if inspect.isfunction(self.fn):
  375. return hash(self.fn.__code__)
  376. return hash(self.fn)
  377. def __repr__(self) -> str:
  378. return f"_MaskModWrapper({self.fn})"
  379. class BlockMask:
  380. r"""
  381. BlockMask is our format for representing a block-sparse attention mask.
  382. It is somewhat of a cross in-between BCSR and a non-sparse format.
  383. **Basics**
  384. A block-sparse mask means that instead of representing the sparsity of
  385. individual elements in the mask, a KV_BLOCK_SIZE x Q_BLOCK_SIZE block is
  386. considered sparse only if every element within that block is sparse.
  387. This aligns well with hardware, which generally expects to perform
  388. contiguous loads and computation.
  389. This format is primarily optimized for 1. simplicity, and 2. kernel
  390. efficiency. Notably, it is *not* optimized for size, as this mask is always
  391. reduced by a factor of KV_BLOCK_SIZE * Q_BLOCK_SIZE. If the size is a
  392. concern, the tensors can be reduced in size by increasing the block size.
  393. The essentials of our format are:
  394. num_blocks_in_row: Tensor[ROWS]:
  395. Describes the number of blocks present in each row.
  396. col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]:
  397. `col_indices[i]` is the sequence of block positions for row i. The values of
  398. this row after `col_indices[i][num_blocks_in_row[i]]` are undefined.
  399. For example, to reconstruct the original tensor from this format:
  400. .. code-block:: python
  401. dense_mask = torch.zeros(ROWS, COLS)
  402. for row in range(ROWS):
  403. for block_idx in range(num_blocks_in_row[row]):
  404. dense_mask[row, col_indices[row, block_idx]] = 1
  405. Notably, this format makes it easier to implement a reduction along the
  406. *rows* of the mask.
  407. **Details**
  408. The basics of our format require only kv_num_blocks and kv_indices. But, we
  409. have up to 8 tensors on this object. This represents 4 pairs:
  410. 1. (kv_num_blocks, kv_indices): Used for the forwards pass of attention, as
  411. we reduce along the KV dimension.
  412. 2. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): This is optional and
  413. purely an optimization. As it turns out, applying masking to every block
  414. is quite expensive! If we specifically know which blocks are "full" and
  415. don't require masking at all, then we can skip applying mask_mod to these
  416. blocks. This requires the user to split out a separate mask_mod from the
  417. score_mod. For causal masks, this is about a 15% speedup.
  418. 3. [GENERATED] (q_num_blocks, q_indices): Required for the backwards pass,
  419. as computing dKV requires iterating along the mask along the Q dimension. These are autogenerated from 1.
  420. 4. [GENERATED] (full_q_num_blocks, full_q_indices): Same as above, but for
  421. the backwards pass. These are autogenerated from 2.
  422. """
  423. seq_lengths: tuple[int, int]
  424. kv_num_blocks: Tensor
  425. kv_indices: Tensor
  426. full_kv_num_blocks: Tensor | None
  427. full_kv_indices: Tensor | None
  428. q_num_blocks: Tensor | None
  429. q_indices: Tensor | None
  430. full_q_num_blocks: Tensor | None
  431. full_q_indices: Tensor | None
  432. BLOCK_SIZE: tuple[int, int]
  433. mask_mod: _mask_mod_signature
  434. # Attribute lists for pytree flatten/unflatten
  435. _TENSOR_ATTRS = [
  436. "kv_num_blocks",
  437. "kv_indices",
  438. "full_kv_num_blocks",
  439. "full_kv_indices",
  440. "q_num_blocks",
  441. "q_indices",
  442. "full_q_num_blocks",
  443. "full_q_indices",
  444. ]
  445. _CONTEXT_ATTRS = [
  446. "seq_lengths",
  447. "BLOCK_SIZE",
  448. "mask_mod",
  449. ]
  450. def __init__(
  451. self,
  452. seq_lengths: tuple[int, int],
  453. kv_num_blocks: Tensor,
  454. kv_indices: Tensor,
  455. full_kv_num_blocks: Tensor | None,
  456. full_kv_indices: Tensor | None,
  457. q_num_blocks: Tensor | None,
  458. q_indices: Tensor | None,
  459. full_q_num_blocks: Tensor | None,
  460. full_q_indices: Tensor | None,
  461. BLOCK_SIZE: tuple[int, int],
  462. mask_mod: _mask_mod_signature,
  463. ) -> None:
  464. if kv_indices.dim() < 2:
  465. raise RuntimeError("BlockMask must have at least 2 dimensions")
  466. if kv_num_blocks is None:
  467. raise AssertionError("kv_num_blocks must be provided")
  468. if kv_indices is None:
  469. raise AssertionError("kv_indices must be provided")
  470. if (full_kv_num_blocks is None) != (full_kv_indices is None):
  471. raise AssertionError(
  472. "full_kv_num_blocks and full_kv_indices must be both provided or omitted"
  473. )
  474. if (full_q_num_blocks is None) != (full_q_indices is None):
  475. raise AssertionError(
  476. "full_q_num_blocks and full_q_indices must be both provided or omitted"
  477. )
  478. self.seq_lengths = seq_lengths
  479. self.kv_num_blocks = kv_num_blocks
  480. self.kv_indices = kv_indices
  481. self.full_kv_num_blocks = full_kv_num_blocks
  482. self.full_kv_indices = full_kv_indices
  483. self.q_num_blocks = q_num_blocks
  484. self.q_indices = q_indices
  485. self.full_q_num_blocks = full_q_num_blocks
  486. self.full_q_indices = full_q_indices
  487. self.BLOCK_SIZE = BLOCK_SIZE
  488. self.mask_mod = mask_mod
  489. @classmethod
  490. def from_kv_blocks(
  491. cls,
  492. kv_num_blocks: Tensor,
  493. kv_indices: Tensor,
  494. full_kv_num_blocks: Tensor | None = None,
  495. full_kv_indices: Tensor | None = None,
  496. BLOCK_SIZE: int | tuple[int, int] = _DEFAULT_SPARSE_BLOCK_SIZE,
  497. mask_mod: _mask_mod_signature | None = None,
  498. seq_lengths: tuple[int, int] | None = None,
  499. compute_q_blocks: bool = True,
  500. ):
  501. """
  502. Creates a BlockMask instance from key-value block information.
  503. Args:
  504. kv_num_blocks (Tensor): Number of kv_blocks in each Q_BLOCK_SIZE row tile.
  505. kv_indices (Tensor): Indices of key-value blocks in each Q_BLOCK_SIZE row tile.
  506. full_kv_num_blocks (Optional[Tensor]): Number of full kv_blocks in each Q_BLOCK_SIZE row tile.
  507. full_kv_indices (Optional[Tensor]): Indices of full key-value blocks in each Q_BLOCK_SIZE row tile.
  508. BLOCK_SIZE (Union[int, tuple[int, int]]): Size of KV_BLOCK_SIZE x Q_BLOCK_SIZE tiles.
  509. mask_mod (Optional[Callable]): Function to modify the mask.
  510. Returns:
  511. BlockMask: Instance with full Q information generated via _transposed_ordered
  512. Raises:
  513. RuntimeError: If kv_indices has < 2 dimensions.
  514. AssertionError: If only one of full_kv_* args is provided.
  515. """
  516. if kv_indices.dim() < 2:
  517. raise RuntimeError("BlockMask must have at least 2 dimensions")
  518. if (full_kv_num_blocks is None) != (full_kv_indices is None):
  519. raise AssertionError(
  520. "full_kv_num_blocks and full_kv_indices must be both provided or omitted"
  521. )
  522. # Generate q_num_blocks and q_indices
  523. if compute_q_blocks:
  524. q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices)
  525. if full_kv_num_blocks is not None:
  526. if full_kv_indices is None:
  527. raise AssertionError("full_kv_indices must not be None")
  528. full_q_num_blocks, full_q_indices = _transpose_ordered(
  529. full_kv_num_blocks, full_kv_indices
  530. )
  531. else:
  532. full_q_num_blocks, full_q_indices = None, None
  533. else:
  534. q_num_blocks, q_indices = None, None
  535. full_q_num_blocks, full_q_indices = None, None
  536. if isinstance(BLOCK_SIZE, int):
  537. BLOCK_SIZE = (BLOCK_SIZE, BLOCK_SIZE)
  538. mask_mod = mask_mod if mask_mod is not None else noop_mask
  539. if seq_lengths is None:
  540. q_length = kv_indices.shape[-2] * BLOCK_SIZE[0]
  541. kv_length = kv_indices.shape[-1] * BLOCK_SIZE[1]
  542. seq_lengths = (q_length, kv_length)
  543. return cls(
  544. seq_lengths=seq_lengths,
  545. kv_num_blocks=kv_num_blocks,
  546. kv_indices=kv_indices,
  547. full_kv_num_blocks=full_kv_num_blocks,
  548. full_kv_indices=full_kv_indices,
  549. q_num_blocks=q_num_blocks,
  550. q_indices=q_indices,
  551. full_q_num_blocks=full_q_num_blocks,
  552. full_q_indices=full_q_indices,
  553. BLOCK_SIZE=BLOCK_SIZE,
  554. mask_mod=mask_mod,
  555. )
  556. def as_tuple(self, flatten: bool = True):
  557. """
  558. Returns a tuple of the attributes of the BlockMask.
  559. Args:
  560. flatten (bool): If True, it will flatten the tuple of (KV_BLOCK_SIZE, Q_BLOCK_SIZE)
  561. """
  562. if flatten:
  563. block_size = (self.BLOCK_SIZE[0], self.BLOCK_SIZE[1]) # type: ignore[assignment]
  564. seq_lengths = (self.seq_lengths[0], self.seq_lengths[1]) # type: ignore[assignment]
  565. else:
  566. block_size = (self.BLOCK_SIZE,) # type: ignore[assignment]
  567. seq_lengths = (self.seq_lengths,) # type: ignore[assignment]
  568. # pyrefly: ignore [not-iterable]
  569. return (
  570. *seq_lengths,
  571. self.kv_num_blocks,
  572. self.kv_indices,
  573. self.full_kv_num_blocks,
  574. self.full_kv_indices,
  575. self.q_num_blocks,
  576. self.q_indices,
  577. self.full_q_num_blocks,
  578. self.full_q_indices,
  579. *block_size,
  580. self.mask_mod,
  581. )
  582. @property
  583. def shape(self):
  584. *batch_dims, _, _ = self.kv_indices.shape
  585. return tuple(batch_dims) + self.seq_lengths
  586. def __str__(self) -> str:
  587. s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n"
  588. mask_str = self.to_string().strip()
  589. s += mask_str
  590. s += "\n)"
  591. return s
  592. def __getitem__(self, index) -> "BlockMask":
  593. """
  594. Returns a new BlockMask instance by getting the mask for the given index position.
  595. Args:
  596. index: Index to apply to all attributes.
  597. Example Usage:
  598. .. code-block:: python
  599. def causal_mask(b, h, q_idx, kv_idx):
  600. return q_idx >= kv_idx
  601. block_mask = create_block_mask(
  602. causal_mask, 4, 2, 512, 512, device="cuda"
  603. )
  604. assert block_mask.kv_num_blocks.shape == (4, 2, 4)
  605. assert block_mask.kv_indices.shape == (4, 2, 4, 4)
  606. # Index on batch dimension
  607. new_block_mask = block_mask[0]
  608. assert new_block_mask.kv_num_blocks.shape == (2, 4)
  609. assert new_block_mask.kv_indices.shape == (2, 4, 4)
  610. # Index on batch and head dimension
  611. new_block_mask = block_mask[0, 1]
  612. assert new_block_mask.kv_num_blocks.shape == (4,)
  613. assert new_block_mask.kv_indices.shape == (4, 4)
  614. # slicing on batch and head dimension
  615. new_block_mask = block_mask[0:2, 1:2]
  616. assert new_block_mask.kv_num_blocks.shape == (2, 1, 4)
  617. assert new_block_mask.kv_indices.shape == (2, 1, 4, 4)
  618. # slicing on batch, head, and query dimension
  619. new_block_mask = block_mask[
  620. 0:2, 1:2, torch.tensor([1], dtype=torch.int32)
  621. ]
  622. assert new_block_mask.kv_num_blocks.shape == (2, 1, 1)
  623. assert new_block_mask.kv_indices.shape == (2, 1, 1, 4)
  624. """
  625. index = (index,) if not isinstance(index, tuple) else index
  626. padded = (*index, slice(None), slice(None), slice(None))[:3]
  627. sizes = self.kv_num_blocks.shape[:3]
  628. index = tuple(
  629. (slice(i + n, i + n + 1) if -n <= i < 0 else slice(i, i + 1))
  630. if isinstance(i, int)
  631. else i
  632. for i, n in zip(padded, sizes, strict=True)
  633. )
  634. new_kv_num_blocks = self.kv_num_blocks[index]
  635. new_kv_indices = self.kv_indices[index]
  636. if self.full_kv_num_blocks is not None:
  637. if self.full_kv_indices is None:
  638. raise AssertionError("full_kv_indices must not be None")
  639. new_full_kv_num_blocks = self.full_kv_num_blocks[index]
  640. new_full_kv_indices = self.full_kv_indices[index]
  641. else:
  642. new_full_kv_num_blocks = None
  643. new_full_kv_indices = None
  644. return BlockMask.from_kv_blocks(
  645. new_kv_num_blocks,
  646. new_kv_indices,
  647. new_full_kv_num_blocks,
  648. new_full_kv_indices,
  649. BLOCK_SIZE=self.BLOCK_SIZE,
  650. mask_mod=_sliced_mask_mod_error,
  651. seq_lengths=self.seq_lengths,
  652. compute_q_blocks=self.q_indices is not None,
  653. )
  654. def __repr__(self) -> str:
  655. def shape_or_none(x: torch.Tensor | None):
  656. return x.shape if x is not None else None
  657. return (
  658. f"BlockMask(\n"
  659. f" kv_num_blocks={self.kv_num_blocks.shape},\n"
  660. f" kv_indices={self.kv_indices.shape},\n"
  661. f" full_kv_num_blocks={shape_or_none(self.full_kv_num_blocks)},\n"
  662. f" full_kv_indices={shape_or_none(self.full_kv_indices)},\n"
  663. f" q_num_blocks={shape_or_none(self.q_num_blocks)},\n"
  664. f" q_indices={shape_or_none(self.q_indices)},\n"
  665. f" full_q_num_blocks={shape_or_none(self.full_q_num_blocks)},\n"
  666. f" full_q_indices={shape_or_none(self.full_q_indices)},\n"
  667. f" BLOCK_SIZE={self.BLOCK_SIZE},\n"
  668. f" shape={self.shape},\n"
  669. f" sparsity={self.sparsity():.2f}%,\n"
  670. f" mask_mod={self.mask_mod.__name__ if hasattr(self.mask_mod, '__name__') else self.mask_mod}\n"
  671. f")"
  672. )
  673. def _adjust(self, new_q_len: int, new_kv_len: int):
  674. new_num_rows = (new_q_len + self.BLOCK_SIZE[0] - 1) // self.BLOCK_SIZE[0]
  675. new_num_cols = (new_kv_len + self.BLOCK_SIZE[1] - 1) // self.BLOCK_SIZE[1]
  676. new_kv_num_blocks, new_kv_indices = _adjust_num_blocks_and_indices(
  677. self.kv_num_blocks, self.kv_indices, new_num_rows, new_num_cols
  678. )
  679. if self.full_kv_num_blocks is not None:
  680. if self.full_kv_indices is None:
  681. raise AssertionError("full_kv_indices must not be None")
  682. (
  683. new_full_kv_num_blocks,
  684. new_full_kv_indices,
  685. ) = _adjust_num_blocks_and_indices(
  686. self.full_kv_num_blocks,
  687. self.full_kv_indices,
  688. new_num_rows,
  689. new_num_cols,
  690. )
  691. else:
  692. new_full_kv_num_blocks = None
  693. new_full_kv_indices = None
  694. return self.from_kv_blocks(
  695. new_kv_num_blocks,
  696. new_kv_indices,
  697. new_full_kv_num_blocks,
  698. new_full_kv_indices,
  699. self.BLOCK_SIZE,
  700. self.mask_mod,
  701. )
  702. def numel(self):
  703. """Returns the number of elements (not accounting for sparsity) in the mask."""
  704. shape = self.shape
  705. def _prod(xs):
  706. return functools.reduce(operator.mul, xs, 1)
  707. return _prod(shape)
  708. def sparsity(self) -> float:
  709. """Computes the percentage of blocks that are sparse (i.e. not computed)"""
  710. total_size = self.numel()
  711. computed_blocks = self.kv_num_blocks.sum()
  712. if self.full_kv_num_blocks is not None:
  713. computed_blocks += self.full_kv_num_blocks.sum()
  714. computed_size = computed_blocks.item() * self.BLOCK_SIZE[0] * self.BLOCK_SIZE[1]
  715. dense_ratio = computed_size / total_size
  716. return 100 * (1 - dense_ratio)
  717. def to_dense(self) -> Tensor:
  718. """Returns a dense block that is equivalent to the block mask."""
  719. partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices)
  720. if self.full_kv_num_blocks is not None:
  721. if self.full_kv_indices is None:
  722. raise AssertionError("full_kv_indices must not be None")
  723. # pyrefly: ignore [bad-return]
  724. return partial_dense | _ordered_to_dense(
  725. self.full_kv_num_blocks, self.full_kv_indices
  726. )
  727. return partial_dense
  728. def to_string(self, grid_size=(20, 20), limit=4):
  729. """Returns a string representation of the block mask. Quite nifty.
  730. If grid_size is -1, prints out an uncompressed version. Warning, it can be quite big!
  731. """
  732. dense_mask = self.to_dense()
  733. *batch_dims, num_rows, num_cols = dense_mask.shape
  734. if isinstance(grid_size, int):
  735. max_rows = grid_size
  736. max_cols = grid_size
  737. elif grid_size == -1:
  738. max_rows = num_rows
  739. max_cols = num_cols
  740. else:
  741. max_rows, max_cols = grid_size
  742. def create_block_vis(*batch_idx):
  743. descriptors = []
  744. descriptors.append(f"{batch_idx}")
  745. vis = ", ".join(reversed(descriptors)) + "\n"
  746. def summarize_section(section) -> str:
  747. percentage = section.float().mean().item()
  748. if percentage == 1:
  749. return "█"
  750. elif percentage == 0:
  751. return " "
  752. else:
  753. return "░"
  754. def cdiv(a, b):
  755. return (a + (b - 1)) // b
  756. row_step = max(1, cdiv(num_rows, max_rows))
  757. col_step = max(1, cdiv(num_cols, max_cols))
  758. for r in range(0, num_rows, row_step):
  759. for c in range(0, num_cols, col_step):
  760. cur_mask = dense_mask
  761. for idx in batch_idx:
  762. cur_mask = cur_mask[idx]
  763. char = summarize_section(
  764. cur_mask[r : r + row_step, c : c + col_step]
  765. )
  766. vis += char * 2
  767. vis += "\n"
  768. return vis
  769. total_vis = []
  770. for idx, batch_idx in enumerate(
  771. itertools.product(*[range(i) for i in batch_dims])
  772. ):
  773. if idx == limit:
  774. total_vis.append("...")
  775. total_vis.append("To print out more, set BlockMask.to_string(limit=N)")
  776. total_vis.append(
  777. "You can also index (BlockMask[batch, head]) to choose a specific batch or head"
  778. )
  779. break
  780. block_vis = create_block_vis(*batch_idx)
  781. total_vis.append(block_vis)
  782. return "\n".join(total_vis)
  783. def to(self, device: torch.device | str) -> "BlockMask":
  784. """Moves the BlockMask to the specified device.
  785. Args:
  786. device (torch.device or str): The target device to move the BlockMask to.
  787. Can be a torch.device object or a string (e.g., 'cpu', 'cuda:0').
  788. Returns:
  789. BlockMask: A new BlockMask instance with all tensor components moved
  790. to the specified device.
  791. Note:
  792. This method does not modify the original BlockMask in-place.
  793. Instead, it returns a new BlockMask instance where individual tensor attributes
  794. may or may not be moved to the specified device, depending on their
  795. current device placement.
  796. """
  797. mapped_attributes = tree_map_only(
  798. torch.Tensor,
  799. lambda x: x.to(device),
  800. self.as_tuple(flatten=False),
  801. )
  802. return BlockMask(*mapped_attributes)
  803. @staticmethod
  804. def _wrap_context_value(attr: str, value: Any) -> Any:
  805. if attr == "mask_mod":
  806. return _MaskModWrapper(value)
  807. return value
  808. @staticmethod
  809. def _unwrap_context_value(attr: str, value: Any) -> Any:
  810. if attr == "mask_mod":
  811. if not isinstance(value, _MaskModWrapper):
  812. raise AssertionError(f"Expected _MaskModWrapper, got {type(value)}")
  813. return value.fn
  814. return value
  815. def _flatten(self):
  816. """Flatten BlockMask into a list of tensors and context.
  817. Wraps mask_mod in _MaskModWrapper for value-based comparison in TreeSpec.
  818. """
  819. tensors = tuple(getattr(self, attr) for attr in self._TENSOR_ATTRS)
  820. context = tuple(
  821. self._wrap_context_value(attr, getattr(self, attr))
  822. for attr in self._CONTEXT_ATTRS
  823. )
  824. return tensors, context
  825. @classmethod
  826. def _unflatten(cls, tensors, context):
  827. """Unflatten tensors and context back into a BlockMask."""
  828. kwargs = {
  829. attr: cls._unwrap_context_value(attr, val)
  830. for attr, val in zip(cls._CONTEXT_ATTRS, context)
  831. }
  832. kwargs.update(zip(cls._TENSOR_ATTRS, tensors))
  833. # pyrefly: ignore [bad-argument-type]
  834. return cls(**kwargs)
  835. def _flatten_with_keys(self):
  836. """Flatten BlockMask with keys for better tracing.
  837. Wraps mask_mod in _MaskModWrapper for value-based comparison in TreeSpec.
  838. """
  839. tensors = tuple(
  840. (GetAttrKey(attr), getattr(self, attr)) for attr in self._TENSOR_ATTRS
  841. )
  842. context = tuple(
  843. (GetAttrKey(attr), self._wrap_context_value(attr, getattr(self, attr)))
  844. for attr in self._CONTEXT_ATTRS
  845. )
  846. return tensors, context
  847. def _broadcast_to_dim(x, dim):
  848. while x.dim() < dim:
  849. x = x.unsqueeze(0)
  850. return x
  851. def _round_up_to_multiple(x, multiple):
  852. return (x + multiple - 1) // multiple * multiple
  853. def _convert_mask_to_block_mask(
  854. mask: Tensor,
  855. Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
  856. KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
  857. separate_full_blocks: bool = False,
  858. ) -> tuple[Tensor, Tensor | None]:
  859. if mask.dtype != torch.bool:
  860. raise AssertionError(f"mask.dtype must be torch.bool, got {mask.dtype}")
  861. mask = _broadcast_to_dim(mask, 4)
  862. def padding_needed_for_multiple(x, multiple):
  863. return _round_up_to_multiple(x, multiple) - x
  864. mask = torch.nn.functional.pad(
  865. mask,
  866. (
  867. 0,
  868. padding_needed_for_multiple(mask.shape[-1], KV_BLOCK_SIZE),
  869. 0,
  870. padding_needed_for_multiple(mask.shape[-2], Q_BLOCK_SIZE),
  871. ),
  872. )
  873. B, H, Q, KV = mask.shape
  874. if Q % Q_BLOCK_SIZE != 0:
  875. raise AssertionError(
  876. f"Q ({Q}) must be divisible by Q_BLOCK_SIZE ({Q_BLOCK_SIZE})"
  877. )
  878. if KV % KV_BLOCK_SIZE != 0:
  879. raise AssertionError(
  880. f"KV ({KV}) must be divisible by KV_BLOCK_SIZE ({KV_BLOCK_SIZE})"
  881. )
  882. mask = mask.view(
  883. B, H, Q // Q_BLOCK_SIZE, Q_BLOCK_SIZE, KV // KV_BLOCK_SIZE, KV_BLOCK_SIZE
  884. ) # [B, H, Q//Q_BLOCK_SIZE, Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE, KV_BLOCK_SIZE]
  885. mask = mask.permute(
  886. 0, 1, 2, 4, 3, 5
  887. ) # [B, H, Q//Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE, Q_BLOCK_SIZE, KV_BLOCK_SIZE]
  888. mask_block_sum = mask.sum(
  889. dim=[-2, -1]
  890. ) # [B, H, Q//Q_BLOCK_SIZE, KV//KV_BLOCK_SIZE]
  891. if separate_full_blocks:
  892. full_block_sum = Q_BLOCK_SIZE * KV_BLOCK_SIZE
  893. full_blocks = mask_block_sum == full_block_sum
  894. partial_blocks = (mask_block_sum > 0) & (mask_block_sum < full_block_sum)
  895. partial_blocks = partial_blocks.to(dtype=torch.int8)
  896. full_blocks = full_blocks.to(dtype=torch.int8)
  897. return partial_blocks, full_blocks
  898. else:
  899. partial_blocks = mask_block_sum > 0
  900. partial_blocks = partial_blocks.to(dtype=torch.int8)
  901. return partial_blocks, None
  902. def or_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature:
  903. """Returns a mask_mod that's the union of provided mask_mods"""
  904. if not all(callable(arg) for arg in mask_mods):
  905. raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}")
  906. def or_mask(b, h, q_idx, kv_idx):
  907. result = b.new_zeros((), dtype=torch.bool)
  908. for mask in mask_mods:
  909. result = result | mask(b, h, q_idx, kv_idx)
  910. return result
  911. return or_mask
  912. def and_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature:
  913. """Returns a mask_mod that's the intersection of provided mask_mods"""
  914. if not all(callable(arg) for arg in mask_mods):
  915. raise RuntimeError(f"All inputs should be callable mask_mods: {mask_mods}")
  916. def and_mask(b, h, q_idx, kv_idx):
  917. result = b.new_ones((), dtype=torch.bool)
  918. for mask in mask_mods:
  919. result = result & mask(b, h, q_idx, kv_idx)
  920. return result
  921. return and_mask
  922. def _convert_block_mask_to_mask(
  923. block_mask,
  924. KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
  925. Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
  926. ) -> Tensor:
  927. if block_mask.dim() != 4:
  928. raise AssertionError(f"block_mask.dim() must be 4, got {block_mask.dim()}")
  929. B, H, Q, KV = block_mask.shape
  930. block_mask = block_mask.expand(Q_BLOCK_SIZE, KV_BLOCK_SIZE, *block_mask.shape)
  931. block_mask = block_mask.permute(2, 3, 4, 0, 5, 1).reshape(
  932. B, H, Q * Q_BLOCK_SIZE, KV * KV_BLOCK_SIZE
  933. )
  934. return block_mask
  935. def _create_sparse_block_from_block_mask(
  936. block_mask: tuple[Tensor, Tensor | None],
  937. mask_mod: Callable | None,
  938. seq_lengths: tuple[int, int],
  939. Q_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
  940. KV_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
  941. ) -> BlockMask:
  942. partial_blocks, full_blocks = block_mask
  943. partial_bm = _dense_to_ordered(partial_blocks)
  944. if full_blocks is not None:
  945. full_bm: tuple[Tensor | None, Tensor | None] = _dense_to_ordered(full_blocks)
  946. else:
  947. full_bm = (None, None)
  948. return BlockMask.from_kv_blocks(
  949. partial_bm[0],
  950. partial_bm[1],
  951. full_bm[0],
  952. full_bm[1],
  953. BLOCK_SIZE=(Q_BLOCK_SIZE, KV_BLOCK_SIZE),
  954. mask_mod=mask_mod,
  955. seq_lengths=seq_lengths,
  956. )
  957. def create_mask(
  958. mod_fn: _score_mod_signature | _mask_mod_signature,
  959. B: int | None,
  960. H: int | None,
  961. Q_LEN: int,
  962. KV_LEN: int,
  963. device: DeviceLikeType | None = None,
  964. ) -> Tensor:
  965. r"""This function creates a mask tensor from a mod_fn function.
  966. Args:
  967. mod_fn (Union[_score_mod_signature, _mask_mod_signature]): Function to modify attention scores.
  968. B (int): Batch size.
  969. H (int): Number of query heads.
  970. Q_LEN (int): Sequence length of query.
  971. KV_LEN (int): Sequence length of key/value.
  972. device (str): Device to run the mask creation on.
  973. Returns:
  974. mask (Tensor): A mask tensor with shape (B, H, M, N).
  975. """
  976. if device is None:
  977. device = torch.accelerator.current_accelerator() or "cpu"
  978. if B is None:
  979. B = 1
  980. if H is None:
  981. H = 1
  982. b = torch.arange(0, B, device=device)
  983. h = torch.arange(0, H, device=device)
  984. m = torch.arange(0, Q_LEN, device=device)
  985. n = torch.arange(0, KV_LEN, device=device)
  986. mod_type = _get_mod_type(mod_fn)
  987. from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
  988. with TransformGetItemToIndex():
  989. if mod_type == _ModificationType.SCORE_MOD:
  990. score_mod = mod_fn
  991. score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,)) # first input is score
  992. out = score_mod(torch.zeros(B, H, Q_LEN, KV_LEN, device=device), b, h, m, n)
  993. mask = torch.where(torch.isneginf(out), False, True)
  994. return mask
  995. elif mod_type == _ModificationType.MASK_MOD:
  996. mask_mod = mod_fn
  997. mask_mod = _vmap_for_bhqkv(mask_mod, prefix=())
  998. mask = mask_mod(b, h, m, n)
  999. return mask
  1000. else:
  1001. raise AssertionError
  1002. def create_block_mask(
  1003. mask_mod: _mask_mod_signature,
  1004. B: int | None,
  1005. H: int | None,
  1006. Q_LEN: int,
  1007. KV_LEN: int,
  1008. device: DeviceLikeType | None = None,
  1009. BLOCK_SIZE: int | tuple[int, int] = _DEFAULT_SPARSE_BLOCK_SIZE,
  1010. _compile=False,
  1011. ) -> BlockMask:
  1012. r"""This function creates a block mask tuple from a mask_mod function.
  1013. Args:
  1014. mask_mod (Callable): mask_mod function. This is a callable that defines the
  1015. masking pattern for the attention mechanism. It takes four arguments:
  1016. b (batch size), h (number of heads), q_idx (query index), and kv_idx (key/value index).
  1017. It should return a boolean tensor indicating which attention connections are allowed (True)
  1018. or masked out (False).
  1019. B (int): Batch size.
  1020. H (int): Number of query heads.
  1021. Q_LEN (int): Sequence length of query.
  1022. KV_LEN (int): Sequence length of key/value.
  1023. device (str): Device to run the mask creation on.
  1024. BLOCK_SIZE (int or tuple[int, int]): Block size for the block mask. If a single int is provided it is used for both query and key/value.
  1025. Returns:
  1026. BlockMask: A BlockMask object that contains the block mask information.
  1027. Example Usage:
  1028. .. code-block:: python
  1029. def causal_mask(b, h, q_idx, kv_idx):
  1030. return q_idx >= kv_idx
  1031. block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda")
  1032. query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
  1033. key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
  1034. value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
  1035. output = flex_attention(query, key, value, block_mask=block_mask)
  1036. """
  1037. if device is None:
  1038. device = torch.accelerator.current_accelerator() or "cpu"
  1039. mod_type = _get_mod_type(mask_mod)
  1040. if mod_type != _ModificationType.MASK_MOD:
  1041. raise AssertionError(
  1042. f"create-block_mask requires a mask_mod function! Got {mask_mod}"
  1043. )
  1044. if B is None:
  1045. B = 1
  1046. if H is None:
  1047. H = 1
  1048. if isinstance(BLOCK_SIZE, int):
  1049. Q_BLOCK_SIZE = BLOCK_SIZE
  1050. KV_BLOCK_SIZE = BLOCK_SIZE
  1051. else:
  1052. Q_BLOCK_SIZE, KV_BLOCK_SIZE = BLOCK_SIZE
  1053. if _compile:
  1054. warnings.warn(
  1055. "_compile flag on create_block_mask was originally added to work around a torch.compile limitation. That limitation has since been addressed. So, to compile create_block_mask, we suggest doing torch.compile(create_block_mask). This still works for now, but will be removed in the future.",
  1056. DeprecationWarning,
  1057. stacklevel=2,
  1058. )
  1059. return torch.compile(create_block_mask)(
  1060. mask_mod, B, H, Q_LEN, KV_LEN, device, BLOCK_SIZE
  1061. )
  1062. mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device)
  1063. partial_block_mask, full_block_mask = _convert_mask_to_block_mask(
  1064. mask_tensor,
  1065. Q_BLOCK_SIZE=Q_BLOCK_SIZE,
  1066. KV_BLOCK_SIZE=KV_BLOCK_SIZE,
  1067. separate_full_blocks=True,
  1068. )
  1069. block_mask = _create_sparse_block_from_block_mask(
  1070. (partial_block_mask, full_block_mask),
  1071. mask_mod,
  1072. (Q_LEN, KV_LEN),
  1073. Q_BLOCK_SIZE,
  1074. KV_BLOCK_SIZE,
  1075. )
  1076. return block_mask
  1077. def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask:
  1078. r"""Default block mask for flex attention.
  1079. If users don't specify any block sparse mask info, we create this
  1080. empty block sparse mask. Which creates a BlockMask with 1 block that is the full length
  1081. of the query and key tensors.
  1082. """
  1083. device = query.device
  1084. return BlockMask.from_kv_blocks(
  1085. kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device),
  1086. kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device),
  1087. BLOCK_SIZE=_LARGE_SPARSE_BLOCK_SIZE,
  1088. seq_lengths=(1, 1),
  1089. )
  1090. def _apply_kernel_options(
  1091. query: Tensor,
  1092. key: Tensor,
  1093. value: Tensor,
  1094. return_lse: bool,
  1095. kernel_options,
  1096. return_aux: AuxRequest | None = None,
  1097. ):
  1098. kernel_options = {} if kernel_options is None else dict(kernel_options)
  1099. if "BACKEND" in kernel_options and kernel_options.get(
  1100. "FORCE_USE_FLEX_ATTENTION", False
  1101. ):
  1102. # TODO: remove FORCE_USE_FLEX_ATTENTION once BACKEND is fully adopted.
  1103. raise RuntimeError(
  1104. "BACKEND cannot be combined with legacy FORCE_USE_FLEX_ATTENTION. "
  1105. "BACKEND supersedes the legacy knob; please drop FORCE_USE_FLEX_ATTENTION "
  1106. "and only specify the desired BACKEND."
  1107. )
  1108. if "BACKEND" in kernel_options:
  1109. valid_backends = typing.get_args(_Backend)
  1110. if kernel_options["BACKEND"] not in valid_backends:
  1111. raise ValueError(
  1112. f"Invalid BACKEND value '{kernel_options['BACKEND']}'. "
  1113. f"Must be one of {valid_backends}"
  1114. )
  1115. kernel_options.setdefault("BACKEND", "AUTO")
  1116. kernel_options.setdefault("PRESCALE_QK", False)
  1117. kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False)
  1118. kernel_options.setdefault("BLOCKS_ARE_CONTIGUOUS", False)
  1119. # This forces all biases grad scatters to be done in the DQ iteration loop of the backwards
  1120. kernel_options.setdefault("WRITE_DQ", True)
  1121. any_inputs_on_cpu_device = (
  1122. query.device.type == "cpu"
  1123. or key.device.type == "cpu"
  1124. or value.device.type == "cpu"
  1125. )
  1126. # Determine what auxiliary outputs are needed
  1127. output_lse = return_lse
  1128. output_max = False
  1129. if return_aux is not None:
  1130. # New API takes precedence over legacy parameters
  1131. output_lse = return_aux.lse
  1132. output_max = return_aux.max_scores
  1133. # If forward kernel needs to return logsumexp is decided by this rule internally.
  1134. if "OUTPUT_LOGSUMEXP" in kernel_options:
  1135. raise AssertionError("OUTPUT_LOGSUMEXP must not be in kernel_options")
  1136. kernel_options["OUTPUT_LOGSUMEXP"] = True
  1137. if not output_lse:
  1138. # We used to check if q,k,v required grads but since captured buffers can require grad
  1139. # we always write unless in no_grad
  1140. kernel_options["OUTPUT_LOGSUMEXP"] = torch.is_grad_enabled()
  1141. if any_inputs_on_cpu_device:
  1142. # CPU with torch.compile now supports inference, and will not return lse
  1143. # TODO: support CPU for training and return lse
  1144. kernel_options["OUTPUT_LOGSUMEXP"] = False
  1145. # If forward kernel needs to return max is decided by this rule internally.
  1146. if "OUTPUT_MAX" in kernel_options:
  1147. raise AssertionError("OUTPUT_MAX must not be in kernel_options")
  1148. kernel_options["OUTPUT_MAX"] = output_max
  1149. if any_inputs_on_cpu_device and output_max:
  1150. # CPU doesn't support returning max yet
  1151. # TODO: support CPU for returning max
  1152. raise NotImplementedError("Returning max scores is not supported on CPU.")
  1153. kernel_options["OUTPUT_MAX"] = False
  1154. return kernel_options
  1155. def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor) -> None:
  1156. if query.size(-1) != key.size(-1):
  1157. raise ValueError(
  1158. f"Expect query and key/value to have the same embedding dimension "
  1159. f"but got E={query.size(-1)} and E={key.size(-1)}."
  1160. )
  1161. def _validate_device(query: Tensor, key: Tensor, value: Tensor) -> None:
  1162. """TODO: Remove once non cuda/cpu devices support is added
  1163. We only need to check query since we have already that q,k,v are on the same device
  1164. """
  1165. if query.device.type == "cpu" and (
  1166. query.requires_grad or key.requires_grad or value.requires_grad
  1167. ):
  1168. raise NotImplementedError(
  1169. "FlexAttention does not support backward on CPU. Please set the input requires_grad to False or use another device."
  1170. )
  1171. supported_devices = {"cuda", "cpu", "xpu", "hpu"}
  1172. if query.device.type not in supported_devices:
  1173. raise ValueError(
  1174. "FlexAttention is only supported on CUDA, CPU or HPU devices. "
  1175. f"Found input tensors on {query.device.type} device."
  1176. )
  1177. def _enforce_mem_layouts(
  1178. query: Tensor, key: Tensor, value: Tensor
  1179. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  1180. """
  1181. Enforce memory layouts for query, key, and value tensors.
  1182. For non-FP8 dtypes, no action is taken.
  1183. For FP8 dtypes, we enforce the following memory layouts:
  1184. - Query tensor must be in row-major memory layout, as it will be the left-operand in the FP8 GEMM `q @ k.T`.
  1185. - Key tensor must be in row-major memory layout, as it will be transposed when used as the right-operand
  1186. in the FP8 GEMM `q @ k.T`, meaning it will correctly be in column-major memory layout for the GEMM.
  1187. - Value tensor must be in column-major memory layout, as it will be the right-operand in the FP8 GEMM `softmax_scores @ v`.
  1188. Returns the query, key, and value tensors with the enforced memory layouts.
  1189. """
  1190. def is_row_major(tensor: Tensor) -> bool:
  1191. return tensor.stride()[-1] == 1
  1192. def is_col_major(tensor: Tensor) -> bool:
  1193. return tensor.stride()[-2] == 1
  1194. # These memory layout constraint are only for FP8 GEMMs on NVIDIA GPU architectures >= SM89 and < SM100.
  1195. # This is because GPU arch < SM89 does not not support FP8 GEMMs, and
  1196. # SM100 has support for TN, NT, TT, NN layouts for FP8 GEMMs
  1197. # (i.e., left and right operands can be in row or column major layouts)
  1198. # so this check is only needed for older architectures.
  1199. # See: https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md
  1200. fp8_dtypes = (
  1201. torch.float8_e4m3fn,
  1202. torch.float8_e5m2,
  1203. )
  1204. gemm_precision = query.dtype
  1205. should_enforce_mem_layout = (
  1206. gemm_precision in fp8_dtypes
  1207. and torch.version.cuda is not None
  1208. and torch.cuda.get_device_capability("cuda") >= (8, 9)
  1209. and torch.cuda.get_device_capability("cuda") < (10, 0)
  1210. )
  1211. if not should_enforce_mem_layout:
  1212. return query, key, value
  1213. # Query must be in row-major memory layout as the left-operand in the FP8 GEMM `q @ k.T`
  1214. if not is_row_major(query):
  1215. query = query.contiguous()
  1216. # Key must be in row-major memory layout as it will be transposed when used as the right-operand
  1217. # in the FP8 GEMM `q @ k.T`, meaning it will correctly be in column-major memory layout for the GEMM.
  1218. if not is_row_major(key):
  1219. key = key.contiguous()
  1220. # Value must be in column-major memory layout as the right-operand in the FP8 GEMM `softmax_scores @ v`
  1221. if not is_col_major(value):
  1222. value = value.transpose(-2, -1).contiguous().transpose(-2, -1)
  1223. return query, key, value
  1224. def flex_attention(
  1225. query: Tensor,
  1226. key: Tensor,
  1227. value: Tensor,
  1228. score_mod: _score_mod_signature | None = None,
  1229. block_mask: BlockMask | None = None,
  1230. scale: float | None = None,
  1231. enable_gqa: bool = False,
  1232. return_lse: bool = False,
  1233. kernel_options: FlexKernelOptions | None = None,
  1234. *,
  1235. return_aux: AuxRequest | None = None,
  1236. ) -> Tensor | tuple[Tensor, Tensor] | tuple[Tensor, AuxOutput]:
  1237. r"""This function implements scaled dot product attention with an arbitrary attention score modification function
  1238. described in the `Flex Attention <https://arxiv.org/abs/2412.05496>`_ paper. See also the
  1239. `blog post <https://pytorch.org/blog/flexattention/>`_.
  1240. This function computes the scaled dot product attention between query, key, and value tensors with a user-defined
  1241. attention score modification function. The attention score modification function will be applied after the attention
  1242. scores have been calculated between the query and key tensors. The attention scores are calculated as follows:
  1243. The ``score_mod`` function should have the following signature:
  1244. .. code-block:: python
  1245. def score_mod(
  1246. score: Tensor,
  1247. batch: Tensor,
  1248. head: Tensor,
  1249. q_idx: Tensor,
  1250. k_idx: Tensor
  1251. ) -> Tensor:
  1252. Where:
  1253. - ``score``: A scalar tensor representing the attention score,
  1254. with the same data type and device as the query, key, and value tensors.
  1255. - ``batch``, ``head``, ``q_idx``, ``k_idx``: Scalar tensors indicating
  1256. the batch index, query head index, query index, and key/value index, respectively.
  1257. These should have the ``torch.int`` data type and be located on the same device as the score tensor.
  1258. Args:
  1259. query (Tensor): Query tensor; shape :math:`(B, Hq, L, E)`. For FP8 dtypes, should be in row-major memory layout for optimal performance.
  1260. key (Tensor): Key tensor; shape :math:`(B, Hkv, S, E)`. For FP8 dtypes, should be in row-major memory layout for optimal performance.
  1261. value (Tensor): Value tensor; shape :math:`(B, Hkv, S, Ev)`. For FP8 dtypes, should be in column-major memory layout for optimal performance.
  1262. score_mod (Optional[Callable]): Function to modify attention scores. By default no score_mod is applied.
  1263. block_mask (Optional[BlockMask]): BlockMask object that controls the blocksparsity pattern of the attention.
  1264. scale (Optional[float]): Scaling factor applied prior to softmax. If none, the default value is set to :math:`\frac{1}{\sqrt{E}}`.
  1265. enable_gqa (bool): If set to True, enables Grouped Query Attention (GQA) and broadcasts key/value heads to query heads.
  1266. return_lse (bool): Whether to return the logsumexp of the attention scores. Default is False. **Deprecated**: Use ``return_aux=AuxRequest(lse=True)`` instead.
  1267. kernel_options (Optional[FlexKernelOptions]):
  1268. Options to control the behavior of the underlying Triton kernels.
  1269. See :class:`FlexKernelOptions` for available options and usage examples.
  1270. return_aux (Optional[AuxRequest]): Specifies which auxiliary outputs to compute and return.
  1271. If None, only the attention output is returned. Use ``AuxRequest(lse=True, max_scores=True)``
  1272. to request both auxiliary outputs.
  1273. Returns:
  1274. output (Tensor): Attention output; shape :math:`(B, Hq, L, Ev)`.
  1275. When ``return_aux`` is not None:
  1276. aux (AuxOutput): Auxiliary outputs with requested fields populated.
  1277. When ``return_aux`` is None (deprecated paths):
  1278. lse (Tensor): Log-sum-exp of attention scores; shape :math:`(B, Hq, L)`. Only returned if ``return_lse=True``.
  1279. Shape legend:
  1280. - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
  1281. - :math:`S: \text{Source sequence length}`
  1282. - :math:`L: \text{Target sequence length}`
  1283. - :math:`E: \text{Embedding dimension of the query and key}`
  1284. - :math:`Ev: \text{Embedding dimension of the value}`
  1285. .. warning::
  1286. `torch.nn.attention.flex_attention` is a prototype feature in PyTorch.
  1287. Please look forward to a more stable implementation in a future version of PyTorch.
  1288. Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
  1289. """
  1290. # Some basic input validation
  1291. _validate_sdpa_input(query, key, value, allow_lowp_kv=True)
  1292. _validate_embed_dim(query, key, value)
  1293. _validate_device(query, key, value)
  1294. query, key, value = _enforce_mem_layouts(query, key, value)
  1295. if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
  1296. raise NotImplementedError("NYI: query, key, and value must be 4D tensors")
  1297. if (not enable_gqa) and query.size(-3) != key.size(-3):
  1298. raise ValueError(
  1299. f"Expect query and key/value to have the same number of heads "
  1300. f"but got Hq={query.size(-3)} and Hkv={key.size(-3)}. "
  1301. f"Try setting enable_gqa=True for GQA."
  1302. )
  1303. if enable_gqa:
  1304. Hq = query.size(1)
  1305. Hkv = key.size(1)
  1306. if Hq % Hkv != 0:
  1307. raise ValueError(
  1308. f"Expect number of query heads to be a multiple of kv heads for GQA "
  1309. f"but got Hq={Hq} and Hkv={Hkv}."
  1310. )
  1311. if query.size(0) != key.size(0):
  1312. if block_mask is None:
  1313. raise ValueError(
  1314. f"Expect query and key/value to have the same batch size, "
  1315. f"or non-none block_mask, "
  1316. f"but got block_mask=None, Bq={query.size(0)}, and Bkv={key.size(0)}."
  1317. )
  1318. if block_mask.kv_num_blocks.size(0) != query.size(0):
  1319. raise ValueError(
  1320. f"Expect query and key/value to have the same batch size, "
  1321. f"or block_mask and query to have the same batch size, "
  1322. f"but got Bq={query.size(0)}, Bkv={key.size(0)}, B_block_mask={block_mask.kv_num_blocks.size(0)}."
  1323. )
  1324. if score_mod is None:
  1325. score_mod = _identity
  1326. if block_mask is None:
  1327. block_mask = _create_empty_block_mask(query, key)
  1328. # If BlockMask was sliced, its mask_mod is intentionally replaced with an error-raising stub.
  1329. # This guard ensures we surface the intended error message before any shape-based checks.
  1330. if getattr(block_mask, "mask_mod", None) is _sliced_mask_mod_error:
  1331. raise RuntimeError("Cannot use mask_mod from a sliced BlockMask")
  1332. if (
  1333. block_mask.BLOCK_SIZE[0] == _LARGE_SPARSE_BLOCK_SIZE
  1334. and block_mask.BLOCK_SIZE[1] == _LARGE_SPARSE_BLOCK_SIZE
  1335. ):
  1336. # This corresponds to the case where we essentially have a "no-op" block mask.
  1337. pass
  1338. else:
  1339. block_mask_q_len = block_mask.shape[-2]
  1340. block_mask_kv_len = block_mask.shape[-1]
  1341. if query.size(-2) > block_mask_q_len or key.size(-2) > block_mask_kv_len:
  1342. raise ValueError(
  1343. f"block_mask was created for block_mask.shape={block_mask.shape} but got q_len={query.size(-2)} and kv_len={key.size(-2)}. "
  1344. "As the block mask was created for a smaller length than you're using it for, you likely need to create a new block mask."
  1345. )
  1346. elif (
  1347. query.size(-2) < block_mask_q_len and key.size(-2) <= block_mask_kv_len
  1348. ) or (query.size(-2) <= block_mask_q_len and key.size(-2) < block_mask_kv_len):
  1349. raise ValueError(
  1350. f"block_mask was created for block_mask.shape={block_mask.shape} but got q_len={query.size(-2)} and kv_len={key.size(-2)}. "
  1351. "As the block mask was created for a larger length than you're using it for, you can either 1. create a new block mask with the correct length, or 2. 'adjust' the existing block mask to the correct length by calling block_mask._adjust(q_len, kv_len). This essentially 'crops' the block mask to the upper left corner, which does not work for all mask_mods!"
  1352. )
  1353. if query.size(-2) != block_mask_q_len:
  1354. raise AssertionError(
  1355. f"query.size(-2) ({query.size(-2)}) != block_mask_q_len ({block_mask_q_len})"
  1356. )
  1357. if key.size(-2) != block_mask_kv_len:
  1358. raise AssertionError(
  1359. f"key.size(-2) ({key.size(-2)}) != block_mask_kv_len ({block_mask_kv_len})"
  1360. )
  1361. if scale is None:
  1362. scale = 1.0 / math.sqrt(query.size(-1))
  1363. if query.device != block_mask.kv_num_blocks.device: # type: ignore[union-attr]
  1364. raise RuntimeError(
  1365. f"Expect q/k/v and block_mask to be on the same device "
  1366. f"but got {query.device} and {block_mask.kv_num_blocks.device}." # type: ignore[union-attr]
  1367. )
  1368. # Handle deprecation warnings for old parameters
  1369. if return_lse and return_aux is not None:
  1370. raise ValueError(
  1371. "Cannot specify both return_lse and return_aux. "
  1372. "return_lse is deprecated, please use return_aux=AuxRequest(lse=True) instead."
  1373. )
  1374. elif return_lse and return_aux is None:
  1375. _warn_once(
  1376. "deprecated_return_lse",
  1377. "return_lse is deprecated and will be removed in v2.10. "
  1378. "Please use return_aux=AuxRequest(lse=True) instead.",
  1379. category=FutureWarning,
  1380. )
  1381. kernel_options = _apply_kernel_options(
  1382. query,
  1383. key,
  1384. value,
  1385. return_lse,
  1386. kernel_options,
  1387. return_aux,
  1388. )
  1389. def _finalize_outputs(
  1390. out,
  1391. lse,
  1392. max_scores,
  1393. *,
  1394. return_aux: AuxRequest | None,
  1395. return_lse: bool,
  1396. ):
  1397. """Normalize stats and build return value (aux-aware, legacy-compatible)."""
  1398. ln2 = math.log(2.0)
  1399. return_lse = return_lse or return_aux is not None and return_aux.lse
  1400. return_max = return_aux is not None and return_aux.max_scores
  1401. lse_scaled = lse * ln2 if (return_lse and lse.numel() > 0) else None
  1402. max_scaled = (
  1403. max_scores * ln2 if (return_max and max_scores.numel() > 0) else None
  1404. )
  1405. if return_aux is not None:
  1406. return out, AuxOutput(
  1407. lse=lse_scaled,
  1408. max_scores=max_scaled,
  1409. )
  1410. if return_lse:
  1411. return out, lse_scaled
  1412. return out
  1413. if torch.compiler.is_dynamo_compiling():
  1414. # mark head_dim and number of heads to be static
  1415. for x in [query, key, value]:
  1416. torch._dynamo.mark_static(x, -3)
  1417. torch._dynamo.mark_static(x, -1)
  1418. out, lse, max_scores = flex_attention_hop(
  1419. query,
  1420. key,
  1421. value,
  1422. score_mod,
  1423. block_mask.as_tuple(),
  1424. scale,
  1425. kernel_options, # type: ignore[union-attr]
  1426. )
  1427. return _finalize_outputs(
  1428. out, lse, max_scores, return_aux=return_aux, return_lse=return_lse
  1429. )
  1430. if not _FLEX_ATTENTION_DISABLE_COMPILE_DEBUG:
  1431. _warn_once(
  1432. warning_id="flex_attention_performance",
  1433. message=(
  1434. "flex_attention called without torch.compile() - this will use an unfused implementation that materializes the full scores matrix instead of generating a fused kernel.\n\n"
  1435. "SOLUTION: Use torch.compile(flex_attention)(...)\n\n"
  1436. "If you want to debug your score_mod/mask_mod, you can set:\n"
  1437. "torch.nn.attention.flex_attention._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True\n\n"
  1438. "This will allow you to use print statements or breakpoints. Note: This doesn't work with the backwards pass and may produce incorrect results."
  1439. ),
  1440. )
  1441. if not torch._dynamo.is_dynamo_supported():
  1442. raise RuntimeError("flex_attention requires dynamo support")
  1443. # Dynamo is expecting a callable with "__code__" attribute.
  1444. # We cannot directly pass hop to it. So we wrap it in a dummy function.
  1445. def _flex_attention_hop_wrapper(*args, **kwargs):
  1446. return flex_attention_hop(*args, **kwargs)
  1447. with setup_compilation_env() as backend:
  1448. if _FLEX_ATTENTION_DISABLE_COMPILE_DEBUG:
  1449. flex_fn = _flex_attention_hop_wrapper
  1450. else:
  1451. flex_fn = torch.compile(
  1452. _flex_attention_hop_wrapper, backend=backend, fullgraph=True
  1453. )
  1454. out, lse, max_scores = flex_fn(
  1455. query,
  1456. key,
  1457. value,
  1458. score_mod,
  1459. block_mask.as_tuple(), # type: ignore[union-attr]
  1460. scale,
  1461. kernel_options,
  1462. )
  1463. return _finalize_outputs(
  1464. out, lse, max_scores, return_aux=return_aux, return_lse=return_lse
  1465. )