mamba_simple.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # Copyright (c) 2023, Tri Dao, Albert Gu.
  2. import math
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch import Tensor
  8. from einops import rearrange, repeat
  9. from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
  10. try:
  11. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  12. except ImportError:
  13. causal_conv1d_fn, causal_conv1d_update = None, None
  14. try:
  15. from mamba_ssm.ops.triton.selective_state_update import selective_state_update
  16. except ImportError:
  17. selective_state_update = None
  18. try:
  19. from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
  20. except ImportError:
  21. RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
  22. class Mamba(nn.Module):
  23. def __init__(
  24. self,
  25. d_model,
  26. d_state=16,
  27. d_conv=4,
  28. expand=2,
  29. dt_rank="auto",
  30. dt_min=0.001,
  31. dt_max=0.1,
  32. dt_init="random",
  33. dt_scale=1.0,
  34. dt_init_floor=1e-4,
  35. conv_bias=True,
  36. bias=False,
  37. use_fast_path=True, # Fused kernel options
  38. layer_idx=None,
  39. device=None,
  40. dtype=None,
  41. ):
  42. factory_kwargs = {"device": device, "dtype": dtype}
  43. super().__init__()
  44. self.d_model = d_model
  45. self.d_state = d_state
  46. self.d_conv = d_conv
  47. self.expand = expand
  48. self.d_inner = int(self.expand * self.d_model)
  49. self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
  50. self.use_fast_path = use_fast_path
  51. self.layer_idx = layer_idx
  52. self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
  53. self.conv1d = nn.Conv1d(
  54. in_channels=self.d_inner,
  55. out_channels=self.d_inner,
  56. bias=conv_bias,
  57. kernel_size=d_conv,
  58. groups=self.d_inner,
  59. padding=d_conv - 1,
  60. **factory_kwargs,
  61. )
  62. self.activation = "silu"
  63. self.act = nn.SiLU()
  64. self.x_proj = nn.Linear(
  65. self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
  66. )
  67. self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
  68. # Initialize special dt projection to preserve variance at initialization
  69. dt_init_std = self.dt_rank**-0.5 * dt_scale
  70. if dt_init == "constant":
  71. nn.init.constant_(self.dt_proj.weight, dt_init_std)
  72. elif dt_init == "random":
  73. nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
  74. else:
  75. raise NotImplementedError
  76. # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
  77. dt = torch.exp(
  78. torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
  79. + math.log(dt_min)
  80. ).clamp(min=dt_init_floor)
  81. # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  82. inv_dt = dt + torch.log(-torch.expm1(-dt))
  83. with torch.no_grad():
  84. self.dt_proj.bias.copy_(inv_dt)
  85. # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
  86. self.dt_proj.bias._no_reinit = True
  87. # S4D real initialization
  88. A = repeat(
  89. torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
  90. "n -> d n",
  91. d=self.d_inner,
  92. ).contiguous()
  93. A_log = torch.log(A) # Keep A_log in fp32
  94. self.A_log = nn.Parameter(A_log)
  95. self.A_log._no_weight_decay = True
  96. # D "skip" parameter
  97. self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
  98. self.D._no_weight_decay = True
  99. self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
  100. def forward(self, hidden_states, inference_params=None):
  101. """
  102. hidden_states: (B, L, D)
  103. Returns: same shape as hidden_states
  104. """
  105. batch, seqlen, dim = hidden_states.shape
  106. conv_state, ssm_state = None, None
  107. if inference_params is not None:
  108. conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
  109. if inference_params.seqlen_offset > 0:
  110. # The states are updated inplace
  111. out, _, _ = self.step(hidden_states, conv_state, ssm_state)
  112. return out
  113. # We do matmul and transpose BLH -> HBL at the same time
  114. xz = rearrange(
  115. self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
  116. "d (b l) -> b d l",
  117. l=seqlen,
  118. )
  119. if self.in_proj.bias is not None:
  120. xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
  121. A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
  122. # In the backward pass we write dx and dz next to each other to avoid torch.cat
  123. if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
  124. out = mamba_inner_fn(
  125. xz,
  126. self.conv1d.weight,
  127. self.conv1d.bias,
  128. self.x_proj.weight,
  129. self.dt_proj.weight,
  130. self.out_proj.weight,
  131. self.out_proj.bias,
  132. A,
  133. None, # input-dependent B
  134. None, # input-dependent C
  135. self.D.float(),
  136. delta_bias=self.dt_proj.bias.float(),
  137. delta_softplus=True,
  138. )
  139. else:
  140. x, z = xz.chunk(2, dim=1)
  141. # Compute short convolution
  142. if conv_state is not None:
  143. # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
  144. # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
  145. conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
  146. if causal_conv1d_fn is None:
  147. x = self.act(self.conv1d(x)[..., :seqlen])
  148. else:
  149. assert self.activation in ["silu", "swish"]
  150. x = causal_conv1d_fn(
  151. x=x,
  152. weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
  153. bias=self.conv1d.bias,
  154. activation=self.activation,
  155. )
  156. # We're careful here about the layout, to avoid extra transposes.
  157. # We want dt to have d as the slowest moving dimension
  158. # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
  159. x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
  160. dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
  161. dt = self.dt_proj.weight @ dt.t()
  162. dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
  163. B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
  164. C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
  165. assert self.activation in ["silu", "swish"]
  166. y = selective_scan_fn(
  167. x,
  168. dt,
  169. A,
  170. B,
  171. C,
  172. self.D.float(),
  173. z=z,
  174. delta_bias=self.dt_proj.bias.float(),
  175. delta_softplus=True,
  176. return_last_state=ssm_state is not None,
  177. )
  178. if ssm_state is not None:
  179. y, last_state = y
  180. ssm_state.copy_(last_state)
  181. y = rearrange(y, "b d l -> b l d")
  182. out = self.out_proj(y)
  183. return out
  184. def step(self, hidden_states, conv_state, ssm_state):
  185. dtype = hidden_states.dtype
  186. assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
  187. xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
  188. x, z = xz.chunk(2, dim=-1) # (B D)
  189. # Conv step
  190. if causal_conv1d_update is None:
  191. conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
  192. conv_state[:, :, -1] = x
  193. x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
  194. if self.conv1d.bias is not None:
  195. x = x + self.conv1d.bias
  196. x = self.act(x).to(dtype=dtype)
  197. else:
  198. x = causal_conv1d_update(
  199. x,
  200. conv_state,
  201. rearrange(self.conv1d.weight, "d 1 w -> d w"),
  202. self.conv1d.bias,
  203. self.activation,
  204. )
  205. x_db = self.x_proj(x) # (B dt_rank+2*d_state)
  206. dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
  207. # Don't add dt_bias here
  208. dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
  209. A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
  210. # SSM step
  211. if selective_state_update is None:
  212. # Discretize A and B
  213. dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
  214. dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
  215. dB = torch.einsum("bd,bn->bdn", dt, B)
  216. ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
  217. y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
  218. y = y + self.D.to(dtype) * x
  219. y = y * self.act(z) # (B D)
  220. else:
  221. y = selective_state_update(
  222. ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
  223. )
  224. out = self.out_proj(y)
  225. return out.unsqueeze(1), conv_state, ssm_state
  226. def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
  227. device = self.out_proj.weight.device
  228. conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
  229. conv_state = torch.zeros(
  230. batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
  231. )
  232. ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
  233. # ssm_dtype = torch.float32
  234. ssm_state = torch.zeros(
  235. batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
  236. )
  237. return conv_state, ssm_state
  238. def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
  239. assert self.layer_idx is not None
  240. if self.layer_idx not in inference_params.key_value_memory_dict:
  241. batch_shape = (batch_size,)
  242. conv_state = torch.zeros(
  243. batch_size,
  244. self.d_model * self.expand,
  245. self.d_conv,
  246. device=self.conv1d.weight.device,
  247. dtype=self.conv1d.weight.dtype,
  248. )
  249. ssm_state = torch.zeros(
  250. batch_size,
  251. self.d_model * self.expand,
  252. self.d_state,
  253. device=self.dt_proj.weight.device,
  254. dtype=self.dt_proj.weight.dtype,
  255. # dtype=torch.float32,
  256. )
  257. inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
  258. else:
  259. conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
  260. # TODO: What if batch size changes between generation, and we reuse the same states?
  261. if initialize_states:
  262. conv_state.zero_()
  263. ssm_state.zero_()
  264. return conv_state, ssm_state