modular_bamba.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898
  1. # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  4. # and OPT implementations in this library. It has been modified from its
  5. # original forms to accommodate minor architectural differences compared
  6. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """PyTorch Bamba model."""
  20. from typing import TypedDict
  21. import torch
  22. from torch import nn
  23. from ... import initialization as init
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...integrations.hub_kernels import lazy_load_kernel
  27. from ...masking_utils import create_causal_mask
  28. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  29. from ...modeling_utils import PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
  32. from ...utils.generic import merge_with_config_defaults
  33. from ...utils.import_utils import resolve_internal_import
  34. from ...utils.output_capturing import capture_outputs
  35. from ..jamba.modeling_jamba import JambaAttentionDecoderLayer
  36. from ..llama.modeling_llama import (
  37. LlamaAttention,
  38. LlamaForCausalLM,
  39. LlamaMLP,
  40. LlamaRMSNorm,
  41. LlamaRotaryEmbedding,
  42. rotate_half,
  43. )
  44. from ..mamba2.modeling_mamba2 import (
  45. MambaRMSNormGated,
  46. apply_mask_to_padding_states,
  47. pad_tensor_by_size,
  48. reshape_into_chunks,
  49. segment_sum,
  50. )
  51. from .configuration_bamba import BambaConfig
  52. logger = logging.get_logger(__name__)
  53. class BambaFlashAttentionKwargs(TypedDict, total=False):
  54. """
  55. Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
  56. Use cases include padding-free training and fewer `torch.compile` graph breaks.
  57. cu_seq_lens_q (`torch.LongTensor`):
  58. Gets cumulative sequence length for query state.
  59. cu_seq_lens_k (`torch.LongTensor`):
  60. Gets cumulative sequence length for key state.
  61. max_length_q (`int`):
  62. Maximum sequence length for query state.
  63. max_length_k (`int`):
  64. Maximum sequence length for key state.
  65. seq_idx (`torch.IntTensor`):
  66. Index of each packed sequence.
  67. """
  68. cu_seq_lens_q: torch.LongTensor
  69. cu_seq_lens_k: torch.LongTensor
  70. max_length_q: int
  71. max_length_k: int
  72. seq_idx: torch.IntTensor
  73. class BambaRotaryEmbedding(LlamaRotaryEmbedding):
  74. pass
  75. # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
  76. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  77. """Applies Rotary Position Embedding to the query and key tensors.
  78. Removes the interleaving of cos and sin from GLM
  79. Args:
  80. q (`torch.Tensor`): The query tensor.
  81. k (`torch.Tensor`): The key tensor.
  82. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  83. sin (`torch.Tensor`): The sine part of the rotary embedding.
  84. unsqueeze_dim (`int`, *optional*, defaults to 1):
  85. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  86. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  87. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  88. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  89. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  90. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  91. Returns:
  92. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  93. """
  94. cos = cos.unsqueeze(unsqueeze_dim)
  95. sin = sin.unsqueeze(unsqueeze_dim)
  96. # Keep half or full tensor for later concatenation
  97. rotary_dim = cos.shape[-1]
  98. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  99. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  100. # Apply rotary embeddings on the first half or full tensor
  101. q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
  102. k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
  103. # Concatenate back to full shape
  104. q_embed = torch.cat([q_embed, q_pass], dim=-1)
  105. k_embed = torch.cat([k_embed, k_pass], dim=-1)
  106. return q_embed, k_embed
  107. class BambaAttention(LlamaAttention):
  108. pass
  109. class BambaRMSNormGated(MambaRMSNormGated):
  110. pass
  111. # Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
  112. class BambaMixer(nn.Module):
  113. """
  114. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  115. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  116. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  117. and is why Mamba is called **selective** state spaces)
  118. The are a few differences between this and Mamba2Mixer:
  119. - The variable use_precomputed_states is slightly different due to the hybrid cache structure
  120. - There's a few non-obvious bugs fixed with batching in the slow path that exist in main
  121. - Some extra variables that our layer doesn't need have been removed
  122. - We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged
  123. """
  124. def __init__(self, config: BambaConfig, layer_idx: int):
  125. super().__init__()
  126. self.num_heads = config.mamba_n_heads
  127. self.hidden_size = config.hidden_size
  128. self.ssm_state_size = config.mamba_d_state
  129. self.conv_kernel_size = config.mamba_d_conv
  130. self.intermediate_size = int(config.mamba_expand * self.hidden_size)
  131. self.layer_idx = layer_idx
  132. self.use_conv_bias = config.mamba_conv_bias
  133. self.activation = config.hidden_act
  134. self.act = ACT2FN[config.hidden_act]
  135. self.use_bias = config.mamba_proj_bias
  136. self.layer_norm_epsilon = config.rms_norm_eps
  137. self.n_groups = config.mamba_n_groups
  138. self.head_dim = config.mamba_d_head
  139. self.chunk_size = config.mamba_chunk_size
  140. self.time_step_limit = config.time_step_limit
  141. self.time_step_min = config.time_step_min
  142. self.time_step_max = config.time_step_max
  143. self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
  144. self.conv1d = nn.Conv1d(
  145. in_channels=self.conv_dim,
  146. out_channels=self.conv_dim,
  147. bias=config.mamba_conv_bias,
  148. kernel_size=self.conv_kernel_size,
  149. groups=self.conv_dim,
  150. padding=self.conv_kernel_size - 1,
  151. )
  152. # projection of the input hidden states
  153. projection_size = self.intermediate_size + self.conv_dim + self.num_heads
  154. self.in_proj = nn.Linear(
  155. self.hidden_size,
  156. projection_size,
  157. bias=self.use_bias,
  158. )
  159. # selective projection used to make dt, B and C input dependent
  160. # time step projection (discretization)
  161. # instantiate once and copy inv_dt in init_weights of PretrainedModel
  162. self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
  163. # S4D real initialization. These are not discretized!
  164. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  165. A = torch.arange(1, self.num_heads + 1)
  166. self.A_log = nn.Parameter(torch.log(A))
  167. self.norm = BambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
  168. self.D = nn.Parameter(torch.ones(self.num_heads))
  169. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
  170. global causal_conv1d_update, causal_conv1d_fn
  171. causal_conv1d = lazy_load_kernel("causal-conv1d")
  172. causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
  173. causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
  174. global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
  175. mamba_ssm = lazy_load_kernel("mamba-ssm")
  176. selective_state_update = resolve_internal_import(
  177. mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
  178. )
  179. mamba_chunk_scan_combined = resolve_internal_import(
  180. mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
  181. )
  182. mamba_split_conv1d_scan_combined = resolve_internal_import(
  183. mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
  184. )
  185. global is_fast_path_available
  186. is_fast_path_available = all(
  187. (
  188. selective_state_update,
  189. mamba_chunk_scan_combined,
  190. mamba_split_conv1d_scan_combined,
  191. causal_conv1d_fn,
  192. causal_conv1d_update,
  193. )
  194. )
  195. if not is_fast_path_available:
  196. logger.warning_once(
  197. "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
  198. " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
  199. " https://github.com/Dao-AILab/causal-conv1d"
  200. )
  201. else:
  202. logger.warning_once("The fast path for Bamba will be used when running the model on a GPU")
  203. def cuda_kernels_forward(
  204. self,
  205. hidden_states: torch.Tensor,
  206. cache_params: Cache | None = None,
  207. attention_mask: torch.Tensor | None = None,
  208. seq_idx: torch.IntTensor | None = None,
  209. ):
  210. # 1. Gated MLP's linear projection
  211. hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
  212. projected_states = self.in_proj(hidden_states)
  213. # Set up dimensions for reshapes later
  214. batch_size, seq_len, _ = hidden_states.shape
  215. groups_time_state_size = self.n_groups * self.ssm_state_size
  216. use_precomputed_states = (
  217. cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1
  218. )
  219. # getting projected states from cache if it exists
  220. if use_precomputed_states:
  221. gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
  222. [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  223. )
  224. # 2. Convolution sequence transformation
  225. hidden_states_B_C = causal_conv1d_update(
  226. hidden_states_B_C,
  227. cache_params.layers[self.layer_idx].conv_states,
  228. self.conv1d.weight.squeeze(1),
  229. self.conv1d.bias,
  230. self.activation,
  231. )
  232. hidden_states, B, C = torch.split(
  233. hidden_states_B_C,
  234. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  235. dim=-1,
  236. )
  237. # 3. SSM transformation
  238. A = -torch.exp(self.A_log.float()) # (nheads,)
  239. A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  240. dt = dt[:, :, None].expand(-1, -1, self.head_dim)
  241. dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
  242. D = self.D[:, None, ...].expand(-1, self.head_dim)
  243. B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
  244. C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
  245. hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
  246. hidden_states = selective_state_update(
  247. cache_params.layers[self.layer_idx].recurrent_states,
  248. hidden_states_reshaped,
  249. dt,
  250. A,
  251. B,
  252. C,
  253. D,
  254. z=None,
  255. dt_bias=dt_bias,
  256. dt_softplus=True,
  257. )
  258. hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
  259. hidden_states = self.norm(hidden_states, gate)
  260. # 4. Final linear projection
  261. out = self.out_proj(hidden_states)[:, None, ...]
  262. # Fused calculations or step by step if no initialized cache is found
  263. else:
  264. A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
  265. dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
  266. # 2-4. Fused kernel for conv1d, SSM, and the final projection
  267. if self.training and cache_params is None:
  268. out = mamba_split_conv1d_scan_combined(
  269. projected_states,
  270. self.conv1d.weight.squeeze(1),
  271. self.conv1d.bias,
  272. self.dt_bias,
  273. A,
  274. D=self.D,
  275. chunk_size=self.chunk_size,
  276. seq_idx=seq_idx,
  277. activation=self.activation,
  278. rmsnorm_weight=self.norm.weight,
  279. rmsnorm_eps=self.norm.variance_epsilon,
  280. outproj_weight=self.out_proj.weight,
  281. outproj_bias=self.out_proj.bias,
  282. headdim=self.head_dim,
  283. ngroups=self.n_groups,
  284. norm_before_gate=False,
  285. return_final_states=False,
  286. **dt_limit_kwargs,
  287. )
  288. else:
  289. gate, hidden_states_B_C, dt = projected_states.split(
  290. [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  291. )
  292. # 2. Convolution sequence transformation
  293. # Init cache
  294. if cache_params is not None:
  295. # storing the states
  296. # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
  297. # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
  298. hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
  299. conv_states = nn.functional.pad(
  300. hidden_states_B_C_transposed,
  301. (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
  302. )
  303. conv_states = cache_params.update_conv_state(conv_states, self.layer_idx)
  304. if self.activation not in ["silu", "swish"]:
  305. hidden_states_B_C = self.act(
  306. self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
  307. )
  308. else:
  309. hidden_states_B_C = causal_conv1d_fn(
  310. x=hidden_states_B_C.transpose(1, 2),
  311. weight=self.conv1d.weight.squeeze(1),
  312. bias=self.conv1d.bias,
  313. activation=self.activation,
  314. seq_idx=seq_idx,
  315. ).transpose(1, 2)
  316. hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
  317. hidden_states, B, C = torch.split(
  318. hidden_states_B_C,
  319. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  320. dim=-1,
  321. )
  322. # 3. SSM transformation
  323. scan_output, ssm_state = mamba_chunk_scan_combined(
  324. hidden_states.view(batch_size, seq_len, -1, self.head_dim),
  325. dt,
  326. A,
  327. B.view(batch_size, seq_len, self.n_groups, -1),
  328. C.view(batch_size, seq_len, self.n_groups, -1),
  329. chunk_size=self.chunk_size,
  330. D=self.D,
  331. z=None,
  332. seq_idx=seq_idx,
  333. return_final_states=True,
  334. dt_bias=self.dt_bias,
  335. dt_softplus=True,
  336. **dt_limit_kwargs,
  337. )
  338. # Init cache
  339. if ssm_state is not None and cache_params is not None:
  340. ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx)
  341. scan_output = scan_output.view(batch_size, seq_len, -1)
  342. # Multiply "gate" branch and apply extra normalization layer
  343. scan_output = self.norm(scan_output, gate)
  344. # 4. Final linear projection
  345. out = self.out_proj(scan_output)
  346. return out
  347. # fmt: off
  348. def torch_forward(
  349. self,
  350. input_states,
  351. cache_params: Cache | None = None,
  352. attention_mask: torch.Tensor | None = None,
  353. ):
  354. batch_size, seq_len, _ = input_states.shape
  355. dtype = input_states.dtype
  356. # 1. Gated MLP's linear projection
  357. input_states = apply_mask_to_padding_states(input_states, attention_mask)
  358. projected_states = self.in_proj(input_states)
  359. gate, hidden_states_B_C, dt = projected_states.split(
  360. [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  361. )
  362. hidden_states_B_C = hidden_states_B_C.transpose(1,2)
  363. use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1
  364. # 2. Convolution sequence transformation
  365. if use_precomputed_states:
  366. conv_states = cache_params.update_conv_state(hidden_states_B_C, self.layer_idx)
  367. hidden_states_B_C = torch.sum(
  368. conv_states * self.conv1d.weight.squeeze(1), dim=-1
  369. )
  370. if self.use_conv_bias:
  371. hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
  372. hidden_states_B_C = self.act(hidden_states_B_C)
  373. else:
  374. # Init cache
  375. if cache_params is not None:
  376. conv_states = nn.functional.pad(
  377. hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0)
  378. )
  379. conv_states = cache_params.update_conv_state(conv_states, self.layer_idx)
  380. hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2))
  381. hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
  382. hidden_states, B, C = torch.split(
  383. hidden_states_B_C,
  384. [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
  385. dim=-1
  386. )
  387. # 3. SSM transformation
  388. A = -torch.exp(self.A_log.float()) # [num_heads]
  389. if use_precomputed_states:
  390. # We need to guarantee that anything regarding the cache is on the same device
  391. cache_device = cache_params.layers[self.layer_idx].recurrent_states.device
  392. # Note: there is no need to pad parameter matrices here, as there is just one new token
  393. # for batched generation
  394. dt = dt[:, 0, :][:, None, ...]
  395. dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
  396. # [num_heads] -> [num_heads, head_dim]
  397. dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
  398. dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
  399. dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
  400. A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  401. # [bsz, num_heads, head_dim, state_size]
  402. dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
  403. # Discretize B
  404. # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
  405. # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
  406. B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
  407. B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
  408. B = B.reshape(batch_size, -1, B.shape[-1])
  409. # [bsz, num_heads, head_dim, state_size]
  410. dB = dt[..., None] * B[..., None, :]
  411. # Discretize x into dB
  412. # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
  413. hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
  414. dBx = (dB * hidden_states[..., None]).to(device=cache_device)
  415. # State calculation
  416. ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx
  417. ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx)
  418. # Subsequent output
  419. # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
  420. C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
  421. C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
  422. C = C.reshape(batch_size, -1, C.shape[-1])
  423. # [bsz, num_heads, head_dim]
  424. ssm_states = ssm_states.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n]
  425. # Reshape ssm_states to merge the first two dimensions
  426. ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
  427. C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
  428. y = torch.bmm(ssm_states_reshaped, C_reshaped)
  429. y = y.view(batch_size, self.num_heads, self.head_dim)
  430. # D skip connection
  431. # [num_heads] -> [num_heads, head_dim]
  432. D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
  433. y = (y + hidden_states * D).to(y.dtype)
  434. # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
  435. y = y.reshape(batch_size, -1)[:, None, ...]
  436. else:
  437. # begin ssd naive implementation without einsums
  438. dt = nn.functional.softplus(dt + self.dt_bias)
  439. dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
  440. hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
  441. B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  442. C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  443. B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  444. C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  445. pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
  446. D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
  447. # Discretize x and A
  448. hidden_states = hidden_states * dt[..., None]
  449. A = A.to(hidden_states.dtype) * dt
  450. # Rearrange into blocks/chunks
  451. hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
  452. # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
  453. A = A.permute(0, 3, 1, 2)
  454. A_cumsum = torch.cumsum(A, dim=-1)
  455. # 1. Compute the output for each intra-chunk (diagonal blocks)
  456. # This is the analog of a causal mask
  457. L = torch.exp(segment_sum(A))
  458. # Contraction of C and B to get G (attention-weights like)
  459. G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n)
  460. G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
  461. # Compute M, equivalent to applying attention mask to weights
  462. M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
  463. M = M_intermediate.sum(dim=-1)
  464. # Compute Y_diag (apply to values)
  465. Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)
  466. # 2. Compute the state for each intra-chunk
  467. # (right term of low-rank factorization of off-diagonal blocks; B terms)
  468. decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
  469. B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
  470. states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)
  471. # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
  472. # (middle term of factorization of off-diag blocks; A terms)
  473. previous_states = torch.zeros_like(states[:, :1])
  474. states = torch.cat([previous_states, states], dim=1)
  475. decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
  476. decay_chunk = decay_chunk.transpose(1, 3)
  477. new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
  478. states, ssm_state = new_states[:, :-1], new_states[:, -1]
  479. # 4. Compute state -> output conversion per chunk
  480. # (left term of low-rank factorization of off-diagonal blocks; C terms)
  481. state_decay_out = torch.exp(A_cumsum)
  482. C_times_states = (C[..., None, :] * states[:, :, None, ...])
  483. state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
  484. Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
  485. # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
  486. y = Y_diag + Y_off
  487. # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
  488. y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
  489. y = y + D_residual
  490. # Cutting off padded chunks
  491. if pad_size > 0:
  492. y = y[:, :seq_len, :, :]
  493. y = y.reshape(batch_size, seq_len, -1)
  494. # Init cache
  495. if ssm_state is not None and cache_params is not None:
  496. ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx)
  497. scan_output = self.norm(y, gate)
  498. # end ssd naive
  499. # 4. Final linear projection
  500. contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
  501. return contextualized_states
  502. # fmt: on
  503. def forward(
  504. self,
  505. hidden_states,
  506. cache_params: Cache | None = None,
  507. attention_mask: torch.Tensor | None = None,
  508. seq_idx: torch.IntTensor | None = None,
  509. **kwargs,
  510. ):
  511. if is_fast_path_available and "cuda" in self.in_proj.weight.device.type and not is_torchdynamo_compiling():
  512. return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask, seq_idx)
  513. if seq_idx is not None:
  514. raise NotImplementedError(
  515. "`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
  516. )
  517. dtype = hidden_states.dtype
  518. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  519. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  520. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  521. return self.torch_forward(hidden_states, cache_params, attention_mask)
  522. class BambaMLP(LlamaMLP):
  523. pass
  524. class BambaRMSNorm(LlamaRMSNorm):
  525. pass
  526. class BambaDecoderLayer(JambaAttentionDecoderLayer):
  527. def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
  528. super().__init__(config, layer_idx)
  529. del self.self_attn
  530. num_experts = 1
  531. ffn_layer_class = BambaMLP if num_experts == 1 else None
  532. self.feed_forward = ffn_layer_class(config)
  533. self.layer_type = layer_type
  534. if layer_type == "mamba":
  535. self.mamba = BambaMixer(config=config, layer_idx=layer_idx)
  536. elif layer_type == "attention":
  537. self.self_attn = BambaAttention(config, layer_idx)
  538. else:
  539. raise ValueError("Invalid layer_type")
  540. def forward(
  541. self,
  542. hidden_states: torch.Tensor,
  543. attention_mask: torch.Tensor | None = None,
  544. position_ids: torch.LongTensor | None = None,
  545. past_key_values: Cache | None = None,
  546. use_cache: bool | None = False,
  547. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  548. **kwargs: Unpack[BambaFlashAttentionKwargs],
  549. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  550. residual = hidden_states
  551. hidden_states = self.input_layernorm(hidden_states)
  552. if self.layer_type == "mamba":
  553. hidden_states = self.mamba(
  554. hidden_states=hidden_states,
  555. cache_params=past_key_values,
  556. attention_mask=attention_mask,
  557. **kwargs,
  558. )
  559. self_attn_weights = None
  560. elif self.layer_type == "attention":
  561. hidden_states, self_attn_weights = self.self_attn(
  562. hidden_states=hidden_states,
  563. attention_mask=attention_mask,
  564. position_ids=position_ids,
  565. past_key_values=past_key_values,
  566. use_cache=use_cache,
  567. position_embeddings=position_embeddings,
  568. **kwargs,
  569. )
  570. hidden_states = residual + hidden_states
  571. residual = hidden_states
  572. hidden_states = self.pre_ff_layernorm(hidden_states)
  573. hidden_states = self.feed_forward(hidden_states)
  574. hidden_states = residual + hidden_states
  575. return hidden_states, self_attn_weights
  576. @auto_docstring
  577. class BambaPreTrainedModel(PreTrainedModel):
  578. config: BambaConfig
  579. base_model_prefix = "model"
  580. supports_gradient_checkpointing = True
  581. _no_split_modules = ["BambaDecoderLayer"]
  582. _skip_keys_device_placement = "past_key_values"
  583. _supports_flash_attn = True
  584. _supports_sdpa = True
  585. _is_stateful = True
  586. _can_record_outputs = {
  587. "hidden_states": BambaDecoderLayer,
  588. "attentions": BambaAttention,
  589. }
  590. @torch.no_grad()
  591. def _init_weights(self, module):
  592. super()._init_weights(module)
  593. if isinstance(module, BambaMixer):
  594. init.ones_(module.dt_bias)
  595. init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
  596. init.ones_(module.D)
  597. @auto_docstring
  598. class BambaModel(BambaPreTrainedModel):
  599. def __init__(self, config: BambaConfig):
  600. super().__init__(config)
  601. self.padding_idx = config.pad_token_id
  602. self.vocab_size = config.vocab_size
  603. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  604. decoder_layers = []
  605. for i in range(config.num_hidden_layers):
  606. decoder_layers.append(BambaDecoderLayer(config, layer_idx=i, layer_type=config.layers_block_type[i]))
  607. self.layers = nn.ModuleList(decoder_layers)
  608. self._attn_implementation = config._attn_implementation
  609. self.final_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  610. self.rotary_emb = BambaRotaryEmbedding(config=config)
  611. self.gradient_checkpointing = False
  612. # Initialize weights and apply final processing
  613. self.post_init()
  614. @merge_with_config_defaults
  615. @capture_outputs
  616. @auto_docstring
  617. def forward(
  618. self,
  619. input_ids: torch.LongTensor | None = None,
  620. attention_mask: torch.Tensor | None = None,
  621. position_ids: torch.LongTensor | None = None,
  622. past_key_values: Cache | None = None,
  623. inputs_embeds: torch.FloatTensor | None = None,
  624. use_cache: bool | None = None,
  625. **kwargs: Unpack[BambaFlashAttentionKwargs],
  626. ) -> BaseModelOutputWithPast:
  627. if (input_ids is None) ^ (inputs_embeds is not None):
  628. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  629. if inputs_embeds is None:
  630. inputs_embeds = self.embed_tokens(input_ids)
  631. hidden_states = inputs_embeds
  632. if use_cache and past_key_values is None:
  633. past_key_values = DynamicCache(config=self.config)
  634. if position_ids is None:
  635. position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
  636. causal_mask = create_causal_mask(
  637. config=self.config,
  638. inputs_embeds=inputs_embeds,
  639. attention_mask=attention_mask,
  640. past_key_values=past_key_values,
  641. position_ids=position_ids,
  642. )
  643. mamba_mask = self._update_mamba_mask(attention_mask, past_key_values)
  644. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  645. for i, decoder_layer in enumerate(self.layers):
  646. layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask
  647. hidden_states, attn_weights = decoder_layer(
  648. hidden_states,
  649. attention_mask=layer_mask,
  650. position_ids=position_ids,
  651. past_key_values=past_key_values,
  652. use_cache=use_cache,
  653. position_embeddings=position_embeddings,
  654. **kwargs,
  655. )
  656. hidden_states = self.final_layernorm(hidden_states)
  657. return BaseModelOutputWithPast(
  658. last_hidden_state=hidden_states,
  659. past_key_values=past_key_values,
  660. )
  661. def _update_mamba_mask(self, attention_mask, past_key_values):
  662. """
  663. No need for zeroing states when
  664. 1. Cached forward
  665. 2. Attending to all inputs
  666. """
  667. mamba_mask = attention_mask
  668. if (past_key_values is not None and past_key_values.has_previous_state()) or (
  669. attention_mask is not None and torch.all(attention_mask == 1)
  670. ):
  671. mamba_mask = None
  672. return mamba_mask
  673. class BambaForCausalLM(LlamaForCausalLM):
  674. def __init__(self, config):
  675. super().__init__(config)
  676. self.z_loss_coefficient = config.z_loss_coefficient
  677. # Initialize weights and apply final processing
  678. self.post_init()
  679. @can_return_tuple
  680. @auto_docstring
  681. def forward(
  682. self,
  683. input_ids: torch.LongTensor | None = None,
  684. attention_mask: torch.Tensor | None = None,
  685. position_ids: torch.LongTensor | None = None,
  686. past_key_values: Cache | None = None,
  687. inputs_embeds: torch.FloatTensor | None = None,
  688. labels: torch.LongTensor | None = None,
  689. use_cache: bool | None = None,
  690. logits_to_keep: int | torch.Tensor = 0,
  691. **kwargs,
  692. ) -> CausalLMOutputWithPast:
  693. r"""
  694. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  695. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  696. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  697. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  698. Example:
  699. ```python
  700. >>> from transformers import AutoTokenizer, BambaForCausalLM
  701. >>> model = BambaForCausalLM.from_pretrained("...")
  702. >>> tokenizer = AutoTokenizer.from_pretrained("...")
  703. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  704. >>> inputs = tokenizer(prompt, return_tensors="pt")
  705. >>> # Generate
  706. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  707. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  708. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  709. ```"""
  710. outputs: BaseModelOutputWithPast = self.model(
  711. input_ids=input_ids,
  712. attention_mask=attention_mask,
  713. position_ids=position_ids,
  714. past_key_values=past_key_values,
  715. inputs_embeds=inputs_embeds,
  716. use_cache=use_cache,
  717. **kwargs,
  718. )
  719. hidden_states = outputs.last_hidden_state
  720. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  721. logits = self.lm_head(hidden_states[:, slice_indices, :])
  722. loss = None
  723. if labels is not None:
  724. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  725. if self.z_loss_coefficient > 0:
  726. z_loss = logits.logsumexp(dim=-1).to(dtype=loss.dtype).pow(2).mean()
  727. loss = loss + self.z_loss_coefficient * z_loss
  728. return CausalLMOutputWithPast(
  729. loss=loss,
  730. logits=logits,
  731. past_key_values=outputs.past_key_values,
  732. hidden_states=outputs.hidden_states,
  733. attentions=outputs.attentions,
  734. )
  735. def prepare_inputs_for_generation(
  736. self,
  737. input_ids,
  738. past_key_values=None,
  739. attention_mask=None,
  740. inputs_embeds=None,
  741. position_ids=None,
  742. use_cache=True,
  743. is_first_iteration=False,
  744. **kwargs,
  745. ):
  746. kwargs["logits_to_keep"] = self.config.num_logits_to_keep
  747. model_inputs = super().prepare_inputs_for_generation(
  748. input_ids,
  749. past_key_values=past_key_values,
  750. attention_mask=attention_mask,
  751. inputs_embeds=inputs_embeds,
  752. position_ids=position_ids,
  753. use_cache=use_cache,
  754. is_first_iteration=is_first_iteration,
  755. **kwargs,
  756. )
  757. return model_inputs
  758. __all__ = ["BambaModel", "BambaForCausalLM", "BambaPreTrainedModel"]