modeling_mamba.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
  1. # Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch MAMBA model."""
  15. import math
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from torch.nn import CrossEntropyLoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache
  23. from ...generation import GenerationMixin
  24. from ...integrations import lazy_load_kernel
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import (
  28. ModelOutput,
  29. auto_docstring,
  30. logging,
  31. )
  32. from ...utils.import_utils import (
  33. is_mambapy_available,
  34. is_torch_greater_or_equal,
  35. is_tracing,
  36. resolve_internal_import,
  37. )
  38. from .configuration_mamba import MambaConfig
  39. logger = logging.get_logger(__name__)
  40. if is_torch_greater_or_equal("2.9.0"):
  41. from torch._higher_order_ops.associative_scan import associative_scan
  42. else:
  43. associative_scan = None
  44. if is_mambapy_available():
  45. from mambapy.pscan import pscan
  46. else:
  47. pscan = None
  48. class MambaMixer(nn.Module):
  49. """
  50. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  51. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  52. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  53. and is why Mamba is called **selective** state spaces)
  54. """
  55. def __init__(self, config: MambaConfig, layer_idx: int, initialize_mixer_weights: bool = True):
  56. super().__init__()
  57. self.config = config
  58. self.hidden_size = config.hidden_size
  59. self.ssm_state_size = config.state_size
  60. self.conv_kernel_size = config.conv_kernel
  61. self.intermediate_size = config.intermediate_size
  62. self.time_step_rank = int(config.time_step_rank)
  63. self.layer_idx = layer_idx
  64. self.use_conv_bias = config.use_conv_bias
  65. self.conv1d = nn.Conv1d(
  66. in_channels=self.intermediate_size,
  67. out_channels=self.intermediate_size,
  68. bias=config.use_conv_bias,
  69. kernel_size=config.conv_kernel,
  70. groups=self.intermediate_size,
  71. padding=config.conv_kernel - 1,
  72. )
  73. self.activation = config.hidden_act
  74. self.act = ACT2FN[config.hidden_act]
  75. self.use_mambapy = config.use_mambapy
  76. self.use_associative_scan = config.use_associative_scan
  77. # projection of the input hidden states
  78. self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
  79. # selective projection used to make dt, B and C input dependent
  80. self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
  81. # time step projection (discretization)
  82. self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
  83. # S4D real initialization. These are not discretized!
  84. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  85. self.A_log = nn.Parameter(torch.empty(self.intermediate_size, self.ssm_state_size))
  86. self.D = nn.Parameter(torch.empty(self.intermediate_size))
  87. if initialize_mixer_weights and self.dt_proj.weight.device.type != "meta":
  88. self.init_mamba_weights()
  89. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
  90. self.use_bias = config.use_bias
  91. global causal_conv1d, causal_conv1d_update, causal_conv1d_fn
  92. causal_conv1d = lazy_load_kernel("causal-conv1d")
  93. causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
  94. causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
  95. global mamba_ssm, selective_state_update, selective_scan_fn, mamba_inner_fn
  96. mamba_ssm = lazy_load_kernel("mamba-ssm")
  97. selective_state_update = resolve_internal_import(
  98. mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
  99. )
  100. selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None)
  101. mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None)
  102. self.warn_slow_implementation()
  103. @torch.no_grad()
  104. def init_mamba_weights(self):
  105. A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32, device=self.A_log.device)[None, :]
  106. A = A.expand(self.intermediate_size, -1).contiguous()
  107. init.copy_(self.A_log, torch.log(A))
  108. init.ones_(self.D)
  109. dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
  110. if self.config.time_step_init_scheme == "constant":
  111. init.constant_(self.dt_proj.weight, dt_init_std)
  112. elif self.config.time_step_init_scheme == "random":
  113. init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
  114. dt = torch.exp(
  115. torch.rand(self.intermediate_size, device=self.dt_proj.bias.device, dtype=torch.float32)
  116. * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
  117. + math.log(self.config.time_step_min)
  118. ).clamp(min=self.config.time_step_floor)
  119. # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  120. inv_dt = dt + torch.log(-torch.expm1(-dt))
  121. init.copy_(self.dt_proj.bias, inv_dt)
  122. def warn_slow_implementation(self):
  123. is_fast_path_available = all(
  124. (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
  125. )
  126. if not is_fast_path_available:
  127. if self.use_mambapy:
  128. if is_mambapy_available():
  129. logger.warning_once(
  130. "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
  131. " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
  132. " install the kernels library using `pip install kernels` or https://github.com/Dao-AILab/causal-conv1d for causal-conv1d"
  133. )
  134. else:
  135. raise ImportError(
  136. "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
  137. )
  138. else:
  139. logger.warning_once(
  140. "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
  141. " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
  142. " install the kernels library using `pip install kernels` or https://github.com/Dao-AILab/causal-conv1d for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
  143. )
  144. def cuda_kernels_forward(
  145. self,
  146. hidden_states: torch.Tensor,
  147. cache_params: Cache | None = None,
  148. attention_mask: torch.LongTensor | None = None,
  149. ):
  150. # 1. Gated MLP's linear projection
  151. projected_states = self.in_proj(hidden_states).transpose(1, 2)
  152. if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
  153. contextualized_states = mamba_inner_fn(
  154. projected_states,
  155. self.conv1d.weight,
  156. self.conv1d.bias if self.use_conv_bias else None,
  157. self.x_proj.weight,
  158. self.dt_proj.weight,
  159. self.out_proj.weight,
  160. self.out_proj.bias.float() if self.use_bias else None,
  161. -torch.exp(self.A_log.float()),
  162. None, # input-dependent B
  163. None, # input-dependent C
  164. self.D.float(),
  165. delta_bias=self.dt_proj.bias.float(),
  166. delta_softplus=True,
  167. )
  168. else:
  169. hidden_states, gate = projected_states.chunk(2, dim=1)
  170. if attention_mask is not None:
  171. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  172. is_decoding = cache_params is not None and cache_params.has_previous_state(self.layer_idx)
  173. # 2. Convolution sequence transformation
  174. conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
  175. if is_decoding:
  176. hidden_states = causal_conv1d_update(
  177. hidden_states.squeeze(-1),
  178. cache_params.layers[self.layer_idx].conv_states,
  179. conv_weights,
  180. self.conv1d.bias,
  181. self.activation,
  182. )
  183. hidden_states = hidden_states.unsqueeze(-1)
  184. else:
  185. if cache_params is not None:
  186. conv_states = nn.functional.pad(
  187. hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
  188. )
  189. cache_params.update_conv_state(conv_states, self.layer_idx)
  190. hidden_states = causal_conv1d_fn(
  191. hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
  192. )
  193. if attention_mask is not None:
  194. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  195. # 3. State Space Model sequence transformation
  196. # 3.a. input varying initialization of time_step, B and C
  197. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
  198. time_step, B, C = torch.split(
  199. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  200. )
  201. discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
  202. A = -torch.exp(self.A_log.float())
  203. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  204. time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
  205. if is_decoding:
  206. scan_outputs = selective_state_update(
  207. cache_params.layers[self.layer_idx].recurrent_states,
  208. hidden_states[..., 0],
  209. discrete_time_step[..., 0],
  210. A,
  211. B[:, 0],
  212. C[:, 0],
  213. self.D,
  214. gate[..., 0],
  215. time_proj_bias,
  216. dt_softplus=True,
  217. ).unsqueeze(-1)
  218. else:
  219. scan_outputs, ssm_state = selective_scan_fn(
  220. hidden_states,
  221. discrete_time_step,
  222. A,
  223. B.transpose(1, 2),
  224. C.transpose(1, 2),
  225. self.D.float(),
  226. gate,
  227. time_proj_bias,
  228. delta_softplus=True,
  229. return_last_state=True,
  230. )
  231. if ssm_state is not None and cache_params is not None:
  232. cache_params.update_recurrent_state(ssm_state, self.layer_idx)
  233. # 4. Final linear projection
  234. contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
  235. return contextualized_states
  236. # fmt: off
  237. def slow_forward(self, input_states, cache_params: Cache | None=None, attention_mask: torch.LongTensor | None = None):
  238. batch_size, seq_len, _ = input_states.shape
  239. dtype = input_states.dtype
  240. # 1. Gated MLP's linear projection
  241. projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
  242. hidden_states, gate = projected_states.chunk(2, dim=1)
  243. if attention_mask is not None:
  244. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  245. if cache_params is not None and cache_params.has_previous_state(self.layer_idx):
  246. ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone()
  247. else:
  248. ssm_state = torch.zeros(
  249. (batch_size, self.intermediate_size, self.ssm_state_size),
  250. device=hidden_states.device, dtype=dtype
  251. )
  252. # 2. Convolution sequence transformation
  253. if cache_params is not None:
  254. if not cache_params.has_previous_state(self.layer_idx):
  255. conv_state = nn.functional.pad(
  256. hidden_states,
  257. (self.conv_kernel_size - hidden_states.shape[-1], 0)
  258. )
  259. cache_params.update_conv_state(conv_state, self.layer_idx)
  260. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
  261. else:
  262. conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx)
  263. conv_state = conv_state.to(self.conv1d.weight.device)
  264. hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
  265. if self.use_conv_bias:
  266. hidden_states += self.conv1d.bias
  267. hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
  268. else:
  269. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
  270. if attention_mask is not None:
  271. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  272. # 3. State Space Model sequence transformation
  273. # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
  274. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
  275. time_step, B, C = torch.split(
  276. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  277. )
  278. discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
  279. discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
  280. # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
  281. A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
  282. discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
  283. discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
  284. deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
  285. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  286. if self.use_mambapy and self.training and cache_params is None:
  287. hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, seq_len, intermediate_size, ssm_state_size]
  288. scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len]
  289. scan_output = scan_output + hidden_states * self.D[None, :, None]
  290. scan_output = scan_output * self.act(gate)
  291. else:
  292. # Use associative_scan for parallel computation when available
  293. if self.use_associative_scan and associative_scan is not None and is_tracing(hidden_states) and cache_params is None:
  294. def combine_fn(left, right):
  295. a_left, b_left = left
  296. a_right, b_right = right
  297. return (a_left * a_right, a_right * b_left + b_right)
  298. combine_mode = "pointwise" if discrete_A.device.type in ("cuda", "xpu") else "generic"
  299. _, all_h = associative_scan(combine_fn, (discrete_A, deltaB_u), dim=2, combine_mode=combine_mode)
  300. # all_h: [B, D, S, N] -> output: [B, D, S]
  301. scan_output = torch.matmul(all_h.permute(0, 2, 1, 3).to(dtype), C.unsqueeze(-1)).squeeze(-1).permute(0, 2, 1)
  302. ssm_state = all_h[:, :, -1, :]
  303. else:
  304. # Sequential loop for decoding or when associative_scan unavailable
  305. scan_outputs = []
  306. for i in range(seq_len):
  307. ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state]
  308. scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1]
  309. scan_outputs.append(scan_output[:, :, 0])
  310. scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
  311. scan_output = scan_output + (hidden_states * self.D[None, :, None])
  312. scan_output = (scan_output * self.act(gate))
  313. if cache_params is not None:
  314. cache_params.update_recurrent_state(ssm_state, self.layer_idx)
  315. # 4. Final linear projection
  316. contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
  317. return contextualized_states
  318. # fmt: on
  319. def forward(
  320. self,
  321. hidden_states,
  322. cache_params: Cache | None = None,
  323. attention_mask: torch.LongTensor | None = None,
  324. **kwargs,
  325. ):
  326. is_fast_path_available = all(
  327. (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
  328. )
  329. if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not is_tracing(hidden_states):
  330. return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
  331. return self.slow_forward(hidden_states, cache_params, attention_mask)
  332. class MambaRMSNorm(nn.Module):
  333. def __init__(self, hidden_size, eps=1e-6):
  334. """
  335. MambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
  336. """
  337. super().__init__()
  338. self.weight = nn.Parameter(torch.ones(hidden_size))
  339. self.variance_epsilon = eps
  340. def forward(self, hidden_states):
  341. input_dtype = hidden_states.dtype
  342. hidden_states = hidden_states.to(torch.float32)
  343. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  344. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  345. return self.weight * hidden_states.to(input_dtype)
  346. def extra_repr(self):
  347. return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"
  348. class MambaBlock(GradientCheckpointingLayer):
  349. def __init__(self, config, layer_idx):
  350. super().__init__()
  351. self.config = config
  352. self.layer_idx = layer_idx
  353. self.residual_in_fp32 = config.residual_in_fp32
  354. self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  355. self.mixer = MambaMixer(config, layer_idx=layer_idx, initialize_mixer_weights=False)
  356. def forward(
  357. self,
  358. hidden_states,
  359. cache_params: Cache | None = None,
  360. attention_mask: torch.LongTensor | None = None,
  361. **kwargs,
  362. ):
  363. residual = hidden_states
  364. hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
  365. if self.residual_in_fp32:
  366. residual = residual.to(torch.float32)
  367. hidden_states = self.mixer(hidden_states, cache_params=cache_params, attention_mask=attention_mask)
  368. hidden_states = residual + hidden_states
  369. return hidden_states
  370. @auto_docstring
  371. class MambaPreTrainedModel(PreTrainedModel):
  372. config: MambaConfig
  373. base_model_prefix = "backbone"
  374. _no_split_modules = ["MambaBlock", "MambaMixer"]
  375. supports_gradient_checkpointing = True
  376. _is_stateful = True
  377. @torch.no_grad()
  378. def _init_weights(self, module):
  379. """Initialize the weights."""
  380. std = self.config.initializer_range
  381. if isinstance(module, MambaMixer):
  382. # S4D real initialization. These are not discretized!
  383. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  384. module.init_mamba_weights()
  385. init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5))
  386. if module.conv1d.bias is not None:
  387. init.zeros_(module.conv1d.bias)
  388. init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5))
  389. if self.config.rescale_prenorm_residual:
  390. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  391. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  392. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  393. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  394. #
  395. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  396. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  397. # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
  398. # We need to reinit p since this code could be called multiple times
  399. # Having just p *= scale would repeatedly scale it down
  400. p = module.out_proj.weight
  401. p /= math.sqrt(self.config.num_hidden_layers)
  402. if isinstance(module, nn.Linear):
  403. init.normal_(module.weight, std=std)
  404. if module.bias is not None:
  405. init.zeros_(module.bias)
  406. elif isinstance(module, MambaRMSNorm):
  407. init.ones_(module.weight)
  408. elif isinstance(module, nn.Embedding):
  409. init.normal_(module.weight, std=std)
  410. @dataclass
  411. @auto_docstring(
  412. custom_intro="""
  413. Class for the MAMBA model outputs.
  414. """
  415. )
  416. class MambaOutput(ModelOutput):
  417. r"""
  418. cache_params (`Cache`):
  419. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  420. avoid providing the old `input_ids`.
  421. Includes both the State space model state matrices after the selective scan, and the Convolutional states
  422. """
  423. last_hidden_state: torch.FloatTensor | None = None
  424. cache_params: Cache | None = None
  425. hidden_states: tuple[torch.FloatTensor] | None = None
  426. @dataclass
  427. @auto_docstring(
  428. custom_intro="""
  429. Base class for causal language model (or autoregressive) outputs.
  430. """
  431. )
  432. class MambaCausalLMOutput(ModelOutput):
  433. r"""
  434. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  435. Language modeling loss (for next-token prediction).
  436. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  437. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  438. cache_params (`Cache`):
  439. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  440. avoid providing the old `input_ids`.
  441. Includes both the State space model state matrices after the selective scan, and the Convolutional states
  442. """
  443. loss: torch.FloatTensor | None = None
  444. logits: torch.FloatTensor | None = None
  445. cache_params: Cache | None = None
  446. hidden_states: tuple[torch.FloatTensor] | None = None
  447. @auto_docstring
  448. class MambaModel(MambaPreTrainedModel):
  449. def __init__(self, config):
  450. super().__init__(config)
  451. self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
  452. self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
  453. self.gradient_checkpointing = False
  454. self.norm_f = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  455. # Initialize weights and apply final processing
  456. self._register_load_state_dict_pre_hook(self.load_hook)
  457. self.post_init()
  458. def load_hook(self, state_dict, prefix, *args):
  459. for k in state_dict:
  460. if "embedding." in k:
  461. state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
  462. break
  463. def get_input_embeddings(self):
  464. return self.embeddings
  465. def set_input_embeddings(self, new_embeddings):
  466. self.embeddings = new_embeddings
  467. @auto_docstring
  468. def forward(
  469. self,
  470. input_ids: torch.LongTensor | None = None,
  471. inputs_embeds: torch.LongTensor | None = None,
  472. cache_params: Cache | None = None,
  473. use_cache: bool | None = None,
  474. output_hidden_states: bool | None = None,
  475. return_dict: bool | None = None,
  476. attention_mask: torch.LongTensor | None = None,
  477. **kwargs,
  478. ) -> tuple | MambaOutput:
  479. r"""
  480. cache_params (`Cache`, *optional*):
  481. If passed along, the model uses the previous state in all the blocks (which will give the output for the
  482. `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
  483. use_cache (`bool`, *optional*):
  484. If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
  485. """
  486. output_hidden_states = (
  487. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  488. )
  489. use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
  490. return_dict = return_dict if return_dict is not None else self.config.return_dict
  491. if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
  492. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  493. if inputs_embeds is None:
  494. inputs_embeds = self.embeddings(input_ids)
  495. if self.gradient_checkpointing and self.training and use_cache:
  496. use_cache = False
  497. if use_cache and cache_params is None:
  498. cache_params = DynamicCache(config=self.config)
  499. hidden_states = inputs_embeds
  500. all_hidden_states = () if output_hidden_states else None
  501. for mixer_block in self.layers:
  502. hidden_states = mixer_block(
  503. hidden_states,
  504. cache_params=cache_params,
  505. attention_mask=attention_mask,
  506. )
  507. if output_hidden_states:
  508. all_hidden_states = all_hidden_states + (hidden_states,)
  509. hidden_states = self.norm_f(hidden_states)
  510. if output_hidden_states:
  511. all_hidden_states = all_hidden_states + (hidden_states,)
  512. if not return_dict:
  513. return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
  514. return MambaOutput(
  515. last_hidden_state=hidden_states,
  516. cache_params=cache_params if use_cache else None,
  517. hidden_states=all_hidden_states,
  518. )
  519. @auto_docstring(
  520. custom_intro="""
  521. The MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input
  522. embeddings).
  523. """
  524. )
  525. class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin):
  526. _tied_weights_keys = {"lm_head.weight": "backbone.embeddings.weight"}
  527. def __init__(self, config):
  528. super().__init__(config)
  529. self.backbone = MambaModel(config)
  530. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  531. # Initialize weights and apply final processing
  532. self.post_init()
  533. def get_input_embeddings(self):
  534. return self.backbone.get_input_embeddings()
  535. def set_input_embeddings(self, new_embeddings):
  536. return self.backbone.set_input_embeddings(new_embeddings)
  537. def prepare_inputs_for_generation(
  538. self,
  539. input_ids,
  540. inputs_embeds=None,
  541. use_cache=None,
  542. cache_params: Cache | None = None,
  543. attention_mask: torch.LongTensor | None = None,
  544. is_first_iteration: bool | None = False,
  545. **kwargs,
  546. ):
  547. model_inputs = super().prepare_inputs_for_generation(
  548. input_ids,
  549. inputs_embeds=inputs_embeds,
  550. use_cache=use_cache,
  551. cache_params=cache_params,
  552. attention_mask=attention_mask,
  553. is_first_iteration=is_first_iteration,
  554. **kwargs,
  555. )
  556. if use_cache and not is_first_iteration:
  557. model_inputs["attention_mask"] = None
  558. return model_inputs
  559. @auto_docstring
  560. def forward(
  561. self,
  562. input_ids: torch.LongTensor | None = None,
  563. attention_mask: torch.LongTensor | None = None,
  564. inputs_embeds: torch.FloatTensor | None = None,
  565. cache_params: Cache | None = None,
  566. labels: torch.LongTensor | None = None,
  567. output_hidden_states: bool | None = None,
  568. return_dict: bool | None = None,
  569. use_cache: bool | None = None,
  570. logits_to_keep: int | torch.Tensor = 0,
  571. **kwargs, # for now we need this for generation
  572. ) -> tuple | MambaCausalLMOutput:
  573. r"""
  574. cache_params (`Cache`, *optional*):
  575. If passed along, the model uses the previous state in all the blocks (which will give the output for the
  576. `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
  577. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  578. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  579. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  580. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  581. use_cache (`bool`, *optional*):
  582. If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
  583. """
  584. return_dict = return_dict if return_dict is not None else self.config.return_dict
  585. mamba_outputs = self.backbone(
  586. input_ids,
  587. cache_params=cache_params,
  588. inputs_embeds=inputs_embeds,
  589. output_hidden_states=output_hidden_states,
  590. return_dict=return_dict,
  591. use_cache=use_cache,
  592. attention_mask=attention_mask,
  593. )
  594. hidden_states = mamba_outputs[0]
  595. # Only compute necessary logits
  596. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  597. logits = self.lm_head(hidden_states[:, slice_indices, :].to(self.lm_head.weight.dtype)).float()
  598. loss = None
  599. if labels is not None:
  600. # move labels to correct device
  601. labels = labels.to(logits.device)
  602. # Shift so that tokens < n predict n
  603. shift_logits = logits[..., :-1, :].contiguous()
  604. shift_labels = labels[..., 1:].contiguous()
  605. # Flatten the tokens
  606. loss_fct = CrossEntropyLoss()
  607. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  608. if not return_dict:
  609. output = (logits,) + mamba_outputs[1:]
  610. return ((loss,) + output) if loss is not None else output
  611. return MambaCausalLMOutput(
  612. loss=loss,
  613. logits=logits,
  614. cache_params=mamba_outputs.cache_params,
  615. hidden_states=mamba_outputs.hidden_states,
  616. )
  617. __all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel"]