| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- # Copyright (c) 2024, Tri Dao, Albert Gu.
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from einops import rearrange
- try:
- from flash_attn import flash_attn_with_kvcache
- except ImportError:
- flash_attn_with_kvcache = None
- try:
- from flash_attn.layers.rotary import RotaryEmbedding
- except ImportError:
- RotaryEmbedding = None
- try:
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
- except ImportError:
- causal_conv1d_fn, causal_conv1d_update = None, None
- def _update_kv_cache(kv, inference_params, layer_idx):
- """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
- # Pre-allocate memory for key-values for inference.
- num_heads, head_dim = kv.shape[-2:]
- assert layer_idx in inference_params.key_value_memory_dict
- kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
- # Adjust key and value for inference
- batch_start = inference_params.batch_size_offset
- batch_end = batch_start + kv.shape[0]
- sequence_start = inference_params.seqlen_offset
- sequence_end = sequence_start + kv.shape[1]
- assert batch_end <= kv_cache.shape[0]
- assert sequence_end <= kv_cache.shape[1]
- assert kv_cache is not None
- kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
- return kv_cache[batch_start:batch_end, :sequence_end, ...]
- class MHA(nn.Module):
- """Multi-head self-attention and cross-attention"""
- def __init__(
- self,
- embed_dim,
- num_heads,
- num_heads_kv=None,
- head_dim=None, # If None, use embed_dim // num_heads
- mlp_dim=0,
- qkv_proj_bias=True,
- out_proj_bias=True,
- softmax_scale=None,
- causal=False,
- layer_idx=None,
- d_conv=0,
- rotary_emb_dim=0,
- rotary_emb_base=10000.0,
- rotary_emb_interleaved=False,
- device=None,
- dtype=None,
- ) -> None:
- """
- num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
- return_residual: whether to return the input x along with the output. This is for
- performance reason: for post-norm architecture, returning the input allows us
- to fuse the backward of nn.Linear with the residual connection.
- """
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- self.embed_dim = embed_dim
- self.layer_idx = layer_idx
- self.d_conv = d_conv
- self.rotary_emb_dim = rotary_emb_dim
- self.softmax_scale = softmax_scale
- self.causal = causal
- self.num_heads = num_heads
- self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
- assert (
- self.num_heads % self.num_heads_kv == 0
- ), "num_heads must be divisible by num_heads_kv"
- if head_dim is None:
- assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
- self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
- self.mlp_dim = math.ceil(mlp_dim / 256) * 256
- qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
- out_dim = self.head_dim * self.num_heads
- if self.rotary_emb_dim > 0:
- assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
- self.rotary_emb = RotaryEmbedding(
- self.rotary_emb_dim,
- base=rotary_emb_base,
- interleaved=rotary_emb_interleaved,
- device=device,
- )
- self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
- if self.d_conv > 0:
- self.conv1d = nn.Conv1d(
- qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
- **factory_kwargs
- )
- self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
- dtype = self.out_proj.weight.dtype if dtype is None else dtype
- device = self.out_proj.weight.device
- if self.d_conv > 0:
- conv_state = torch.zeros(
- batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
- )
- else:
- conv_state = None
- kv_cache = torch.empty(
- batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
- )
- return kv_cache, conv_state
- def _update_kv_cache(self, kv, inference_params):
- """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
- assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
- return _update_kv_cache(kv, inference_params, self.layer_idx)
- def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
- """
- Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
- q: (batch_size, seqlen_q, nheads, head_dim)
- kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
- """
- assert inference_params is not None and inference_params.seqlen_offset > 0
- if self.rotary_emb_dim > 0:
- self.rotary_emb._update_cos_sin_cache(
- inference_params.max_seqlen, device=q.device, dtype=q.dtype
- )
- rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
- else:
- rotary_cos, rotary_sin = None, None
- batch = q.shape[0]
- kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
- kv_cache = kv_cache[:batch]
- cache_seqlens = (
- inference_params.lengths_per_sample[:batch]
- if inference_params.lengths_per_sample is not None
- else inference_params.seqlen_offset
- )
- assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
- context = flash_attn_with_kvcache(
- q,
- kv_cache[:, :, 0],
- kv_cache[:, :, 1],
- kv[:, :, 0],
- kv[:, :, 1],
- rotary_cos=rotary_cos,
- rotary_sin=rotary_sin,
- cache_seqlens=cache_seqlens,
- softmax_scale=self.softmax_scale,
- causal=self.causal,
- rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
- )
- return context
- def _update_kvcache_attention(self, q, kv, inference_params):
- """Write kv to inference_params, then do attention"""
- if (
- inference_params.seqlen_offset == 0
- or flash_attn_with_kvcache is None
- ):
- # TODO: this only uses seqlen_offset and not lengths_per_sample.
- kv = self._update_kv_cache(kv, inference_params)
- k, v = kv.unbind(dim=-3)
- k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
- v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
- return F.scaled_dot_product_attention(
- q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
- ).transpose(1, 2)
- else:
- batch = q.shape[0]
- kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
- kv_cache = kv_cache[:batch]
- cache_seqlens = (
- inference_params.lengths_per_sample[:batch]
- if inference_params.lengths_per_sample is not None
- else inference_params.seqlen_offset
- )
- return flash_attn_with_kvcache(
- q,
- kv_cache[:, :, 0],
- kv_cache[:, :, 1],
- kv[:, :, 0],
- kv[:, :, 1],
- cache_seqlens=cache_seqlens,
- softmax_scale=self.softmax_scale,
- causal=self.causal,
- )
- def forward(self, x, inference_params=None):
- """
- Arguments:
- x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
- cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
- is the is the sum of the sequence lengths in the batch.
- inference_params: for generation. Adapted from Megatron-LM (and Apex)
- https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
- """
- if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
- inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
- x.shape[0], inference_params.max_seqlen, dtype=x.dtype
- )
- seqlen_offset = (
- 0
- if inference_params is None
- else (
- inference_params.lengths_per_sample
- if inference_params.lengths_per_sample is not None
- else inference_params.seqlen_offset
- )
- )
- rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
- qkv = self.in_proj(x)
- if self.mlp_dim > 0:
- qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
- x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
- x_mlp = x_mlp_up * F.silu(x_mlp_gate)
- if self.d_conv > 0:
- # The inference code for conv1d is pretty messy, should clean it up
- if (inference_params is None or inference_params.seqlen_offset == 0):
- if causal_conv1d_fn is None:
- qkv = rearrange(
- self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
- ).contiguous()
- else:
- qkv = causal_conv1d_fn(
- qkv.transpose(1, 2),
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
- self.conv1d.bias
- ).transpose(1, 2)
- if inference_params is not None:
- _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
- # If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
- # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
- qkv_t = rearrange(qkv, "b l d -> b d l")
- conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
- else:
- _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
- assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
- qkv = qkv.squeeze(1)
- # Conv step
- if causal_conv1d_update is None:
- conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
- conv_state[:, :, -1] = qkv
- qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
- if self.conv1d.bias is not None:
- qkv = qkv + self.conv1d.bias
- else:
- qkv = causal_conv1d_update(
- qkv,
- conv_state,
- rearrange(self.conv1d.weight, "d 1 w -> d w"),
- self.conv1d.bias
- )
- qkv = qkv.unsqueeze(1)
- q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
- q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
- kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
- if (
- inference_params is None
- or inference_params.seqlen_offset == 0
- or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
- ):
- if self.rotary_emb_dim > 0:
- q, kv = self.rotary_emb(
- q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
- )
- if inference_params is None:
- k, v = kv.unbind(dim=-3)
- k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
- v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
- context = F.scaled_dot_product_attention(
- q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
- ).transpose(1, 2)
- else:
- context = self._update_kvcache_attention(q, kv, inference_params)
- else:
- context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
- context = rearrange(context, "... h d -> ... (h d)")
- if self.mlp_dim > 0:
- context = torch.cat([context, x_mlp], dim=-1)
- out = self.out_proj(context)
- return out
|