diff_attention.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. """Differential Attention
  2. Paper: 'Differential Transformer' - https://arxiv.org/abs/2410.05258
  3. Reference impl: https://github.com/microsoft/unilm/tree/master/Diff-Transformer
  4. Hacked together by / Copyright 2024, Ross Wightman
  5. """
  6. import math
  7. from typing import Optional, Type
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from .attention import maybe_add_mask, resolve_self_attn_mask
  12. from .config import use_fused_attn
  13. from .norm import RmsNorm
  14. class DiffAttention(nn.Module):
  15. """Differential Attention module.
  16. Computes attention as the difference between two softmax attention maps, which helps
  17. cancel out noise and promotes sparse attention patterns. The module splits Q and K
  18. into two groups, computes separate attention maps, and subtracts one from the other
  19. scaled by a learnable lambda parameter.
  20. The attention output is computed as:
  21. Attn = softmax(Q1 @ K1^T) - lambda * softmax(Q2 @ K2^T)
  22. Output = Attn @ V
  23. Supports both fused (scaled_dot_product_attention) and manual implementations.
  24. """
  25. fused_attn: torch.jit.Final[bool]
  26. def __init__(
  27. self,
  28. dim: int,
  29. num_heads: int = 8,
  30. qkv_bias: bool = False,
  31. qk_norm: bool = False,
  32. scale_norm: bool = False,
  33. proj_bias: bool = True,
  34. attn_drop: float = 0.,
  35. proj_drop: float = 0.,
  36. norm_layer: Optional[Type[nn.Module]] = None,
  37. depth: int = 0,
  38. dual_lambda: bool = False,
  39. device=None,
  40. dtype=None,
  41. ) -> None:
  42. """Initialize the DiffAttention module.
  43. Args:
  44. dim: Input dimension of the token embeddings.
  45. num_heads: Number of attention heads.
  46. qkv_bias: Whether to use bias in the query, key, value projections.
  47. qk_norm: Whether to apply normalization to query and key vectors.
  48. scale_norm: Whether to apply normalization before the output projection.
  49. proj_bias: Whether to use bias in the output projection.
  50. attn_drop: Dropout rate applied to the attention weights.
  51. proj_drop: Dropout rate applied after the output projection.
  52. norm_layer: Normalization layer constructor (defaults to RmsNorm).
  53. depth: Block depth index, used to compute depth-dependent lambda_init.
  54. dual_lambda: If True, use simplified dual scalar lambda parameterization
  55. (2 params). If False, use the paper's original formulation with
  56. lambda_q/k vectors (4 * head_dim params).
  57. """
  58. super().__init__()
  59. dd = {'device': device, 'dtype': dtype}
  60. assert dim % num_heads == 0, 'dim should be divisible by num_heads'
  61. if norm_layer is None:
  62. norm_layer = RmsNorm
  63. self.num_heads = num_heads
  64. self.head_dim = dim // num_heads // 2
  65. self.scale = self.head_dim ** -0.5
  66. self.fused_attn = use_fused_attn()
  67. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  68. self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  69. self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  70. self.attn_drop = nn.Dropout(attn_drop)
  71. self.attn_drop_p = attn_drop
  72. self.norm = norm_layer(dim, **dd) if scale_norm else nn.Identity()
  73. self.proj = nn.Linear(dim, dim, bias=proj_bias, **dd)
  74. self.proj_drop = nn.Dropout(proj_drop)
  75. self.dual_lambda = dual_lambda
  76. if dual_lambda:
  77. self.lambda_a = nn.Parameter(torch.empty((), dtype=torch.float32, device=device))
  78. self.lambda_b = nn.Parameter(torch.empty((), dtype=torch.float32, device=device))
  79. self.lambda_q1 = self.lambda_k1 = self.lambda_q2 = self.lambda_k2 = None
  80. else:
  81. self.lambda_a = self.lambda_b = None
  82. self.lambda_q1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
  83. self.lambda_k1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
  84. self.lambda_q2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
  85. self.lambda_k2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
  86. self.sub_norm = RmsNorm(2 * self.head_dim, eps=1e-5, **dd)
  87. self.lambda_init = 0.8
  88. self.set_lambda_init(depth)
  89. self.reset_parameters()
  90. def set_lambda_init(self, depth: int):
  91. self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)
  92. def reset_parameters(self):
  93. if self.dual_lambda:
  94. nn.init.zeros_(self.lambda_a)
  95. nn.init.zeros_(self.lambda_b)
  96. else:
  97. nn.init.normal_(self.lambda_q1, mean=0, std=0.1)
  98. nn.init.normal_(self.lambda_k1, mean=0, std=0.1)
  99. nn.init.normal_(self.lambda_q2, mean=0, std=0.1)
  100. nn.init.normal_(self.lambda_k2, mean=0, std=0.1)
  101. def _compute_lambda(self) -> torch.Tensor:
  102. if self.lambda_a is not None:
  103. lambda_1 = torch.exp(self.lambda_a)
  104. lambda_2 = torch.exp(self.lambda_b)
  105. else:
  106. lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float())
  107. lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float())
  108. return lambda_1 - lambda_2 + self.lambda_init
  109. def forward(
  110. self,
  111. x: torch.Tensor,
  112. attn_mask: Optional[torch.Tensor] = None,
  113. is_causal: bool = False,
  114. ) -> torch.Tensor:
  115. B, N, C = x.shape
  116. q, k, v = self.qkv(x).chunk(3, dim=2)
  117. q = q.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
  118. k = k.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
  119. v = v.reshape(B, N, self.num_heads, 2 * self.head_dim).transpose(1, 2)
  120. q, k = self.q_norm(q), self.k_norm(k)
  121. lambda_full = self._compute_lambda().type_as(q)
  122. if self.fused_attn:
  123. q = q.reshape(B, self.num_heads, 2, N, self.head_dim)
  124. k = k.reshape(B, self.num_heads, 2, N, self.head_dim)
  125. q1, q2 = q.unbind(2)
  126. k1, k2 = k.unbind(2)
  127. dropout_p = self.attn_drop_p if self.training else 0.0
  128. attn1 = F.scaled_dot_product_attention(
  129. q1, k1, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
  130. attn2 = F.scaled_dot_product_attention(
  131. q2, k2, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
  132. x = attn1 - lambda_full * attn2
  133. else:
  134. q = q * self.scale
  135. attn = q @ k.transpose(-2, -1)
  136. attn_bias = resolve_self_attn_mask(N, attn, attn_mask, is_causal=is_causal)
  137. attn = maybe_add_mask(attn, attn_bias)
  138. attn = attn.softmax(dim=-1)
  139. attn = self.attn_drop(attn)
  140. attn = attn.view(B, self.num_heads, 2, N, N)
  141. attn = attn[:, :, 0] - lambda_full * attn[:, :, 1]
  142. x = attn @ v
  143. x = self.sub_norm(x)
  144. x = x * (1 - self.lambda_init)
  145. x = x.transpose(1, 2).reshape(B, N, C)
  146. x = self.norm(x)
  147. x = self.proj(x)
  148. x = self.proj_drop(x)
  149. return x