mha.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # Copyright (c) 2024, Tri Dao, Albert Gu.
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from einops import rearrange
  7. try:
  8. from flash_attn import flash_attn_with_kvcache
  9. except ImportError:
  10. flash_attn_with_kvcache = None
  11. try:
  12. from flash_attn.layers.rotary import RotaryEmbedding
  13. except ImportError:
  14. RotaryEmbedding = None
  15. try:
  16. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  17. except ImportError:
  18. causal_conv1d_fn, causal_conv1d_update = None, None
  19. def _update_kv_cache(kv, inference_params, layer_idx):
  20. """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
  21. # Pre-allocate memory for key-values for inference.
  22. num_heads, head_dim = kv.shape[-2:]
  23. assert layer_idx in inference_params.key_value_memory_dict
  24. kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
  25. # Adjust key and value for inference
  26. batch_start = inference_params.batch_size_offset
  27. batch_end = batch_start + kv.shape[0]
  28. sequence_start = inference_params.seqlen_offset
  29. sequence_end = sequence_start + kv.shape[1]
  30. assert batch_end <= kv_cache.shape[0]
  31. assert sequence_end <= kv_cache.shape[1]
  32. assert kv_cache is not None
  33. kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
  34. return kv_cache[batch_start:batch_end, :sequence_end, ...]
  35. class MHA(nn.Module):
  36. """Multi-head self-attention and cross-attention"""
  37. def __init__(
  38. self,
  39. embed_dim,
  40. num_heads,
  41. num_heads_kv=None,
  42. head_dim=None, # If None, use embed_dim // num_heads
  43. mlp_dim=0,
  44. qkv_proj_bias=True,
  45. out_proj_bias=True,
  46. softmax_scale=None,
  47. causal=False,
  48. layer_idx=None,
  49. d_conv=0,
  50. rotary_emb_dim=0,
  51. rotary_emb_base=10000.0,
  52. rotary_emb_interleaved=False,
  53. device=None,
  54. dtype=None,
  55. ) -> None:
  56. """
  57. num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
  58. return_residual: whether to return the input x along with the output. This is for
  59. performance reason: for post-norm architecture, returning the input allows us
  60. to fuse the backward of nn.Linear with the residual connection.
  61. """
  62. factory_kwargs = {"device": device, "dtype": dtype}
  63. super().__init__()
  64. self.embed_dim = embed_dim
  65. self.layer_idx = layer_idx
  66. self.d_conv = d_conv
  67. self.rotary_emb_dim = rotary_emb_dim
  68. self.softmax_scale = softmax_scale
  69. self.causal = causal
  70. self.num_heads = num_heads
  71. self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
  72. assert (
  73. self.num_heads % self.num_heads_kv == 0
  74. ), "num_heads must be divisible by num_heads_kv"
  75. if head_dim is None:
  76. assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
  77. self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
  78. self.mlp_dim = math.ceil(mlp_dim / 256) * 256
  79. qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
  80. out_dim = self.head_dim * self.num_heads
  81. if self.rotary_emb_dim > 0:
  82. assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
  83. self.rotary_emb = RotaryEmbedding(
  84. self.rotary_emb_dim,
  85. base=rotary_emb_base,
  86. interleaved=rotary_emb_interleaved,
  87. device=device,
  88. )
  89. self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
  90. if self.d_conv > 0:
  91. self.conv1d = nn.Conv1d(
  92. qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
  93. **factory_kwargs
  94. )
  95. self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
  96. def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
  97. dtype = self.out_proj.weight.dtype if dtype is None else dtype
  98. device = self.out_proj.weight.device
  99. if self.d_conv > 0:
  100. conv_state = torch.zeros(
  101. batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
  102. )
  103. else:
  104. conv_state = None
  105. kv_cache = torch.empty(
  106. batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
  107. )
  108. return kv_cache, conv_state
  109. def _update_kv_cache(self, kv, inference_params):
  110. """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
  111. assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
  112. return _update_kv_cache(kv, inference_params, self.layer_idx)
  113. def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
  114. """
  115. Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
  116. q: (batch_size, seqlen_q, nheads, head_dim)
  117. kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
  118. """
  119. assert inference_params is not None and inference_params.seqlen_offset > 0
  120. if self.rotary_emb_dim > 0:
  121. self.rotary_emb._update_cos_sin_cache(
  122. inference_params.max_seqlen, device=q.device, dtype=q.dtype
  123. )
  124. rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
  125. else:
  126. rotary_cos, rotary_sin = None, None
  127. batch = q.shape[0]
  128. kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
  129. kv_cache = kv_cache[:batch]
  130. cache_seqlens = (
  131. inference_params.lengths_per_sample[:batch]
  132. if inference_params.lengths_per_sample is not None
  133. else inference_params.seqlen_offset
  134. )
  135. assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
  136. context = flash_attn_with_kvcache(
  137. q,
  138. kv_cache[:, :, 0],
  139. kv_cache[:, :, 1],
  140. kv[:, :, 0],
  141. kv[:, :, 1],
  142. rotary_cos=rotary_cos,
  143. rotary_sin=rotary_sin,
  144. cache_seqlens=cache_seqlens,
  145. softmax_scale=self.softmax_scale,
  146. causal=self.causal,
  147. rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
  148. )
  149. return context
  150. def _update_kvcache_attention(self, q, kv, inference_params):
  151. """Write kv to inference_params, then do attention"""
  152. if (
  153. inference_params.seqlen_offset == 0
  154. or flash_attn_with_kvcache is None
  155. ):
  156. # TODO: this only uses seqlen_offset and not lengths_per_sample.
  157. kv = self._update_kv_cache(kv, inference_params)
  158. k, v = kv.unbind(dim=-3)
  159. k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
  160. v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
  161. return F.scaled_dot_product_attention(
  162. q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
  163. ).transpose(1, 2)
  164. else:
  165. batch = q.shape[0]
  166. kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
  167. kv_cache = kv_cache[:batch]
  168. cache_seqlens = (
  169. inference_params.lengths_per_sample[:batch]
  170. if inference_params.lengths_per_sample is not None
  171. else inference_params.seqlen_offset
  172. )
  173. return flash_attn_with_kvcache(
  174. q,
  175. kv_cache[:, :, 0],
  176. kv_cache[:, :, 1],
  177. kv[:, :, 0],
  178. kv[:, :, 1],
  179. cache_seqlens=cache_seqlens,
  180. softmax_scale=self.softmax_scale,
  181. causal=self.causal,
  182. )
  183. def forward(self, x, inference_params=None):
  184. """
  185. Arguments:
  186. x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
  187. cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
  188. is the is the sum of the sequence lengths in the batch.
  189. inference_params: for generation. Adapted from Megatron-LM (and Apex)
  190. https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
  191. """
  192. if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
  193. inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
  194. x.shape[0], inference_params.max_seqlen, dtype=x.dtype
  195. )
  196. seqlen_offset = (
  197. 0
  198. if inference_params is None
  199. else (
  200. inference_params.lengths_per_sample
  201. if inference_params.lengths_per_sample is not None
  202. else inference_params.seqlen_offset
  203. )
  204. )
  205. rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
  206. qkv = self.in_proj(x)
  207. if self.mlp_dim > 0:
  208. qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
  209. x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
  210. x_mlp = x_mlp_up * F.silu(x_mlp_gate)
  211. if self.d_conv > 0:
  212. # The inference code for conv1d is pretty messy, should clean it up
  213. if (inference_params is None or inference_params.seqlen_offset == 0):
  214. if causal_conv1d_fn is None:
  215. qkv = rearrange(
  216. self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
  217. ).contiguous()
  218. else:
  219. qkv = causal_conv1d_fn(
  220. qkv.transpose(1, 2),
  221. rearrange(self.conv1d.weight, "d 1 w -> d w"),
  222. self.conv1d.bias
  223. ).transpose(1, 2)
  224. if inference_params is not None:
  225. _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
  226. # If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
  227. # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
  228. qkv_t = rearrange(qkv, "b l d -> b d l")
  229. conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
  230. else:
  231. _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
  232. assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
  233. qkv = qkv.squeeze(1)
  234. # Conv step
  235. if causal_conv1d_update is None:
  236. conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
  237. conv_state[:, :, -1] = qkv
  238. qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
  239. if self.conv1d.bias is not None:
  240. qkv = qkv + self.conv1d.bias
  241. else:
  242. qkv = causal_conv1d_update(
  243. qkv,
  244. conv_state,
  245. rearrange(self.conv1d.weight, "d 1 w -> d w"),
  246. self.conv1d.bias
  247. )
  248. qkv = qkv.unsqueeze(1)
  249. q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
  250. q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
  251. kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
  252. if (
  253. inference_params is None
  254. or inference_params.seqlen_offset == 0
  255. or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
  256. ):
  257. if self.rotary_emb_dim > 0:
  258. q, kv = self.rotary_emb(
  259. q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
  260. )
  261. if inference_params is None:
  262. k, v = kv.unbind(dim=-3)
  263. k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
  264. v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
  265. context = F.scaled_dot_product_attention(
  266. q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
  267. ).transpose(1, 2)
  268. else:
  269. context = self._update_kvcache_attention(q, kv, inference_params)
  270. else:
  271. context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
  272. context = rearrange(context, "... h d -> ... (h d)")
  273. if self.mlp_dim > 0:
  274. context = torch.cat([context, x_mlp], dim=-1)
  275. out = self.out_proj(context)
  276. return out