mamba2.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # Copyright (c) 2024, Tri Dao, Albert Gu.
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from einops import rearrange, repeat
  7. try:
  8. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  9. except ImportError:
  10. causal_conv1d_fn, causal_conv1d_update = None, None
  11. try:
  12. from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
  13. except ImportError:
  14. causal_conv1d_varlen_states = None
  15. try:
  16. from mamba_ssm.ops.triton.selective_state_update import selective_state_update
  17. except ImportError:
  18. selective_state_update = None
  19. from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
  20. from mamba_ssm.distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
  21. from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter
  22. from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
  23. from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
  24. from huggingface_hub import PyTorchModelHubMixin
  25. class Mamba2(nn.Module, PyTorchModelHubMixin):
  26. def __init__(
  27. self,
  28. d_model,
  29. d_state=128,
  30. d_conv=4,
  31. conv_init=None,
  32. expand=2,
  33. headdim=64,
  34. d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
  35. ngroups=1,
  36. A_init_range=(1, 16),
  37. D_has_hdim=False,
  38. rmsnorm=True,
  39. norm_before_gate=False,
  40. dt_min=0.001,
  41. dt_max=0.1,
  42. dt_init_floor=1e-4,
  43. dt_limit=(0.0, float("inf")),
  44. bias=False,
  45. conv_bias=True,
  46. # Fused kernel and sharding options
  47. chunk_size=256,
  48. use_mem_eff_path=True,
  49. layer_idx=None, # Absorb kwarg for general module
  50. process_group=None,
  51. sequence_parallel=True,
  52. device=None,
  53. dtype=None,
  54. ):
  55. factory_kwargs = {"device": device, "dtype": dtype}
  56. super().__init__()
  57. self.d_model = d_model
  58. self.d_state = d_state
  59. self.d_conv = d_conv
  60. self.conv_init = conv_init
  61. self.expand = expand
  62. self.process_group = process_group
  63. self.sequence_parallel = sequence_parallel
  64. self.world_size = 1 if process_group is None else process_group.size()
  65. self.local_rank = 0 if process_group is None else process_group.rank()
  66. self.d_inner = (self.expand * self.d_model) // self.world_size
  67. assert self.d_inner * self.world_size == self.expand * self.d_model
  68. self.headdim = headdim
  69. self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
  70. assert ngroups % self.world_size == 0
  71. self.ngroups = ngroups // self.world_size
  72. assert self.d_ssm % self.headdim == 0
  73. self.nheads = self.d_ssm // self.headdim
  74. self.D_has_hdim = D_has_hdim
  75. self.rmsnorm = rmsnorm
  76. self.norm_before_gate = norm_before_gate
  77. self.dt_limit = dt_limit
  78. self.activation = "silu"
  79. self.chunk_size = chunk_size
  80. self.use_mem_eff_path = use_mem_eff_path
  81. self.layer_idx = layer_idx
  82. # Order: [z, x, B, C, dt]
  83. d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
  84. if self.process_group is None:
  85. self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
  86. else:
  87. self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias,
  88. process_group=self.process_group, sequence_parallel=self.sequence_parallel,
  89. **factory_kwargs)
  90. conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
  91. self.conv1d = nn.Conv1d(
  92. in_channels=conv_dim,
  93. out_channels=conv_dim,
  94. bias=conv_bias,
  95. kernel_size=d_conv,
  96. groups=conv_dim,
  97. padding=d_conv - 1,
  98. **factory_kwargs,
  99. )
  100. if self.conv_init is not None:
  101. nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
  102. self.act = nn.SiLU()
  103. # Initialize log dt bias
  104. dt = torch.exp(
  105. torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
  106. + math.log(dt_min)
  107. )
  108. dt = torch.clamp(dt, min=dt_init_floor)
  109. # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  110. inv_dt = dt + torch.log(-torch.expm1(-dt))
  111. self.dt_bias = nn.Parameter(inv_dt)
  112. # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
  113. # name.endswith("bias") in param_grouping.py
  114. self.dt_bias._no_weight_decay = True
  115. assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
  116. A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
  117. A_log = torch.log(A).to(dtype=dtype)
  118. self.A_log = nn.Parameter(A_log)
  119. self.A_log._no_weight_decay = True
  120. # D "skip" parameter
  121. self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device))
  122. self.D._no_weight_decay = True
  123. if self.rmsnorm:
  124. assert RMSNormGated is not None
  125. self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
  126. group_size=self.d_ssm // ngroups, **factory_kwargs)
  127. if self.process_group is None:
  128. self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
  129. else:
  130. self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias,
  131. process_group=self.process_group, sequence_parallel=self.sequence_parallel,
  132. **factory_kwargs)
  133. def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):
  134. """
  135. u: (batch, seqlen, hidden_dim) if seqlen=None.
  136. If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
  137. split u during sequence parallel, we split the batch * seqlen dimension
  138. (in case batch is small).
  139. Returns: same shape as u
  140. """
  141. seqlen_og = seqlen
  142. if seqlen is None:
  143. batch, seqlen, dim = u.shape
  144. else:
  145. batch_seqlen, dim = u.shape
  146. batch = batch_seqlen // seqlen
  147. conv_state, ssm_state = None, None
  148. if inference_params is not None:
  149. inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
  150. conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
  151. if inference_params.seqlen_offset > 0:
  152. # The states are updated inplace
  153. out, _, _ = self.step(u, conv_state, ssm_state)
  154. return out
  155. zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
  156. if seqlen_og is not None:
  157. zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
  158. # If the model is loaded in fp16, without the .float() here, A might be -inf
  159. A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
  160. dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
  161. if self.use_mem_eff_path and inference_params is None:
  162. out = mamba_split_conv1d_scan_combined(
  163. zxbcdt,
  164. rearrange(self.conv1d.weight, "d 1 w -> d w"),
  165. self.conv1d.bias,
  166. self.dt_bias,
  167. A,
  168. D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
  169. chunk_size=self.chunk_size,
  170. seq_idx=seq_idx,
  171. activation=self.activation,
  172. rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
  173. rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
  174. outproj_weight=self.out_proj.weight,
  175. outproj_bias=self.out_proj.bias,
  176. headdim=None if self.D_has_hdim else self.headdim,
  177. ngroups=self.ngroups,
  178. norm_before_gate=self.norm_before_gate,
  179. **dt_limit_kwargs,
  180. )
  181. if seqlen_og is not None:
  182. out = rearrange(out, "b l d -> (b l) d")
  183. if self.process_group is not None:
  184. reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
  185. out = reduce_fn(out, self.process_group)
  186. else:
  187. d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
  188. z0, x0, z, xBC, dt = torch.split(
  189. zxbcdt,
  190. [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
  191. dim=-1
  192. )
  193. if conv_state is not None:
  194. if cu_seqlens is None:
  195. # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
  196. # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
  197. xBC_t = rearrange(xBC, "b l d -> b d l")
  198. conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
  199. else:
  200. assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package"
  201. assert batch == 1, "varlen inference only supports batch dimension 1"
  202. conv_varlen_states = causal_conv1d_varlen_states(
  203. xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
  204. )
  205. conv_state.copy_(conv_varlen_states)
  206. assert self.activation in ["silu", "swish"]
  207. if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
  208. assert seq_idx is None, "varlen conv1d requires the causal_conv1d package"
  209. xBC = self.act(
  210. self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, :-(self.d_conv - 1)]
  211. ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
  212. else:
  213. xBC = causal_conv1d_fn(
  214. xBC.transpose(1, 2),
  215. rearrange(self.conv1d.weight, "d 1 w -> d w"),
  216. bias=self.conv1d.bias,
  217. activation=self.activation,
  218. seq_idx=seq_idx,
  219. ).transpose(1, 2)
  220. x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
  221. y = mamba_chunk_scan_combined(
  222. rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
  223. dt,
  224. A,
  225. rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
  226. rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
  227. chunk_size=self.chunk_size,
  228. D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
  229. z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None,
  230. dt_bias=self.dt_bias,
  231. dt_softplus=True,
  232. seq_idx=seq_idx,
  233. cu_seqlens=cu_seqlens,
  234. **dt_limit_kwargs,
  235. return_final_states=ssm_state is not None,
  236. return_varlen_states=cu_seqlens is not None and inference_params is not None,
  237. )
  238. if ssm_state is not None:
  239. y, last_state, *rest = y
  240. if cu_seqlens is None:
  241. ssm_state.copy_(last_state)
  242. else:
  243. varlen_states = rest[0]
  244. ssm_state.copy_(varlen_states)
  245. y = rearrange(y, "b l h p -> b l (h p)")
  246. if self.rmsnorm:
  247. y = self.norm(y, z)
  248. if d_mlp > 0:
  249. y = torch.cat([F.silu(z0) * x0, y], dim=-1)
  250. if seqlen_og is not None:
  251. y = rearrange(y, "b l d -> (b l) d")
  252. out = self.out_proj(y)
  253. return out
  254. def step(self, hidden_states, conv_state, ssm_state):
  255. dtype = hidden_states.dtype
  256. assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
  257. zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
  258. d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
  259. z0, x0, z, xBC, dt = torch.split(
  260. zxbcdt,
  261. [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
  262. dim=-1
  263. )
  264. # Conv step
  265. if causal_conv1d_update is None:
  266. conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
  267. conv_state[:, :, -1] = xBC
  268. xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
  269. if self.conv1d.bias is not None:
  270. xBC = xBC + self.conv1d.bias
  271. xBC = self.act(xBC).to(dtype=dtype)
  272. else:
  273. xBC = causal_conv1d_update(
  274. xBC,
  275. conv_state,
  276. rearrange(self.conv1d.weight, "d 1 w -> d w"),
  277. self.conv1d.bias,
  278. self.activation,
  279. )
  280. x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
  281. A = -torch.exp(self.A_log.float()) # (nheads,)
  282. # SSM step
  283. if selective_state_update is None:
  284. assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
  285. # Discretize A and B
  286. dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
  287. dA = torch.exp(dt * A) # (batch, nheads)
  288. x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
  289. dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
  290. ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
  291. y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
  292. y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
  293. y = rearrange(y, "b h p -> b (h p)")
  294. if not self.rmsnorm:
  295. y = y * self.act(z) # (B D)
  296. else:
  297. A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
  298. dt = repeat(dt, "b h -> b h p", p=self.headdim)
  299. dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
  300. D = repeat(self.D, "h -> h p", p=self.headdim)
  301. B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
  302. C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
  303. x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
  304. if not self.rmsnorm:
  305. z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
  306. y = selective_state_update(
  307. ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None,
  308. dt_bias=dt_bias, dt_softplus=True
  309. )
  310. y = rearrange(y, "b h p -> b (h p)")
  311. if self.rmsnorm:
  312. y = self.norm(y, z)
  313. if d_mlp > 0:
  314. y = torch.cat([F.silu(z0) * x0, y], dim=-1)
  315. out = self.out_proj(y)
  316. return out.unsqueeze(1), conv_state, ssm_state
  317. def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
  318. device = self.out_proj.weight.device
  319. conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
  320. conv_state = torch.zeros(
  321. batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype
  322. ).transpose(1, 2)
  323. ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
  324. ssm_state = torch.zeros(
  325. batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype
  326. )
  327. return conv_state, ssm_state
  328. def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
  329. assert self.layer_idx is not None
  330. if self.layer_idx not in inference_params.key_value_memory_dict:
  331. batch_shape = (batch_size,)
  332. conv_state = torch.zeros(
  333. batch_size,
  334. self.d_conv,
  335. self.conv1d.weight.shape[0],
  336. device=self.conv1d.weight.device,
  337. dtype=self.conv1d.weight.dtype,
  338. ).transpose(1, 2)
  339. ssm_state = torch.zeros(
  340. batch_size,
  341. self.nheads,
  342. self.headdim,
  343. self.d_state,
  344. device=self.in_proj.weight.device,
  345. dtype=self.in_proj.weight.dtype,
  346. )
  347. inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
  348. else:
  349. conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
  350. # TODO: What if batch size changes between generation, and we reuse the same states?
  351. if initialize_states:
  352. conv_state.zero_()
  353. ssm_state.zero_()
  354. return conv_state, ssm_state