attention.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. from typing import Final, Optional, Type
  2. import torch
  3. from torch import nn as nn
  4. from torch.nn import functional as F
  5. from ._fx import register_notrace_function
  6. from .config import use_fused_attn
  7. from .pos_embed_sincos import apply_rot_embed_cat
  8. __all__ = ['Attention', 'AttentionRope', 'maybe_add_mask', 'resolve_self_attn_mask']
  9. @torch.fx.wrap
  10. @register_notrace_function
  11. def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
  12. return scores if attn_mask is None else scores + attn_mask
  13. @torch.fx.wrap
  14. @register_notrace_function
  15. def resolve_self_attn_mask(
  16. seq_len: int,
  17. attn: torch.Tensor,
  18. attn_mask: Optional[torch.Tensor] = None,
  19. is_causal: bool = False,
  20. ) -> Optional[torch.Tensor]:
  21. # Build additive bias matching SDPA semantics for self-attention
  22. # is_causal and attn_mask are mutually exclusive (is_causal takes precedence)
  23. if is_causal:
  24. attn_bias = attn.new_full((seq_len, seq_len), float('-inf')).triu_(1)
  25. elif attn_mask is None:
  26. attn_bias = None
  27. elif attn_mask.dtype == torch.bool:
  28. attn_bias = torch.zeros_like(attn_mask, dtype=attn.dtype)
  29. attn_bias.masked_fill_(~attn_mask, float('-inf'))
  30. else:
  31. attn_bias = attn_mask
  32. return attn_bias
  33. class Attention(nn.Module):
  34. """Standard Multi-head Self Attention module with QKV projection.
  35. This module implements the standard multi-head attention mechanism used in transformers.
  36. It supports both the fused attention implementation (scaled_dot_product_attention) for
  37. efficiency when available, and a manual implementation otherwise. The module includes
  38. options for QK normalization, attention dropout, and projection dropout.
  39. """
  40. fused_attn: Final[bool]
  41. def __init__(
  42. self,
  43. dim: int,
  44. num_heads: int = 8,
  45. attn_head_dim: Optional[int] = None,
  46. dim_out: Optional[int] = None,
  47. qkv_bias: bool = False,
  48. qk_norm: bool = False,
  49. scale_norm: bool = False,
  50. proj_bias: bool = True,
  51. attn_drop: float = 0.,
  52. proj_drop: float = 0.,
  53. norm_layer: Optional[Type[nn.Module]] = None,
  54. device=None,
  55. dtype=None,
  56. ) -> None:
  57. """Initialize the Attention module.
  58. Args:
  59. dim: Input dimension of the token embeddings.
  60. num_heads: Number of attention heads.
  61. attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads.
  62. dim_out: Output dimension. If None, same as dim.
  63. qkv_bias: Whether to use bias in the query, key, value projections.
  64. qk_norm: Whether to apply normalization to query and key vectors.
  65. scale_norm: Whether to apply normalization to attention output before projection.
  66. proj_bias: Whether to use bias in the output projection.
  67. attn_drop: Dropout rate applied to the attention weights.
  68. proj_drop: Dropout rate applied after the output projection.
  69. norm_layer: Normalization layer constructor for QK normalization if enabled.
  70. """
  71. super().__init__()
  72. dd = {'device': device, 'dtype': dtype}
  73. dim_out = dim_out or dim
  74. head_dim = attn_head_dim
  75. if head_dim is None:
  76. assert dim % num_heads == 0, 'dim should be divisible by num_heads'
  77. head_dim = dim // num_heads
  78. if qk_norm or scale_norm:
  79. assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
  80. self.num_heads = num_heads
  81. self.head_dim = head_dim
  82. self.attn_dim = num_heads * head_dim
  83. self.scale = head_dim ** -0.5
  84. self.fused_attn = use_fused_attn()
  85. self.qkv = nn.Linear(dim, self.attn_dim * 3, bias=qkv_bias, **dd)
  86. self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
  87. self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
  88. self.attn_drop = nn.Dropout(attn_drop)
  89. self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity()
  90. self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd)
  91. self.proj_drop = nn.Dropout(proj_drop)
  92. def forward(
  93. self,
  94. x: torch.Tensor,
  95. attn_mask: Optional[torch.Tensor] = None,
  96. is_causal: bool = False,
  97. ) -> torch.Tensor:
  98. B, N, C = x.shape
  99. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  100. q, k, v = qkv.unbind(0)
  101. q, k = self.q_norm(q), self.k_norm(k)
  102. if self.fused_attn:
  103. x = F.scaled_dot_product_attention(
  104. q, k, v,
  105. attn_mask=attn_mask,
  106. dropout_p=self.attn_drop.p if self.training else 0.,
  107. is_causal=is_causal,
  108. )
  109. else:
  110. q = q * self.scale
  111. attn = q @ k.transpose(-2, -1)
  112. attn_bias = resolve_self_attn_mask(N, attn, attn_mask, is_causal)
  113. attn = maybe_add_mask(attn, attn_bias)
  114. attn = attn.softmax(dim=-1)
  115. attn = self.attn_drop(attn)
  116. x = attn @ v
  117. x = x.transpose(1, 2).reshape(B, N, self.attn_dim)
  118. x = self.norm(x)
  119. x = self.proj(x)
  120. x = self.proj_drop(x)
  121. return x
  122. class AttentionRope(nn.Module):
  123. """ A Self Attention module with ROPE support.
  124. Includes options for:
  125. * QK normalization option
  126. * Attention output (scale) normalization
  127. * Fused or unfused QKV projection support
  128. """
  129. fused_attn: torch.jit.Final[bool]
  130. def __init__(
  131. self,
  132. dim: int,
  133. num_heads: int = 8,
  134. dim_out: Optional[int] = None,
  135. qkv_bias: bool = True,
  136. qkv_fused: bool = True,
  137. num_prefix_tokens: int = 1,
  138. attn_drop: float = 0.,
  139. proj_drop: float = 0.,
  140. attn_head_dim: Optional[int] = None,
  141. norm_layer: Type[nn.Module] = None,
  142. qk_norm: bool = False,
  143. scale_norm: bool = False,
  144. proj_bias: bool = True,
  145. rotate_half: bool = False,
  146. device=None,
  147. dtype=None,
  148. ):
  149. """Initialize the Attention module.
  150. Args:
  151. dim: Input dimension of the token embeddings
  152. num_heads: Number of attention heads
  153. dim_out: Output dimension. If None, same as dim.
  154. qkv_bias: Whether to add a bias term to the query, key, and value projections
  155. qkv_fused: Whether to use fused QKV projection (single linear) or separate projections
  156. num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
  157. should not have position embeddings applied
  158. attn_drop: Dropout rate for attention weights
  159. proj_drop: Dropout rate for the output projection
  160. attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads.
  161. norm_layer: Normalization layer constructor to use for QK and scale normalization
  162. qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
  163. scale_norm: Enable normalization (scaling) of attention output with norm_layer
  164. proj_bias: Whether to use bias in the output projection
  165. rotate_half: Use 'half' ROPE layout instead of default 'interleaved'
  166. """
  167. super().__init__()
  168. dd = {'device': device, 'dtype': dtype}
  169. dim_out = dim_out or dim
  170. head_dim = attn_head_dim
  171. if head_dim is None:
  172. assert dim % num_heads == 0, 'dim should be divisible by num_heads'
  173. head_dim = dim // num_heads
  174. if scale_norm or qk_norm:
  175. assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
  176. self.num_heads = num_heads
  177. self.head_dim = head_dim
  178. self.attn_dim = head_dim * num_heads
  179. self.scale = head_dim ** -0.5
  180. self.num_prefix_tokens = num_prefix_tokens
  181. self.fused_attn = use_fused_attn()
  182. self.rotate_half = rotate_half
  183. if qkv_fused:
  184. self.qkv = nn.Linear(dim, self.attn_dim * 3, bias=qkv_bias, **dd)
  185. self.q_proj = self.k_proj = self.v_proj = None
  186. else:
  187. self.qkv = None
  188. self.q_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd)
  189. self.k_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd)
  190. self.v_proj = nn.Linear(dim, self.attn_dim, bias=qkv_bias, **dd)
  191. self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
  192. self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
  193. self.attn_drop = nn.Dropout(attn_drop)
  194. self.norm = norm_layer(self.attn_dim, **dd) if scale_norm else nn.Identity()
  195. self.proj = nn.Linear(self.attn_dim, dim_out, bias=proj_bias, **dd)
  196. self.proj_drop = nn.Dropout(proj_drop)
  197. def forward(
  198. self,
  199. x,
  200. rope: Optional[torch.Tensor] = None,
  201. attn_mask: Optional[torch.Tensor] = None,
  202. is_causal: bool = False,
  203. ):
  204. """Forward pass for the attention module.
  205. Args:
  206. x: Input tensor of shape (batch_size, sequence_length, embedding_dim)
  207. rope: Rotary position embeddings tensor for position-aware attention
  208. attn_mask: Optional attention mask to apply during attention computation
  209. is_causal: If True, use causal (autoregressive) masking
  210. Returns:
  211. Tensor of shape (batch_size, sequence_length, dim_out)
  212. """
  213. B, N, C = x.shape
  214. if self.qkv is not None:
  215. qkv = self.qkv(x)
  216. qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  217. q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
  218. else:
  219. q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
  220. k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
  221. v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
  222. q, k = self.q_norm(q), self.k_norm(k)
  223. if rope is not None:
  224. npt = self.num_prefix_tokens
  225. half = getattr(self, 'rotate_half', False)
  226. q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope, half=half)], dim=2).type_as(v)
  227. k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope, half=half)], dim=2).type_as(v)
  228. if self.fused_attn:
  229. x = F.scaled_dot_product_attention(
  230. q, k, v,
  231. attn_mask=attn_mask,
  232. dropout_p=self.attn_drop.p if self.training else 0.,
  233. is_causal=is_causal,
  234. )
  235. else:
  236. q = q * self.scale
  237. attn = q @ k.transpose(-2, -1)
  238. attn_bias = resolve_self_attn_mask(N, attn, attn_mask, is_causal)
  239. attn = maybe_add_mask(attn, attn_bias)
  240. attn = attn.softmax(dim=-1)
  241. attn = self.attn_drop(attn)
  242. x = attn @ v
  243. x = x.transpose(1, 2).reshape(B, N, self.attn_dim)
  244. x = self.norm(x)
  245. x = self.proj(x)
  246. x = self.proj_drop(x)
  247. return x