modeling_mamba2.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931
  1. # Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch MAMBA2 model."""
  15. import math
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache, DynamicCache
  22. from ...generation import GenerationMixin
  23. from ...integrations import lazy_load_kernel
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import ModelOutput, auto_docstring, is_torchdynamo_compiling, logging
  27. from ...utils.import_utils import resolve_internal_import
  28. from .configuration_mamba2 import Mamba2Config
  29. logger = logging.get_logger(__name__)
  30. # Helper methods for segment sum computation
  31. def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
  32. """
  33. Padding x tensor with `pad_size` on the seq_len dim (dim=1)
  34. Assumes that we only have tensors of either size 4 or 3
  35. """
  36. pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
  37. return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
  38. def reshape_into_chunks(input_tensor, pad_size, chunk_size):
  39. """
  40. Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
  41. simultaneously splitting it into chunk sequences.
  42. Assumes that we only have tensors of either size 4 or 3
  43. """
  44. # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
  45. input_tensor = pad_tensor_by_size(input_tensor, pad_size)
  46. if len(input_tensor.shape) == 3:
  47. # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
  48. return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
  49. else:
  50. # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
  51. return input_tensor.reshape(
  52. input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
  53. )
  54. def segment_sum(input_tensor):
  55. """
  56. More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
  57. """
  58. chunk_size = input_tensor.size(-1)
  59. # 1. expand input tensor to have an additional dimension and repeat along that dimension
  60. # [..., chunk_size] -> [..., chunk_size, chunk_size]
  61. input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
  62. # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
  63. mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
  64. input_tensor = input_tensor.masked_fill(~mask, 0)
  65. # 3. compute actual cumsum
  66. tensor_segsum = torch.cumsum(input_tensor, dim=-2)
  67. # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
  68. mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
  69. tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
  70. return tensor_segsum
  71. def apply_mask_to_padding_states(hidden_states, attention_mask):
  72. """
  73. Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
  74. """
  75. # NOTE: attention mask is a 2D boolean tensor
  76. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  77. dtype = hidden_states.dtype
  78. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  79. return hidden_states
  80. class MambaRMSNormGated(torch.nn.Module):
  81. def __init__(self, hidden_size, eps=1e-6):
  82. super().__init__()
  83. self.weight = nn.Parameter(torch.ones(hidden_size))
  84. self.variance_epsilon = eps
  85. def forward(self, hidden_states, gate=None):
  86. input_dtype = hidden_states.dtype
  87. hidden_states = hidden_states.to(torch.float32)
  88. if gate is not None:
  89. hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
  90. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  91. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  92. return self.weight * hidden_states.to(input_dtype)
  93. class Mamba2Mixer(nn.Module):
  94. """
  95. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  96. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  97. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  98. and is why Mamba is called **selective** state spaces)
  99. """
  100. def __init__(self, config: Mamba2Config, layer_idx: int, initialize_mixer_weights: bool = True):
  101. super().__init__()
  102. self.num_heads = config.num_heads
  103. self.hidden_size = config.hidden_size
  104. self.ssm_state_size = config.state_size
  105. self.conv_kernel_size = config.conv_kernel
  106. self.intermediate_size = int(config.expand * self.hidden_size)
  107. self.time_step_rank = int(config.time_step_rank)
  108. self.layer_idx = layer_idx
  109. self.use_conv_bias = config.use_conv_bias
  110. self.activation = config.hidden_act
  111. self.act = ACT2FN[config.hidden_act]
  112. self.layer_norm_epsilon = config.layer_norm_epsilon
  113. self.rms_norm = config.rms_norm
  114. self.n_groups = config.n_groups
  115. self.head_dim = config.head_dim
  116. self.chunk_size = config.chunk_size
  117. self.time_step_limit = config.time_step_limit
  118. self.time_step_min = config.time_step_min
  119. self.time_step_max = config.time_step_max
  120. self.time_step_floor = config.time_step_floor
  121. self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
  122. self.conv1d = nn.Conv1d(
  123. in_channels=self.conv_dim,
  124. out_channels=self.conv_dim,
  125. bias=config.use_conv_bias,
  126. kernel_size=config.conv_kernel,
  127. groups=self.conv_dim,
  128. padding=config.conv_kernel - 1,
  129. )
  130. # projection of the input hidden states
  131. projection_size = self.intermediate_size + self.conv_dim + self.num_heads
  132. self.in_proj = nn.Linear(
  133. self.hidden_size,
  134. projection_size,
  135. bias=config.use_bias,
  136. )
  137. # selective projection used to make dt, B and C input dependent
  138. # time step projection (discretization)
  139. self.dt_bias = nn.Parameter(torch.empty(self.num_heads))
  140. # S4D real initialization. These are not discretized!
  141. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  142. self.A_log = nn.Parameter(torch.empty(self.num_heads))
  143. self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
  144. self.D = nn.Parameter(torch.empty(self.num_heads))
  145. if initialize_mixer_weights and self.dt_bias.device.type != "meta":
  146. self.init_mamba2_weights()
  147. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
  148. self.use_bias = config.use_bias
  149. global causal_conv1d_update, causal_conv1d_fn
  150. causal_conv1d = lazy_load_kernel("causal-conv1d")
  151. causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
  152. causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
  153. global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
  154. mamba_ssm = lazy_load_kernel("mamba-ssm")
  155. selective_state_update = resolve_internal_import(
  156. mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
  157. )
  158. mamba_chunk_scan_combined = resolve_internal_import(
  159. mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
  160. )
  161. mamba_split_conv1d_scan_combined = resolve_internal_import(
  162. mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
  163. )
  164. global is_fast_path_available
  165. is_fast_path_available = all(
  166. (
  167. selective_state_update,
  168. mamba_chunk_scan_combined,
  169. mamba_split_conv1d_scan_combined,
  170. causal_conv1d_fn,
  171. causal_conv1d_update,
  172. )
  173. )
  174. if not is_fast_path_available:
  175. logger.warning_once(
  176. "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
  177. " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
  178. " https://github.com/Dao-AILab/causal-conv1d"
  179. )
  180. @torch.no_grad()
  181. def init_mamba2_weights(self):
  182. A = torch.arange(1, self.num_heads + 1, device=self.A_log.device, dtype=torch.float32)
  183. init.copy_(self.A_log, torch.log(A))
  184. init.ones_(self.D)
  185. dt = torch.exp(
  186. torch.rand(self.num_heads, device=self.dt_bias.device, dtype=torch.float32)
  187. * (math.log(self.time_step_max) - math.log(self.time_step_min))
  188. + math.log(self.time_step_min)
  189. ).clamp(min=self.time_step_floor)
  190. # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  191. inv_dt = dt + torch.log(-torch.expm1(-dt))
  192. init.copy_(self.dt_bias, inv_dt)
  193. def cuda_kernels_forward(
  194. self,
  195. hidden_states: torch.Tensor,
  196. cache_params: Cache | None = None,
  197. attention_mask: torch.Tensor | None = None,
  198. ):
  199. # 1. Gated MLP's linear projection
  200. hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
  201. projected_states = self.in_proj(hidden_states)
  202. # Set up dimensions for reshapes later
  203. batch_size, seq_len, _ = hidden_states.shape
  204. groups_time_state_size = self.n_groups * self.ssm_state_size
  205. d_mlp = (
  206. projected_states.shape[-1]
  207. - 2 * self.intermediate_size
  208. - 2 * self.n_groups * self.ssm_state_size
  209. - self.num_heads
  210. ) // 2
  211. # Single step calculations via cache
  212. if cache_params is not None and cache_params.has_previous_state(self.layer_idx):
  213. _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
  214. [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  215. )
  216. # 2. Convolution sequence transformation
  217. hidden_states_B_C = causal_conv1d_update(
  218. hidden_states_B_C,
  219. cache_params.layers[self.layer_idx].conv_states,
  220. self.conv1d.weight.squeeze(1),
  221. self.conv1d.bias,
  222. self.activation,
  223. )
  224. hidden_states, B, C = torch.split(
  225. hidden_states_B_C,
  226. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  227. dim=-1,
  228. )
  229. # 3. SSM transformation
  230. A = -torch.exp(self.A_log.float()) # (nheads,)
  231. A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  232. dt = dt[:, :, None].expand(-1, -1, self.head_dim)
  233. dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
  234. D = self.D[:, None, ...].expand(-1, self.head_dim)
  235. B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
  236. C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
  237. hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
  238. hidden_states = selective_state_update(
  239. cache_params.layers[self.layer_idx].recurrent_states,
  240. hidden_states_reshaped,
  241. dt,
  242. A,
  243. B,
  244. C,
  245. D,
  246. z=None,
  247. dt_bias=dt_bias,
  248. dt_softplus=True,
  249. )
  250. hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
  251. hidden_states = self.norm(hidden_states, gate)
  252. # 4. Final linear projection
  253. out = self.out_proj(hidden_states)[:, None, ...]
  254. # Fused calculations or step by step if no initialized cache is found
  255. else:
  256. A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
  257. dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
  258. # 2-4. Fused kernel for conv1d, SSM, and the final projection
  259. if self.training and cache_params is None:
  260. out = mamba_split_conv1d_scan_combined(
  261. projected_states,
  262. self.conv1d.weight.squeeze(1),
  263. self.conv1d.bias,
  264. self.dt_bias,
  265. A,
  266. D=self.D,
  267. chunk_size=self.chunk_size,
  268. seq_idx=None, # was seq_idx
  269. activation=self.activation,
  270. rmsnorm_weight=self.norm.weight,
  271. rmsnorm_eps=self.norm.variance_epsilon,
  272. outproj_weight=self.out_proj.weight,
  273. outproj_bias=self.out_proj.bias,
  274. headdim=self.head_dim,
  275. ngroups=self.n_groups,
  276. norm_before_gate=False,
  277. return_final_states=False,
  278. **dt_limit_kwargs,
  279. )
  280. else:
  281. _, _, gate, hidden_states_B_C, dt = projected_states.split(
  282. [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  283. )
  284. # 2. Convolution sequence transformation
  285. # Init cache
  286. if cache_params is not None:
  287. hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
  288. conv_states = nn.functional.pad(
  289. hidden_states_B_C_transposed,
  290. (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
  291. )
  292. conv_states = cache_params.update_conv_state(conv_states, layer_idx=self.layer_idx)
  293. if self.activation not in ["silu", "swish"]:
  294. hidden_states_B_C = self.act(
  295. self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
  296. )
  297. else:
  298. hidden_states_B_C = causal_conv1d_fn(
  299. x=hidden_states_B_C.transpose(1, 2),
  300. weight=self.conv1d.weight.squeeze(1),
  301. bias=self.conv1d.bias,
  302. activation=self.activation,
  303. ).transpose(1, 2)
  304. hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
  305. hidden_states, B, C = torch.split(
  306. hidden_states_B_C,
  307. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  308. dim=-1,
  309. )
  310. # 3. SSM transformation
  311. scan_output, ssm_state = mamba_chunk_scan_combined(
  312. hidden_states.view(batch_size, seq_len, -1, self.head_dim),
  313. dt,
  314. A,
  315. B.view(batch_size, seq_len, self.n_groups, -1),
  316. C.view(batch_size, seq_len, self.n_groups, -1),
  317. chunk_size=self.chunk_size,
  318. D=self.D,
  319. z=None,
  320. seq_idx=None,
  321. return_final_states=True,
  322. dt_bias=self.dt_bias,
  323. dt_softplus=True,
  324. **dt_limit_kwargs,
  325. )
  326. # Init cache
  327. if ssm_state is not None and cache_params is not None:
  328. cache_params.update_recurrent_state(ssm_state, layer_idx=self.layer_idx)
  329. scan_output = scan_output.view(batch_size, seq_len, -1)
  330. # Multiply "gate" branch and apply extra normalization layer
  331. scan_output = self.norm(scan_output, gate)
  332. # 4. Final linear projection
  333. out = self.out_proj(scan_output)
  334. return out
  335. # fmt: off
  336. def torch_forward(
  337. self,
  338. hidden_states: torch.Tensor,
  339. cache_params: Cache | None = None,
  340. attention_mask: torch.Tensor | None = None
  341. ):
  342. batch_size, seq_len, _ = hidden_states.shape
  343. dtype = hidden_states.dtype
  344. # 1. Gated MLP's linear projection
  345. hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
  346. projected_states = self.in_proj(hidden_states)
  347. d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
  348. _, _, gate, hidden_states_B_C, dt = projected_states.split(
  349. [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  350. )
  351. hidden_states_B_C = hidden_states_B_C.transpose(1,2)
  352. is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx)
  353. # 2. Convolution sequence transformation
  354. if is_decoding:
  355. conv_states = cache_params.update_conv_state(hidden_states_B_C, layer_idx=self.layer_idx)
  356. hidden_states_B_C = torch.sum(
  357. conv_states * self.conv1d.weight.squeeze(1), dim=-1
  358. )
  359. if self.use_conv_bias:
  360. hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
  361. hidden_states_B_C = self.act(hidden_states_B_C)
  362. else:
  363. # Init cache
  364. if cache_params is not None:
  365. conv_states = nn.functional.pad(
  366. hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0)
  367. )
  368. cache_params.update_conv_state(conv_states, layer_idx=self.layer_idx)
  369. hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2))
  370. hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
  371. hidden_states, B, C = torch.split(
  372. hidden_states_B_C,
  373. [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
  374. dim=-1
  375. )
  376. # 3. SSM transformation
  377. A = -torch.exp(self.A_log.float()) # [num_heads]
  378. if is_decoding:
  379. # We need to guarantee that anything regarding the cache is on the same device
  380. cache_device = cache_params.layers[self.layer_idx].device
  381. # Note: there is no need to pad parameter matrices here, as there is just one new token
  382. # for batched generation
  383. dt = dt[:, 0, :][:, None, ...]
  384. dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
  385. # [num_heads] -> [num_heads, head_dim]
  386. dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
  387. dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
  388. dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
  389. A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  390. # [bsz, num_heads, head_dim, state_size]
  391. dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
  392. # Discretize B
  393. # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
  394. # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
  395. B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
  396. B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
  397. B = B.reshape(batch_size, -1, B.shape[-1])
  398. # [bsz, num_heads, head_dim, state_size]
  399. dB = dt[..., None] * B[..., None, :]
  400. # Discretize x into dB
  401. # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
  402. hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
  403. dBx = (dB * hidden_states[..., None]).to(device=cache_device)
  404. # State calculation
  405. ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx
  406. ssm_states = cache_params.update_recurrent_state(ssm_states, layer_idx=self.layer_idx)
  407. # Subsequent output
  408. # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
  409. C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
  410. C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
  411. C = C.reshape(batch_size, -1, C.shape[-1])
  412. # [bsz, num_heads, head_dim]
  413. # Reshape ssm_states to merge the first two dimensions
  414. ssm_states = ssm_states.to(device=C.device, dtype=C.dtype)
  415. ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
  416. C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
  417. y = torch.bmm(ssm_states_reshaped, C_reshaped)
  418. y = y.view(batch_size, self.num_heads, self.head_dim)
  419. # D skip connection
  420. # [num_heads] -> [num_heads, head_dim]
  421. D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
  422. y = (y + hidden_states * D).to(y.dtype)
  423. # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
  424. y = y.reshape(batch_size, -1)[:, None, ...]
  425. else:
  426. # begin ssd naive implementation without einsums
  427. dt = nn.functional.softplus(dt + self.dt_bias)
  428. dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
  429. hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
  430. B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  431. C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  432. B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  433. C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  434. pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
  435. D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
  436. # Discretize x and A
  437. hidden_states = hidden_states * dt[..., None]
  438. A = A.to(hidden_states.dtype) * dt
  439. # Rearrange into blocks/chunks
  440. hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
  441. # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
  442. A = A.permute(0, 3, 1, 2)
  443. A_cumsum = torch.cumsum(A, dim=-1)
  444. # 1. Compute the output for each intra-chunk (diagonal blocks)
  445. # This is the analog of a causal mask
  446. L = torch.exp(segment_sum(A))
  447. # Contraction of C and B to get G (attention-weights like)
  448. G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n)
  449. G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
  450. # Compute M, equivalent to applying attention mask to weights
  451. M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
  452. M = M_intermediate.sum(dim=-1)
  453. # Compute Y_diag (apply to values)
  454. Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)
  455. # 2. Compute the state for each intra-chunk
  456. # (right term of low-rank factorization of off-diagonal blocks; B terms)
  457. decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
  458. B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
  459. states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)
  460. # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
  461. # (middle term of factorization of off-diag blocks; A terms)
  462. previous_states = torch.zeros_like(states[:, :1])
  463. states = torch.cat([previous_states, states], dim=1)
  464. decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
  465. decay_chunk = decay_chunk.transpose(1, 3)
  466. new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
  467. states, ssm_state = new_states[:, :-1], new_states[:, -1]
  468. # 4. Compute state -> output conversion per chunk
  469. # (left term of low-rank factorization of off-diagonal blocks; C terms)
  470. state_decay_out = torch.exp(A_cumsum)
  471. C_times_states = (C[..., None, :] * states[:, :, None, ...])
  472. state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
  473. Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
  474. # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
  475. y = Y_diag + Y_off
  476. # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
  477. y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
  478. y = y + D_residual
  479. # Cutting off padded chunks
  480. if pad_size > 0:
  481. y = y[:, :seq_len, :, :]
  482. y = y.reshape(batch_size, seq_len, -1)
  483. # Init cache
  484. if ssm_state is not None and cache_params is not None:
  485. cache_params.update_recurrent_state(ssm_state, layer_idx=self.layer_idx)
  486. scan_output = self.norm(y, gate)
  487. # end ssd naive
  488. # 4. Final linear projection
  489. contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
  490. return contextualized_states
  491. # fmt: on
  492. def forward(
  493. self,
  494. hidden_states,
  495. cache_params: Cache | None = None,
  496. attention_mask: torch.Tensor | None = None,
  497. **kwargs,
  498. ):
  499. if is_fast_path_available and "cuda" in self.in_proj.weight.device.type and not is_torchdynamo_compiling():
  500. return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
  501. return self.torch_forward(hidden_states, cache_params, attention_mask)
  502. class Mamba2RMSNorm(nn.Module):
  503. def __init__(self, hidden_size, eps=1e-6):
  504. """
  505. Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
  506. """
  507. super().__init__()
  508. self.weight = nn.Parameter(torch.ones(hidden_size))
  509. self.variance_epsilon = eps
  510. def forward(self, hidden_states):
  511. input_dtype = hidden_states.dtype
  512. hidden_states = hidden_states.to(torch.float32)
  513. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  514. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  515. return self.weight * hidden_states.to(input_dtype)
  516. class Mamba2Block(GradientCheckpointingLayer):
  517. def __init__(self, config, layer_idx):
  518. super().__init__()
  519. self.config = config
  520. self.layer_idx = layer_idx
  521. self.residual_in_fp32 = config.residual_in_fp32
  522. self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  523. self.mixer = Mamba2Mixer(config, layer_idx=layer_idx, initialize_mixer_weights=False)
  524. def forward(
  525. self,
  526. hidden_states,
  527. cache_params: Cache | None = None,
  528. attention_mask: torch.Tensor | None = None,
  529. **kwargs,
  530. ):
  531. residual = hidden_states
  532. hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
  533. if self.residual_in_fp32:
  534. residual = residual.to(torch.float32)
  535. hidden_states = self.mixer(hidden_states, cache_params=cache_params, attention_mask=attention_mask)
  536. hidden_states = residual + hidden_states
  537. return hidden_states
  538. @auto_docstring
  539. class Mamba2PreTrainedModel(PreTrainedModel):
  540. config: Mamba2Config
  541. base_model_prefix = "backbone"
  542. _no_split_modules = ["Mamba2Block"]
  543. supports_gradient_checkpointing = True
  544. _is_stateful = True
  545. @torch.no_grad()
  546. def _init_weights(self, module):
  547. """Initialize the weights."""
  548. std = self.config.initializer_range
  549. if isinstance(module, Mamba2Mixer):
  550. # S4D real initialization. These are not discretized!
  551. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  552. module.init_mamba2_weights()
  553. init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5))
  554. if module.conv1d.bias is not None:
  555. init.zeros_(module.conv1d.bias)
  556. init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5))
  557. if self.config.rescale_prenorm_residual:
  558. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  559. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  560. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  561. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  562. #
  563. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  564. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  565. # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
  566. # We need to reinit p since this code could be called multiple times
  567. # Having just p *= scale would repeatedly scale it down
  568. p = module.out_proj.weight
  569. p /= math.sqrt(self.config.num_hidden_layers)
  570. if isinstance(module, nn.Linear):
  571. init.normal_(module.weight, std=std)
  572. if module.bias is not None:
  573. init.zeros_(module.bias)
  574. elif isinstance(module, (Mamba2RMSNorm, MambaRMSNormGated)):
  575. init.ones_(module.weight)
  576. elif isinstance(module, nn.Embedding):
  577. init.normal_(module.weight, std=std)
  578. @dataclass
  579. @auto_docstring(
  580. custom_intro="""
  581. Class for the MAMBA2 model outputs.
  582. """
  583. )
  584. # Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2
  585. class Mamba2Output(ModelOutput):
  586. r"""
  587. cache_params (`Cache`):
  588. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  589. avoid providing the old `input_ids`.
  590. Includes both the State space model state matrices after the selective scan, and the Convolutional states
  591. """
  592. last_hidden_state: torch.FloatTensor | None = None
  593. cache_params: Cache | None = None
  594. hidden_states: tuple[torch.FloatTensor] | None = None
  595. @dataclass
  596. @auto_docstring(
  597. custom_intro="""
  598. Base class for causal language model (or autoregressive) outputs.
  599. """
  600. )
  601. # Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2
  602. class Mamba2CausalLMOutput(ModelOutput):
  603. r"""
  604. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  605. Language modeling loss (for next-token prediction).
  606. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  607. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  608. cache_params (`Cache`):
  609. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  610. avoid providing the old `input_ids`.
  611. Includes both the State space model state matrices after the selective scan, and the Convolutional states
  612. """
  613. loss: torch.FloatTensor | None = None
  614. logits: torch.FloatTensor | None = None
  615. cache_params: Cache | None = None
  616. hidden_states: tuple[torch.FloatTensor] | None = None
  617. @auto_docstring
  618. class Mamba2Model(Mamba2PreTrainedModel):
  619. def __init__(self, config):
  620. super().__init__(config)
  621. self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
  622. self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
  623. self.gradient_checkpointing = False
  624. self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  625. # Initialize weights and apply final processing
  626. self._register_load_state_dict_pre_hook(self.load_hook)
  627. self.post_init()
  628. def load_hook(self, state_dict, prefix, *args):
  629. for k in state_dict:
  630. if "embedding." in k:
  631. state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
  632. break
  633. def get_input_embeddings(self):
  634. return self.embeddings
  635. def set_input_embeddings(self, new_embeddings):
  636. self.embeddings = new_embeddings
  637. @auto_docstring
  638. def forward(
  639. self,
  640. input_ids: torch.LongTensor | None = None,
  641. inputs_embeds: torch.LongTensor | None = None,
  642. cache_params: Cache | None = None,
  643. use_cache: bool | None = None,
  644. output_hidden_states: bool | None = None,
  645. return_dict: bool | None = None,
  646. attention_mask: torch.Tensor | None = None,
  647. **kwargs,
  648. ) -> tuple | Mamba2Output:
  649. r"""
  650. cache_params (`Cache`, *optional*):
  651. If passed along, the model uses the previous state in all the blocks (which will give the output for the
  652. `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
  653. use_cache (`bool`, *optional*):
  654. If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
  655. """
  656. output_hidden_states = (
  657. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  658. )
  659. use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
  660. return_dict = return_dict if return_dict is not None else self.config.return_dict
  661. if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
  662. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  663. if inputs_embeds is None:
  664. inputs_embeds = self.embeddings(input_ids)
  665. if self.gradient_checkpointing and self.training and use_cache:
  666. use_cache = False
  667. if use_cache and cache_params is None:
  668. cache_params = DynamicCache(config=self.config)
  669. hidden_states = inputs_embeds
  670. all_hidden_states = () if output_hidden_states else None
  671. for mixer_block in self.layers:
  672. hidden_states = mixer_block(
  673. hidden_states,
  674. cache_params=cache_params,
  675. attention_mask=attention_mask,
  676. )
  677. if output_hidden_states:
  678. all_hidden_states = all_hidden_states + (hidden_states,)
  679. hidden_states = self.norm_f(hidden_states)
  680. if output_hidden_states:
  681. all_hidden_states = all_hidden_states + (hidden_states,)
  682. if not return_dict:
  683. return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
  684. return Mamba2Output(
  685. last_hidden_state=hidden_states,
  686. cache_params=cache_params if use_cache else None,
  687. hidden_states=all_hidden_states,
  688. )
  689. @auto_docstring(
  690. custom_intro="""
  691. The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input
  692. embeddings).
  693. """
  694. )
  695. class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin):
  696. _tied_weights_keys = {"lm_head.weight": "backbone.embeddings.weight"}
  697. def __init__(self, config):
  698. super().__init__(config)
  699. self.backbone = Mamba2Model(config)
  700. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  701. # Initialize weights and apply final processing
  702. self.post_init()
  703. def get_input_embeddings(self):
  704. return self.backbone.get_input_embeddings()
  705. def set_input_embeddings(self, new_embeddings):
  706. return self.backbone.set_input_embeddings(new_embeddings)
  707. def prepare_inputs_for_generation(
  708. self,
  709. input_ids,
  710. inputs_embeds=None,
  711. use_cache=None,
  712. cache_params: Cache | None = None,
  713. attention_mask: torch.Tensor | None = None,
  714. is_first_iteration: bool | None = False,
  715. **kwargs,
  716. ):
  717. model_inputs = super().prepare_inputs_for_generation(
  718. input_ids,
  719. inputs_embeds=inputs_embeds,
  720. use_cache=use_cache,
  721. cache_params=cache_params,
  722. attention_mask=attention_mask,
  723. is_first_iteration=is_first_iteration,
  724. **kwargs,
  725. )
  726. if use_cache and not is_first_iteration:
  727. model_inputs["attention_mask"] = None
  728. return model_inputs
  729. @auto_docstring
  730. def forward(
  731. self,
  732. input_ids: torch.LongTensor | None = None,
  733. inputs_embeds: torch.FloatTensor | None = None,
  734. cache_params: Cache | None = None,
  735. labels: torch.LongTensor | None = None,
  736. output_hidden_states: bool | None = None,
  737. return_dict: bool | None = None,
  738. use_cache: bool | None = None,
  739. attention_mask: torch.Tensor | None = None,
  740. logits_to_keep: int | torch.Tensor = 0,
  741. **kwargs, # for now we need this for generation and loss_function
  742. ) -> tuple | Mamba2CausalLMOutput:
  743. r"""
  744. cache_params (`Cache`, *optional*):
  745. If passed along, the model uses the previous state in all the blocks (which will give the output for the
  746. `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
  747. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  748. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  749. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  750. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  751. use_cache (`bool`, *optional*):
  752. If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
  753. """
  754. return_dict = return_dict if return_dict is not None else self.config.return_dict
  755. mamba2_outputs = self.backbone(
  756. input_ids,
  757. cache_params=cache_params,
  758. inputs_embeds=inputs_embeds,
  759. output_hidden_states=output_hidden_states,
  760. return_dict=return_dict,
  761. use_cache=use_cache,
  762. attention_mask=attention_mask,
  763. )
  764. hidden_states = mamba2_outputs[0]
  765. # Only compute necessary logits
  766. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  767. logits = self.lm_head(hidden_states[:, slice_indices, :].to(self.lm_head.weight.dtype)).float()
  768. loss = None
  769. if labels is not None:
  770. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  771. if not return_dict:
  772. output = (logits,) + mamba2_outputs[1:]
  773. return ((loss,) + output) if loss is not None else output
  774. return Mamba2CausalLMOutput(
  775. loss=loss,
  776. logits=logits,
  777. cache_params=mamba2_outputs.cache_params,
  778. hidden_states=mamba2_outputs.hidden_states,
  779. )
  780. __all__ = ["Mamba2ForCausalLM", "Mamba2Model", "Mamba2PreTrainedModel"]