bias.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. # mypy: allow-untyped-defs
  2. """Defines bias subclasses that work with scaled_dot_product_attention"""
  3. from enum import auto, IntEnum
  4. from warnings import warn
  5. import torch
  6. import torch.nn.functional as F
  7. from torch.backends.cuda import (
  8. can_use_efficient_attention,
  9. can_use_flash_attention,
  10. is_flash_attention_available,
  11. SDPAParams,
  12. )
  13. from torch.nn.attention import _raise_kernel_warnings
  14. from torch.nn.attention._utils import (
  15. _calculate_scale,
  16. _input_requires_grad,
  17. _postprocess_flash_output,
  18. _validate_sdpa_input,
  19. )
  20. __all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"]
  21. torch._dynamo.allow_in_graph(is_flash_attention_available)
  22. torch._dynamo.allow_in_graph(can_use_flash_attention)
  23. torch._dynamo.allow_in_graph(can_use_efficient_attention)
  24. torch._dynamo.allow_in_graph(SDPAParams)
  25. class CausalVariant(IntEnum):
  26. r"""
  27. Enum for causal variants used in attention mechanisms.
  28. Defines two types of causal biases:
  29. ``UPPER_LEFT``: Represents upper-left triangular bias for standard causal attention.
  30. The equivalent pytorch code for constructing this bias is:
  31. .. code-block:: python
  32. torch.tril(torch.ones(size, dtype=torch.bool))
  33. For instance, with ``shape=(3,4)``, the materialized bias tensor will be:
  34. .. code-block:: text
  35. [[1, 0, 0, 0],
  36. [1, 1, 0, 0],
  37. [1, 1, 1, 0]]
  38. ``LOWER_RIGHT``: Represents lower-right triangular bias, the include values are aligned to the lower
  39. right corner of the matrix.
  40. The equivalent pytorch code for constructing this bias is:
  41. .. code-block:: python
  42. diagonal_offset = size[1] - size[0]
  43. torch.tril(
  44. torch.ones(size, dtype=torch.bool),
  45. diagonal=diagonal_offset,
  46. )
  47. For instance, with ``shape=(3,4)``, the materialized bias tensor will be:
  48. .. code-block:: text
  49. [[1, 1, 0, 0],
  50. [1, 1, 1, 0],
  51. [1, 1, 1, 1]]
  52. Note that these variants are equivalent to each other when the sequence lengths of the query and key/value
  53. tensors are equal since the triangular matrix is square.
  54. .. warning:: This enum is a prototype and subject to change.
  55. """
  56. UPPER_LEFT = auto()
  57. LOWER_RIGHT = auto()
  58. class CausalBias(torch.Tensor):
  59. """
  60. A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum.
  61. This class is used for defining causal (triangular) attention biases. For construing the bias, there exist
  62. two factory functions: :func:`causal_upper_left` and :func:`causal_lower_right`.
  63. Example:
  64. .. code-block:: python
  65. from torch.nn.attention.bias import causal_lower_right
  66. bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8
  67. # Create a lower-right causal bias
  68. attn_bias = causal_lower_right(seqlen_q, seqlen_kv)
  69. q = torch.randn(
  70. bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16
  71. )
  72. k = torch.randn(
  73. bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16
  74. )
  75. v = torch.randn(
  76. bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16
  77. )
  78. out = F.scaled_dot_product_attention(q, k, v, attn_bias)
  79. .. warning:: This class is a prototype and subject to change.
  80. """
  81. def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int) -> None:
  82. """
  83. Initializes the CausalBias instance with a specified variant and sequence lengths.
  84. Args:
  85. variant (CausalVariant): The type of causal bias to use (either UPPER_LEFT or LOWER_RIGHT).
  86. seq_len_q (int): The sequence length of the query tensor.
  87. seq_len_kv (int): The sequence length of the key/value tensor.
  88. Raises a warning if the LOWER_RIGHT variant is used with seq_len_q > seq_len_kv, as it may produce NaNs.
  89. """
  90. if not isinstance(variant, CausalVariant):
  91. raise AssertionError(
  92. f"variant must be a CausalVariant, got {type(variant).__name__}"
  93. )
  94. super().__init__()
  95. self.variant = variant
  96. self.seq_len_q = seq_len_q
  97. self.seq_len_kv = seq_len_kv
  98. if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT:
  99. warn(
  100. "Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!",
  101. stacklevel=2,
  102. )
  103. def _upper_left(self, device: torch.device) -> torch.Tensor:
  104. """Upper left causal bias"""
  105. return torch.tril(
  106. torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool)
  107. )
  108. def _lower_right(self, device: torch.device) -> torch.Tensor:
  109. """Lower right causal bias"""
  110. diagonal_offset = self.seq_len_kv - self.seq_len_q
  111. return torch.tril(
  112. torch.ones(
  113. self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool
  114. ),
  115. diagonal=diagonal_offset,
  116. )
  117. # pyrefly: ignore [bad-return]
  118. def _materialize(self, device: torch.device | None = None) -> torch.Tensor:
  119. """
  120. Materializes the causal bias into a tensor form.
  121. Depending on the variant, this method generates either an upper-left or lower-right
  122. triangular matrix to represent the causal bias.
  123. Args:
  124. device (Optional[torch.device]): The device on which to create the tensor. Defaults to CPU.
  125. Returns:
  126. torch.Tensor: The materialized bias tensor.
  127. """
  128. if device is None:
  129. device = torch.device("cpu")
  130. if self.variant == CausalVariant.UPPER_LEFT:
  131. return self._upper_left(device)
  132. elif self.variant == CausalVariant.LOWER_RIGHT:
  133. return self._lower_right(device)
  134. @staticmethod
  135. def _dispatch(
  136. query: torch.Tensor,
  137. key: torch.Tensor,
  138. value: torch.Tensor,
  139. attn_mask: "CausalBias",
  140. dropout_p: float = 0.0,
  141. is_causal: bool = False,
  142. scale: float | None = None,
  143. enable_gqa: bool = False,
  144. ) -> torch.Tensor:
  145. r"""
  146. Handles the logic for computing attention with the specified causal bias.
  147. Args:
  148. query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
  149. key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
  150. value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
  151. attn_mask (CausalBias): The type of causal attention to apply.
  152. A boolean mask where a value of True indicates that the element *should* take part in attention.
  153. A float mask of the same type as query, key, value that is added to the attention score.
  154. dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
  155. is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal
  156. are set.
  157. scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
  158. to :math:`\frac{1}{\sqrt{E}}`.
  159. enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
  160. Returns:
  161. output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
  162. Raises:
  163. ValueError: If the causal bias variant is not a CausalVariant type.
  164. """
  165. if is_causal:
  166. raise ValueError("CausalBias should not be used with causal=True")
  167. if (
  168. attn_mask.seq_len_q == attn_mask.seq_len_kv
  169. or attn_mask.variant == CausalVariant.UPPER_LEFT
  170. ):
  171. return F.scaled_dot_product_attention(
  172. query,
  173. key,
  174. value,
  175. attn_mask=None,
  176. dropout_p=dropout_p,
  177. is_causal=True,
  178. scale=scale,
  179. enable_gqa=enable_gqa,
  180. )
  181. elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
  182. _validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
  183. sdpa_params = SDPAParams(
  184. query, key, value, None, dropout_p, is_causal, enable_gqa
  185. )
  186. if can_use_flash_attention(sdpa_params):
  187. alignment = 64 if query.device.type == "xpu" else 8
  188. og_head_size = query.size(-1)
  189. og_scale = _calculate_scale(og_head_size, scale)
  190. needs_padding = og_head_size % alignment != 0
  191. if needs_padding:
  192. pad_len = alignment - (og_head_size % alignment)
  193. query = torch.nn.functional.pad(query, (0, pad_len))
  194. key = torch.nn.functional.pad(key, (0, pad_len))
  195. value = torch.nn.functional.pad(value, (0, pad_len))
  196. out = torch.ops.aten._scaled_dot_product_flash_attention(
  197. query,
  198. key,
  199. value,
  200. dropout_p,
  201. is_causal=True, # TODO: Flash accepts causal = True and for this particular op it means lower right
  202. return_debug_mask=False,
  203. scale=og_scale,
  204. )[0]
  205. return _postprocess_flash_output(out, og_head_size)
  206. if can_use_efficient_attention(sdpa_params):
  207. compute_log_sumexp = False
  208. if _input_requires_grad(query, key, value):
  209. compute_log_sumexp = True
  210. return torch.ops.aten._efficient_attention_forward(
  211. query.transpose(1, 2),
  212. key.transpose(1, 2),
  213. value.transpose(1, 2),
  214. bias=None,
  215. cu_seqlens_q=None,
  216. cu_seqlens_k=None,
  217. max_seqlen_q=None,
  218. max_seqlen_k=None,
  219. dropout_p=dropout_p,
  220. custom_mask_type=int(attn_mask.variant),
  221. compute_log_sumexp=compute_log_sumexp,
  222. scale=scale,
  223. seqlen_k=None,
  224. )[0].transpose(1, 2)
  225. else:
  226. _raise_kernel_warnings(sdpa_params)
  227. # We can't use efficient attention the only support for lower right is via materialization
  228. return F.scaled_dot_product_attention(
  229. query,
  230. key,
  231. value,
  232. attn_mask=attn_mask._materialize(query.device),
  233. dropout_p=dropout_p,
  234. is_causal=False,
  235. scale=scale,
  236. enable_gqa=enable_gqa,
  237. )
  238. else:
  239. raise ValueError(
  240. f"CausalBias.variant must be a CausalVariant type, but found: {attn_mask.variant}"
  241. )
  242. @classmethod
  243. def __torch_function__(cls, func, types, args=(), kwargs=None):
  244. """Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias"""
  245. if kwargs is None:
  246. kwargs = {}
  247. if func is torch.nn.functional.scaled_dot_product_attention:
  248. return cls._dispatch(*args, **kwargs)
  249. return super().__torch_function__(func, types, args, kwargs)
  250. def __repr__(self) -> str: # type:ignore[override]
  251. return self._materialize().__repr__()
  252. def causal_upper_left(*size) -> CausalBias:
  253. """
  254. Creates an upper-left triangular causal bias.
  255. This function generates a upper-left triangular matrix to represent causal attention bias with a
  256. diagonal offset set so that the inclusive values are aligned to the upper left corner of the matrix.
  257. This equivalent to the `is_causal=True` argument in `scaled_dot_product_attention`.
  258. The equivalent pytorch code for constructing this bias is:
  259. .. code-block:: python
  260. torch.tril(torch.ones(size, dtype=torch.bool))
  261. For instance, with `shape=(3,4)`, the materialized bias tensor will be:
  262. .. code-block:: text
  263. [[1, 0, 0, 0],
  264. [1, 1, 0, 0],
  265. [1, 1, 1, 0]]
  266. Args:
  267. size: The size of the bias matrix.
  268. Returns:
  269. CausalBias: The UPPER_LEFT triangular causal bias variant.
  270. """
  271. if len(size) != 2:
  272. raise AssertionError("causal_upper_left only supports 2D tensors")
  273. seq_len_q, seq_len_kv = size
  274. return CausalBias(CausalVariant.UPPER_LEFT, seq_len_q, seq_len_kv)
  275. def causal_lower_right(*size) -> CausalBias:
  276. """
  277. Creates a lower-right triangular causal bias.
  278. This function generates a lower-right triangular matrix to represent causal attention bias with a
  279. diagonal offset set so that the inclusive values are aligned to the lower right corner of the matrix.
  280. The equivalent pytorch code for constructing this bias is:
  281. .. code-block:: python
  282. diagonal_offset = size[1] - size[0]
  283. torch.tril(
  284. torch.ones(size, dtype=torch.bool),
  285. diagonal=diagonal_offset,
  286. )
  287. For instance, with `shape=(3,4)`, the materialized bias tensor will be:
  288. .. code-block:: text
  289. [[1, 1, 0, 0],
  290. [1, 1, 1, 0],
  291. [1, 1, 1, 1]]
  292. Args:
  293. size: The size of the bias matrix.
  294. Returns:
  295. CausalBias: The LOWER_RIGHT triangular causal bias variant.
  296. """
  297. if len(size) != 2:
  298. raise AssertionError("causal_lower_right only supports 2D tensors")
  299. seq_len_q, seq_len_kv = size
  300. return CausalBias(CausalVariant.LOWER_RIGHT, seq_len_q, seq_len_kv)