mamba2_simple.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # Copyright (c) 2024, Tri Dao, Albert Gu.
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from einops import rearrange, repeat
  7. try:
  8. from causal_conv1d import causal_conv1d_fn
  9. except ImportError:
  10. causal_conv1d_fn = None
  11. try:
  12. from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
  13. except ImportError:
  14. RMSNormGated, LayerNorm = None, None
  15. from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
  16. from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
  17. class Mamba2Simple(nn.Module):
  18. def __init__(
  19. self,
  20. d_model,
  21. d_state=64,
  22. d_conv=4,
  23. conv_init=None,
  24. expand=2,
  25. headdim=128,
  26. ngroups=1,
  27. A_init_range=(1, 16),
  28. dt_min=0.001,
  29. dt_max=0.1,
  30. dt_init_floor=1e-4,
  31. dt_limit=(0.0, float("inf")),
  32. learnable_init_states=False,
  33. activation="swish",
  34. bias=False,
  35. conv_bias=True,
  36. # Fused kernel and sharding options
  37. chunk_size=256,
  38. use_mem_eff_path=True,
  39. layer_idx=None, # Absorb kwarg for general module
  40. device=None,
  41. dtype=None,
  42. ):
  43. factory_kwargs = {"device": device, "dtype": dtype}
  44. super().__init__()
  45. self.d_model = d_model
  46. self.d_state = d_state
  47. self.d_conv = d_conv
  48. self.conv_init = conv_init
  49. self.expand = expand
  50. self.d_inner = self.expand * self.d_model
  51. self.headdim = headdim
  52. self.ngroups = ngroups
  53. assert self.d_inner % self.headdim == 0
  54. self.nheads = self.d_inner // self.headdim
  55. self.dt_limit = dt_limit
  56. self.learnable_init_states = learnable_init_states
  57. self.activation = activation
  58. self.chunk_size = chunk_size
  59. self.use_mem_eff_path = use_mem_eff_path
  60. self.layer_idx = layer_idx
  61. # Order: [z, x, B, C, dt]
  62. d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
  63. self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
  64. conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
  65. self.conv1d = nn.Conv1d(
  66. in_channels=conv_dim,
  67. out_channels=conv_dim,
  68. bias=conv_bias,
  69. kernel_size=d_conv,
  70. groups=conv_dim,
  71. padding=d_conv - 1,
  72. **factory_kwargs,
  73. )
  74. if self.conv_init is not None:
  75. nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
  76. # self.conv1d.weight._no_weight_decay = True
  77. if self.learnable_init_states:
  78. self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs))
  79. self.init_states._no_weight_decay = True
  80. self.act = nn.SiLU()
  81. # Initialize log dt bias
  82. dt = torch.exp(
  83. torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
  84. + math.log(dt_min)
  85. )
  86. dt = torch.clamp(dt, min=dt_init_floor)
  87. # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  88. inv_dt = dt + torch.log(-torch.expm1(-dt))
  89. self.dt_bias = nn.Parameter(inv_dt)
  90. # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
  91. # name.endswith("bias") in param_grouping.py
  92. self.dt_bias._no_weight_decay = True
  93. # A parameter
  94. assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
  95. A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
  96. A_log = torch.log(A).to(dtype=dtype)
  97. self.A_log = nn.Parameter(A_log)
  98. # self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
  99. self.A_log._no_weight_decay = True
  100. # D "skip" parameter
  101. self.D = nn.Parameter(torch.ones(self.nheads, device=device))
  102. self.D._no_weight_decay = True
  103. # Extra normalization layer right before output projection
  104. assert RMSNormGated is not None
  105. self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)
  106. self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
  107. def forward(self, u, seq_idx=None):
  108. """
  109. u: (B, L, D)
  110. Returns: same shape as u
  111. """
  112. batch, seqlen, dim = u.shape
  113. zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
  114. A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
  115. initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else None
  116. dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
  117. if self.use_mem_eff_path:
  118. # Fully fused path
  119. out = mamba_split_conv1d_scan_combined(
  120. zxbcdt,
  121. rearrange(self.conv1d.weight, "d 1 w -> d w"),
  122. self.conv1d.bias,
  123. self.dt_bias,
  124. A,
  125. D=self.D,
  126. chunk_size=self.chunk_size,
  127. seq_idx=seq_idx,
  128. activation=self.activation,
  129. rmsnorm_weight=self.norm.weight,
  130. rmsnorm_eps=self.norm.eps,
  131. outproj_weight=self.out_proj.weight,
  132. outproj_bias=self.out_proj.bias,
  133. headdim=self.headdim,
  134. ngroups=self.ngroups,
  135. norm_before_gate=False,
  136. initial_states=initial_states,
  137. **dt_limit_kwargs,
  138. )
  139. else:
  140. z, xBC, dt = torch.split(
  141. zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1
  142. )
  143. dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
  144. assert self.activation in ["silu", "swish"]
  145. # 1D Convolution
  146. if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
  147. xBC = self.act(
  148. self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
  149. ) # (B, L, self.d_inner + 2 * ngroups * d_state)
  150. xBC = xBC[:, :seqlen, :]
  151. else:
  152. xBC = causal_conv1d_fn(
  153. x=xBC.transpose(1, 2),
  154. weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
  155. bias=self.conv1d.bias,
  156. activation=self.activation,
  157. ).transpose(1, 2)
  158. # Split into 3 main branches: X, B, C
  159. # These correspond to V, K, Q respectively in the SSM/attention duality
  160. x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
  161. y = mamba_chunk_scan_combined(
  162. rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
  163. dt,
  164. A,
  165. rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
  166. rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
  167. chunk_size=self.chunk_size,
  168. D=self.D,
  169. z=None,
  170. seq_idx=seq_idx,
  171. initial_states=initial_states,
  172. **dt_limit_kwargs,
  173. )
  174. y = rearrange(y, "b l h p -> b l (h p)")
  175. # Multiply "gate" branch and apply extra normalization layer
  176. y = self.norm(y, z)
  177. out = self.out_proj(y)
  178. return out