modeling_bamba.py 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/bamba/modular_bamba.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_bamba.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. from collections.abc import Callable
  26. from typing import Optional, TypedDict
  27. import torch
  28. from torch import nn
  29. from ... import initialization as init
  30. from ...activations import ACT2FN
  31. from ...cache_utils import Cache, DynamicCache
  32. from ...generation import GenerationMixin
  33. from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
  34. from ...integrations.hub_kernels import lazy_load_kernel
  35. from ...masking_utils import create_causal_mask
  36. from ...modeling_layers import GradientCheckpointingLayer
  37. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  38. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  39. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  40. from ...processing_utils import Unpack
  41. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
  42. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  43. from ...utils.import_utils import resolve_internal_import
  44. from ...utils.output_capturing import capture_outputs
  45. from .configuration_bamba import BambaConfig
  46. logger = logging.get_logger(__name__)
  47. class BambaFlashAttentionKwargs(TypedDict, total=False):
  48. """
  49. Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
  50. Use cases include padding-free training and fewer `torch.compile` graph breaks.
  51. cu_seq_lens_q (`torch.LongTensor`):
  52. Gets cumulative sequence length for query state.
  53. cu_seq_lens_k (`torch.LongTensor`):
  54. Gets cumulative sequence length for key state.
  55. max_length_q (`int`):
  56. Maximum sequence length for query state.
  57. max_length_k (`int`):
  58. Maximum sequence length for key state.
  59. seq_idx (`torch.IntTensor`):
  60. Index of each packed sequence.
  61. """
  62. cu_seq_lens_q: torch.LongTensor
  63. cu_seq_lens_k: torch.LongTensor
  64. max_length_q: int
  65. max_length_k: int
  66. seq_idx: torch.IntTensor
  67. class BambaRotaryEmbedding(nn.Module):
  68. inv_freq: torch.Tensor # fix linting for `register_buffer`
  69. def __init__(self, config: BambaConfig, device=None):
  70. super().__init__()
  71. self.max_seq_len_cached = config.max_position_embeddings
  72. self.original_max_seq_len = config.max_position_embeddings
  73. self.config = config
  74. self.rope_type = self.config.rope_parameters["rope_type"]
  75. rope_init_fn: Callable = self.compute_default_rope_parameters
  76. if self.rope_type != "default":
  77. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  78. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  79. self.register_buffer("inv_freq", inv_freq, persistent=False)
  80. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  81. @staticmethod
  82. def compute_default_rope_parameters(
  83. config: BambaConfig | None = None,
  84. device: Optional["torch.device"] = None,
  85. seq_len: int | None = None,
  86. ) -> tuple["torch.Tensor", float]:
  87. """
  88. Computes the inverse frequencies according to the original RoPE implementation
  89. Args:
  90. config ([`~transformers.PreTrainedConfig`]):
  91. The model configuration.
  92. device (`torch.device`):
  93. The device to use for initialization of the inverse frequencies.
  94. seq_len (`int`, *optional*):
  95. The current sequence length. Unused for this type of RoPE.
  96. Returns:
  97. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  98. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  99. """
  100. base = config.rope_parameters["rope_theta"]
  101. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  102. attention_factor = 1.0 # Unused in this type of RoPE
  103. # Compute the inverse frequencies
  104. inv_freq = 1.0 / (
  105. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  106. )
  107. return inv_freq, attention_factor
  108. @torch.no_grad()
  109. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  110. def forward(self, x, position_ids):
  111. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  112. position_ids_expanded = position_ids[:, None, :].float()
  113. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  114. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  115. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  116. emb = torch.cat((freqs, freqs), dim=-1)
  117. cos = emb.cos() * self.attention_scaling
  118. sin = emb.sin() * self.attention_scaling
  119. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  120. def rotate_half(x):
  121. """Rotates half the hidden dims of the input."""
  122. x1 = x[..., : x.shape[-1] // 2]
  123. x2 = x[..., x.shape[-1] // 2 :]
  124. return torch.cat((-x2, x1), dim=-1)
  125. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  126. """
  127. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  128. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  129. """
  130. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  131. if n_rep == 1:
  132. return hidden_states
  133. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  134. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  135. def eager_attention_forward(
  136. module: nn.Module,
  137. query: torch.Tensor,
  138. key: torch.Tensor,
  139. value: torch.Tensor,
  140. attention_mask: torch.Tensor | None,
  141. scaling: float,
  142. dropout: float = 0.0,
  143. **kwargs: Unpack[TransformersKwargs],
  144. ):
  145. key_states = repeat_kv(key, module.num_key_value_groups)
  146. value_states = repeat_kv(value, module.num_key_value_groups)
  147. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  148. if attention_mask is not None:
  149. attn_weights = attn_weights + attention_mask
  150. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  151. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  152. attn_output = torch.matmul(attn_weights, value_states)
  153. attn_output = attn_output.transpose(1, 2).contiguous()
  154. return attn_output, attn_weights
  155. # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
  156. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  157. """Applies Rotary Position Embedding to the query and key tensors.
  158. Removes the interleaving of cos and sin from GLM
  159. Args:
  160. q (`torch.Tensor`): The query tensor.
  161. k (`torch.Tensor`): The key tensor.
  162. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  163. sin (`torch.Tensor`): The sine part of the rotary embedding.
  164. unsqueeze_dim (`int`, *optional*, defaults to 1):
  165. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  166. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  167. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  168. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  169. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  170. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  171. Returns:
  172. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  173. """
  174. cos = cos.unsqueeze(unsqueeze_dim)
  175. sin = sin.unsqueeze(unsqueeze_dim)
  176. # Keep half or full tensor for later concatenation
  177. rotary_dim = cos.shape[-1]
  178. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  179. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  180. # Apply rotary embeddings on the first half or full tensor
  181. q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
  182. k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
  183. # Concatenate back to full shape
  184. q_embed = torch.cat([q_embed, q_pass], dim=-1)
  185. k_embed = torch.cat([k_embed, k_pass], dim=-1)
  186. return q_embed, k_embed
  187. @use_kernelized_func(apply_rotary_pos_emb)
  188. class BambaAttention(nn.Module):
  189. """Multi-headed attention from 'Attention Is All You Need' paper"""
  190. def __init__(self, config: BambaConfig, layer_idx: int):
  191. super().__init__()
  192. self.config = config
  193. self.layer_idx = layer_idx
  194. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  195. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  196. self.scaling = self.head_dim**-0.5
  197. self.attention_dropout = config.attention_dropout
  198. self.is_causal = True
  199. self.q_proj = nn.Linear(
  200. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  201. )
  202. self.k_proj = nn.Linear(
  203. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  204. )
  205. self.v_proj = nn.Linear(
  206. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  207. )
  208. self.o_proj = nn.Linear(
  209. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  210. )
  211. def forward(
  212. self,
  213. hidden_states: torch.Tensor,
  214. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  215. attention_mask: torch.Tensor | None = None,
  216. past_key_values: Cache | None = None,
  217. **kwargs: Unpack[TransformersKwargs],
  218. ) -> tuple[torch.Tensor, torch.Tensor]:
  219. input_shape = hidden_states.shape[:-1]
  220. hidden_shape = (*input_shape, -1, self.head_dim)
  221. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  222. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  223. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  224. cos, sin = position_embeddings
  225. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  226. if past_key_values is not None:
  227. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  228. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  229. self.config._attn_implementation, eager_attention_forward
  230. )
  231. attn_output, attn_weights = attention_interface(
  232. self,
  233. query_states,
  234. key_states,
  235. value_states,
  236. attention_mask,
  237. dropout=0.0 if not self.training else self.attention_dropout,
  238. scaling=self.scaling,
  239. **kwargs,
  240. )
  241. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  242. attn_output = self.o_proj(attn_output)
  243. return attn_output, attn_weights
  244. class BambaRMSNormGated(torch.nn.Module):
  245. def __init__(self, hidden_size, eps=1e-6):
  246. super().__init__()
  247. self.weight = nn.Parameter(torch.ones(hidden_size))
  248. self.variance_epsilon = eps
  249. def forward(self, hidden_states, gate=None):
  250. input_dtype = hidden_states.dtype
  251. hidden_states = hidden_states.to(torch.float32)
  252. if gate is not None:
  253. hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
  254. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  255. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  256. return self.weight * hidden_states.to(input_dtype)
  257. # Helper methods for segment sum computation
  258. def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
  259. """
  260. Padding x tensor with `pad_size` on the seq_len dim (dim=1)
  261. Assumes that we only have tensors of either size 4 or 3
  262. """
  263. pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
  264. return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
  265. def reshape_into_chunks(input_tensor, pad_size, chunk_size):
  266. """
  267. Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
  268. simultaneously splitting it into chunk sequences.
  269. Assumes that we only have tensors of either size 4 or 3
  270. """
  271. # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
  272. input_tensor = pad_tensor_by_size(input_tensor, pad_size)
  273. if len(input_tensor.shape) == 3:
  274. # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
  275. return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
  276. else:
  277. # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
  278. return input_tensor.reshape(
  279. input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
  280. )
  281. def segment_sum(input_tensor):
  282. """
  283. More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
  284. """
  285. chunk_size = input_tensor.size(-1)
  286. # 1. expand input tensor to have an additional dimension and repeat along that dimension
  287. # [..., chunk_size] -> [..., chunk_size, chunk_size]
  288. input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
  289. # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
  290. mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
  291. input_tensor = input_tensor.masked_fill(~mask, 0)
  292. # 3. compute actual cumsum
  293. tensor_segsum = torch.cumsum(input_tensor, dim=-2)
  294. # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
  295. mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
  296. tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
  297. return tensor_segsum
  298. def apply_mask_to_padding_states(hidden_states, attention_mask):
  299. """
  300. Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
  301. """
  302. # NOTE: attention mask is a 2D boolean tensor
  303. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  304. dtype = hidden_states.dtype
  305. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  306. return hidden_states
  307. # Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
  308. class BambaMixer(nn.Module):
  309. """
  310. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  311. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  312. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  313. and is why Mamba is called **selective** state spaces)
  314. The are a few differences between this and Mamba2Mixer:
  315. - The variable use_precomputed_states is slightly different due to the hybrid cache structure
  316. - There's a few non-obvious bugs fixed with batching in the slow path that exist in main
  317. - Some extra variables that our layer doesn't need have been removed
  318. - We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged
  319. """
  320. def __init__(self, config: BambaConfig, layer_idx: int):
  321. super().__init__()
  322. self.num_heads = config.mamba_n_heads
  323. self.hidden_size = config.hidden_size
  324. self.ssm_state_size = config.mamba_d_state
  325. self.conv_kernel_size = config.mamba_d_conv
  326. self.intermediate_size = int(config.mamba_expand * self.hidden_size)
  327. self.layer_idx = layer_idx
  328. self.use_conv_bias = config.mamba_conv_bias
  329. self.activation = config.hidden_act
  330. self.act = ACT2FN[config.hidden_act]
  331. self.use_bias = config.mamba_proj_bias
  332. self.layer_norm_epsilon = config.rms_norm_eps
  333. self.n_groups = config.mamba_n_groups
  334. self.head_dim = config.mamba_d_head
  335. self.chunk_size = config.mamba_chunk_size
  336. self.time_step_limit = config.time_step_limit
  337. self.time_step_min = config.time_step_min
  338. self.time_step_max = config.time_step_max
  339. self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
  340. self.conv1d = nn.Conv1d(
  341. in_channels=self.conv_dim,
  342. out_channels=self.conv_dim,
  343. bias=config.mamba_conv_bias,
  344. kernel_size=self.conv_kernel_size,
  345. groups=self.conv_dim,
  346. padding=self.conv_kernel_size - 1,
  347. )
  348. # projection of the input hidden states
  349. projection_size = self.intermediate_size + self.conv_dim + self.num_heads
  350. self.in_proj = nn.Linear(
  351. self.hidden_size,
  352. projection_size,
  353. bias=self.use_bias,
  354. )
  355. # selective projection used to make dt, B and C input dependent
  356. # time step projection (discretization)
  357. # instantiate once and copy inv_dt in init_weights of PretrainedModel
  358. self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
  359. # S4D real initialization. These are not discretized!
  360. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  361. A = torch.arange(1, self.num_heads + 1)
  362. self.A_log = nn.Parameter(torch.log(A))
  363. self.norm = BambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
  364. self.D = nn.Parameter(torch.ones(self.num_heads))
  365. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
  366. global causal_conv1d_update, causal_conv1d_fn
  367. causal_conv1d = lazy_load_kernel("causal-conv1d")
  368. causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
  369. causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
  370. global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
  371. mamba_ssm = lazy_load_kernel("mamba-ssm")
  372. selective_state_update = resolve_internal_import(
  373. mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
  374. )
  375. mamba_chunk_scan_combined = resolve_internal_import(
  376. mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
  377. )
  378. mamba_split_conv1d_scan_combined = resolve_internal_import(
  379. mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
  380. )
  381. global is_fast_path_available
  382. is_fast_path_available = all(
  383. (
  384. selective_state_update,
  385. mamba_chunk_scan_combined,
  386. mamba_split_conv1d_scan_combined,
  387. causal_conv1d_fn,
  388. causal_conv1d_update,
  389. )
  390. )
  391. if not is_fast_path_available:
  392. logger.warning_once(
  393. "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
  394. " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
  395. " https://github.com/Dao-AILab/causal-conv1d"
  396. )
  397. else:
  398. logger.warning_once("The fast path for Bamba will be used when running the model on a GPU")
  399. def cuda_kernels_forward(
  400. self,
  401. hidden_states: torch.Tensor,
  402. cache_params: Cache | None = None,
  403. attention_mask: torch.Tensor | None = None,
  404. seq_idx: torch.IntTensor | None = None,
  405. ):
  406. # 1. Gated MLP's linear projection
  407. hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
  408. projected_states = self.in_proj(hidden_states)
  409. # Set up dimensions for reshapes later
  410. batch_size, seq_len, _ = hidden_states.shape
  411. groups_time_state_size = self.n_groups * self.ssm_state_size
  412. use_precomputed_states = (
  413. cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1
  414. )
  415. # getting projected states from cache if it exists
  416. if use_precomputed_states:
  417. gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
  418. [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  419. )
  420. # 2. Convolution sequence transformation
  421. hidden_states_B_C = causal_conv1d_update(
  422. hidden_states_B_C,
  423. cache_params.layers[self.layer_idx].conv_states,
  424. self.conv1d.weight.squeeze(1),
  425. self.conv1d.bias,
  426. self.activation,
  427. )
  428. hidden_states, B, C = torch.split(
  429. hidden_states_B_C,
  430. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  431. dim=-1,
  432. )
  433. # 3. SSM transformation
  434. A = -torch.exp(self.A_log.float()) # (nheads,)
  435. A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  436. dt = dt[:, :, None].expand(-1, -1, self.head_dim)
  437. dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
  438. D = self.D[:, None, ...].expand(-1, self.head_dim)
  439. B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
  440. C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
  441. hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
  442. hidden_states = selective_state_update(
  443. cache_params.layers[self.layer_idx].recurrent_states,
  444. hidden_states_reshaped,
  445. dt,
  446. A,
  447. B,
  448. C,
  449. D,
  450. z=None,
  451. dt_bias=dt_bias,
  452. dt_softplus=True,
  453. )
  454. hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
  455. hidden_states = self.norm(hidden_states, gate)
  456. # 4. Final linear projection
  457. out = self.out_proj(hidden_states)[:, None, ...]
  458. # Fused calculations or step by step if no initialized cache is found
  459. else:
  460. A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
  461. dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
  462. # 2-4. Fused kernel for conv1d, SSM, and the final projection
  463. if self.training and cache_params is None:
  464. out = mamba_split_conv1d_scan_combined(
  465. projected_states,
  466. self.conv1d.weight.squeeze(1),
  467. self.conv1d.bias,
  468. self.dt_bias,
  469. A,
  470. D=self.D,
  471. chunk_size=self.chunk_size,
  472. seq_idx=seq_idx,
  473. activation=self.activation,
  474. rmsnorm_weight=self.norm.weight,
  475. rmsnorm_eps=self.norm.variance_epsilon,
  476. outproj_weight=self.out_proj.weight,
  477. outproj_bias=self.out_proj.bias,
  478. headdim=self.head_dim,
  479. ngroups=self.n_groups,
  480. norm_before_gate=False,
  481. return_final_states=False,
  482. **dt_limit_kwargs,
  483. )
  484. else:
  485. gate, hidden_states_B_C, dt = projected_states.split(
  486. [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  487. )
  488. # 2. Convolution sequence transformation
  489. # Init cache
  490. if cache_params is not None:
  491. # storing the states
  492. # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
  493. # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
  494. hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
  495. conv_states = nn.functional.pad(
  496. hidden_states_B_C_transposed,
  497. (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
  498. )
  499. conv_states = cache_params.update_conv_state(conv_states, self.layer_idx)
  500. if self.activation not in ["silu", "swish"]:
  501. hidden_states_B_C = self.act(
  502. self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
  503. )
  504. else:
  505. hidden_states_B_C = causal_conv1d_fn(
  506. x=hidden_states_B_C.transpose(1, 2),
  507. weight=self.conv1d.weight.squeeze(1),
  508. bias=self.conv1d.bias,
  509. activation=self.activation,
  510. seq_idx=seq_idx,
  511. ).transpose(1, 2)
  512. hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
  513. hidden_states, B, C = torch.split(
  514. hidden_states_B_C,
  515. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  516. dim=-1,
  517. )
  518. # 3. SSM transformation
  519. scan_output, ssm_state = mamba_chunk_scan_combined(
  520. hidden_states.view(batch_size, seq_len, -1, self.head_dim),
  521. dt,
  522. A,
  523. B.view(batch_size, seq_len, self.n_groups, -1),
  524. C.view(batch_size, seq_len, self.n_groups, -1),
  525. chunk_size=self.chunk_size,
  526. D=self.D,
  527. z=None,
  528. seq_idx=seq_idx,
  529. return_final_states=True,
  530. dt_bias=self.dt_bias,
  531. dt_softplus=True,
  532. **dt_limit_kwargs,
  533. )
  534. # Init cache
  535. if ssm_state is not None and cache_params is not None:
  536. ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx)
  537. scan_output = scan_output.view(batch_size, seq_len, -1)
  538. # Multiply "gate" branch and apply extra normalization layer
  539. scan_output = self.norm(scan_output, gate)
  540. # 4. Final linear projection
  541. out = self.out_proj(scan_output)
  542. return out
  543. # fmt: off
  544. def torch_forward(
  545. self,
  546. input_states,
  547. cache_params: Cache | None = None,
  548. attention_mask: torch.Tensor | None = None,
  549. ):
  550. batch_size, seq_len, _ = input_states.shape
  551. dtype = input_states.dtype
  552. # 1. Gated MLP's linear projection
  553. input_states = apply_mask_to_padding_states(input_states, attention_mask)
  554. projected_states = self.in_proj(input_states)
  555. gate, hidden_states_B_C, dt = projected_states.split(
  556. [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  557. )
  558. hidden_states_B_C = hidden_states_B_C.transpose(1,2)
  559. use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1
  560. # 2. Convolution sequence transformation
  561. if use_precomputed_states:
  562. conv_states = cache_params.update_conv_state(hidden_states_B_C, self.layer_idx)
  563. hidden_states_B_C = torch.sum(
  564. conv_states * self.conv1d.weight.squeeze(1), dim=-1
  565. )
  566. if self.use_conv_bias:
  567. hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
  568. hidden_states_B_C = self.act(hidden_states_B_C)
  569. else:
  570. # Init cache
  571. if cache_params is not None:
  572. conv_states = nn.functional.pad(
  573. hidden_states_B_C, (self.conv_kernel_size - hidden_states_B_C.shape[-1], 0)
  574. )
  575. conv_states = cache_params.update_conv_state(conv_states, self.layer_idx)
  576. hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C)[..., :seq_len].transpose(1, 2))
  577. hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
  578. hidden_states, B, C = torch.split(
  579. hidden_states_B_C,
  580. [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
  581. dim=-1
  582. )
  583. # 3. SSM transformation
  584. A = -torch.exp(self.A_log.float()) # [num_heads]
  585. if use_precomputed_states:
  586. # We need to guarantee that anything regarding the cache is on the same device
  587. cache_device = cache_params.layers[self.layer_idx].recurrent_states.device
  588. # Note: there is no need to pad parameter matrices here, as there is just one new token
  589. # for batched generation
  590. dt = dt[:, 0, :][:, None, ...]
  591. dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
  592. # [num_heads] -> [num_heads, head_dim]
  593. dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
  594. dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
  595. dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
  596. A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  597. # [bsz, num_heads, head_dim, state_size]
  598. dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
  599. # Discretize B
  600. # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
  601. # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
  602. B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
  603. B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
  604. B = B.reshape(batch_size, -1, B.shape[-1])
  605. # [bsz, num_heads, head_dim, state_size]
  606. dB = dt[..., None] * B[..., None, :]
  607. # Discretize x into dB
  608. # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
  609. hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
  610. dBx = (dB * hidden_states[..., None]).to(device=cache_device)
  611. # State calculation
  612. ssm_states = cache_params.layers[self.layer_idx].recurrent_states * dA + dBx
  613. ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx)
  614. # Subsequent output
  615. # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
  616. C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
  617. C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
  618. C = C.reshape(batch_size, -1, C.shape[-1])
  619. # [bsz, num_heads, head_dim]
  620. ssm_states = ssm_states.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n]
  621. # Reshape ssm_states to merge the first two dimensions
  622. ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
  623. C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
  624. y = torch.bmm(ssm_states_reshaped, C_reshaped)
  625. y = y.view(batch_size, self.num_heads, self.head_dim)
  626. # D skip connection
  627. # [num_heads] -> [num_heads, head_dim]
  628. D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
  629. y = (y + hidden_states * D).to(y.dtype)
  630. # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
  631. y = y.reshape(batch_size, -1)[:, None, ...]
  632. else:
  633. # begin ssd naive implementation without einsums
  634. dt = nn.functional.softplus(dt + self.dt_bias)
  635. dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
  636. hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
  637. B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  638. C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  639. B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  640. C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  641. pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
  642. D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
  643. # Discretize x and A
  644. hidden_states = hidden_states * dt[..., None]
  645. A = A.to(hidden_states.dtype) * dt
  646. # Rearrange into blocks/chunks
  647. hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
  648. # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
  649. A = A.permute(0, 3, 1, 2)
  650. A_cumsum = torch.cumsum(A, dim=-1)
  651. # 1. Compute the output for each intra-chunk (diagonal blocks)
  652. # This is the analog of a causal mask
  653. L = torch.exp(segment_sum(A))
  654. # Contraction of C and B to get G (attention-weights like)
  655. G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n)
  656. G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
  657. # Compute M, equivalent to applying attention mask to weights
  658. M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
  659. M = M_intermediate.sum(dim=-1)
  660. # Compute Y_diag (apply to values)
  661. Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)
  662. # 2. Compute the state for each intra-chunk
  663. # (right term of low-rank factorization of off-diagonal blocks; B terms)
  664. decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
  665. B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
  666. states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)
  667. # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
  668. # (middle term of factorization of off-diag blocks; A terms)
  669. previous_states = torch.zeros_like(states[:, :1])
  670. states = torch.cat([previous_states, states], dim=1)
  671. decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
  672. decay_chunk = decay_chunk.transpose(1, 3)
  673. new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
  674. states, ssm_state = new_states[:, :-1], new_states[:, -1]
  675. # 4. Compute state -> output conversion per chunk
  676. # (left term of low-rank factorization of off-diagonal blocks; C terms)
  677. state_decay_out = torch.exp(A_cumsum)
  678. C_times_states = (C[..., None, :] * states[:, :, None, ...])
  679. state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
  680. Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
  681. # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
  682. y = Y_diag + Y_off
  683. # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
  684. y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
  685. y = y + D_residual
  686. # Cutting off padded chunks
  687. if pad_size > 0:
  688. y = y[:, :seq_len, :, :]
  689. y = y.reshape(batch_size, seq_len, -1)
  690. # Init cache
  691. if ssm_state is not None and cache_params is not None:
  692. ssm_state = cache_params.update_recurrent_state(ssm_state, self.layer_idx)
  693. scan_output = self.norm(y, gate)
  694. # end ssd naive
  695. # 4. Final linear projection
  696. contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
  697. return contextualized_states
  698. # fmt: on
  699. def forward(
  700. self,
  701. hidden_states,
  702. cache_params: Cache | None = None,
  703. attention_mask: torch.Tensor | None = None,
  704. seq_idx: torch.IntTensor | None = None,
  705. **kwargs,
  706. ):
  707. if is_fast_path_available and "cuda" in self.in_proj.weight.device.type and not is_torchdynamo_compiling():
  708. return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask, seq_idx)
  709. if seq_idx is not None:
  710. raise NotImplementedError(
  711. "`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
  712. )
  713. dtype = hidden_states.dtype
  714. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  715. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  716. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  717. return self.torch_forward(hidden_states, cache_params, attention_mask)
  718. class BambaMLP(nn.Module):
  719. def __init__(self, config):
  720. super().__init__()
  721. self.config = config
  722. self.hidden_size = config.hidden_size
  723. self.intermediate_size = config.intermediate_size
  724. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  725. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  726. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  727. self.act_fn = ACT2FN[config.hidden_act]
  728. def forward(self, x):
  729. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  730. return down_proj
  731. @use_kernel_forward_from_hub("RMSNorm")
  732. class BambaRMSNorm(nn.Module):
  733. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  734. """
  735. BambaRMSNorm is equivalent to T5LayerNorm
  736. """
  737. super().__init__()
  738. self.weight = nn.Parameter(torch.ones(hidden_size))
  739. self.variance_epsilon = eps
  740. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  741. input_dtype = hidden_states.dtype
  742. hidden_states = hidden_states.to(torch.float32)
  743. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  744. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  745. return self.weight * hidden_states.to(input_dtype)
  746. def extra_repr(self):
  747. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  748. class BambaDecoderLayer(GradientCheckpointingLayer):
  749. def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
  750. super().__init__()
  751. num_experts = 1
  752. ffn_layer_class = BambaMLP if num_experts == 1 else None
  753. self.feed_forward = ffn_layer_class(config)
  754. self.input_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  755. self.pre_ff_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  756. self.layer_type = layer_type
  757. if layer_type == "mamba":
  758. self.mamba = BambaMixer(config=config, layer_idx=layer_idx)
  759. elif layer_type == "attention":
  760. self.self_attn = BambaAttention(config, layer_idx)
  761. else:
  762. raise ValueError("Invalid layer_type")
  763. def forward(
  764. self,
  765. hidden_states: torch.Tensor,
  766. attention_mask: torch.Tensor | None = None,
  767. position_ids: torch.LongTensor | None = None,
  768. past_key_values: Cache | None = None,
  769. use_cache: bool | None = False,
  770. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  771. **kwargs: Unpack[BambaFlashAttentionKwargs],
  772. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  773. residual = hidden_states
  774. hidden_states = self.input_layernorm(hidden_states)
  775. if self.layer_type == "mamba":
  776. hidden_states = self.mamba(
  777. hidden_states=hidden_states,
  778. cache_params=past_key_values,
  779. attention_mask=attention_mask,
  780. **kwargs,
  781. )
  782. self_attn_weights = None
  783. elif self.layer_type == "attention":
  784. hidden_states, self_attn_weights = self.self_attn(
  785. hidden_states=hidden_states,
  786. attention_mask=attention_mask,
  787. position_ids=position_ids,
  788. past_key_values=past_key_values,
  789. use_cache=use_cache,
  790. position_embeddings=position_embeddings,
  791. **kwargs,
  792. )
  793. hidden_states = residual + hidden_states
  794. residual = hidden_states
  795. hidden_states = self.pre_ff_layernorm(hidden_states)
  796. hidden_states = self.feed_forward(hidden_states)
  797. hidden_states = residual + hidden_states
  798. return hidden_states, self_attn_weights
  799. @auto_docstring
  800. class BambaPreTrainedModel(PreTrainedModel):
  801. config: BambaConfig
  802. base_model_prefix = "model"
  803. supports_gradient_checkpointing = True
  804. _no_split_modules = ["BambaDecoderLayer"]
  805. _skip_keys_device_placement = "past_key_values"
  806. _supports_flash_attn = True
  807. _supports_sdpa = True
  808. _is_stateful = True
  809. _can_record_outputs = {
  810. "hidden_states": BambaDecoderLayer,
  811. "attentions": BambaAttention,
  812. }
  813. @torch.no_grad()
  814. def _init_weights(self, module):
  815. super()._init_weights(module)
  816. if isinstance(module, BambaMixer):
  817. init.ones_(module.dt_bias)
  818. init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
  819. init.ones_(module.D)
  820. @auto_docstring
  821. class BambaModel(BambaPreTrainedModel):
  822. def __init__(self, config: BambaConfig):
  823. super().__init__(config)
  824. self.padding_idx = config.pad_token_id
  825. self.vocab_size = config.vocab_size
  826. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  827. decoder_layers = []
  828. for i in range(config.num_hidden_layers):
  829. decoder_layers.append(BambaDecoderLayer(config, layer_idx=i, layer_type=config.layers_block_type[i]))
  830. self.layers = nn.ModuleList(decoder_layers)
  831. self._attn_implementation = config._attn_implementation
  832. self.final_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  833. self.rotary_emb = BambaRotaryEmbedding(config=config)
  834. self.gradient_checkpointing = False
  835. # Initialize weights and apply final processing
  836. self.post_init()
  837. @merge_with_config_defaults
  838. @capture_outputs
  839. @auto_docstring
  840. def forward(
  841. self,
  842. input_ids: torch.LongTensor | None = None,
  843. attention_mask: torch.Tensor | None = None,
  844. position_ids: torch.LongTensor | None = None,
  845. past_key_values: Cache | None = None,
  846. inputs_embeds: torch.FloatTensor | None = None,
  847. use_cache: bool | None = None,
  848. **kwargs: Unpack[BambaFlashAttentionKwargs],
  849. ) -> BaseModelOutputWithPast:
  850. if (input_ids is None) ^ (inputs_embeds is not None):
  851. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  852. if inputs_embeds is None:
  853. inputs_embeds = self.embed_tokens(input_ids)
  854. hidden_states = inputs_embeds
  855. if use_cache and past_key_values is None:
  856. past_key_values = DynamicCache(config=self.config)
  857. if position_ids is None:
  858. position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
  859. causal_mask = create_causal_mask(
  860. config=self.config,
  861. inputs_embeds=inputs_embeds,
  862. attention_mask=attention_mask,
  863. past_key_values=past_key_values,
  864. position_ids=position_ids,
  865. )
  866. mamba_mask = self._update_mamba_mask(attention_mask, past_key_values)
  867. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  868. for i, decoder_layer in enumerate(self.layers):
  869. layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask
  870. hidden_states, attn_weights = decoder_layer(
  871. hidden_states,
  872. attention_mask=layer_mask,
  873. position_ids=position_ids,
  874. past_key_values=past_key_values,
  875. use_cache=use_cache,
  876. position_embeddings=position_embeddings,
  877. **kwargs,
  878. )
  879. hidden_states = self.final_layernorm(hidden_states)
  880. return BaseModelOutputWithPast(
  881. last_hidden_state=hidden_states,
  882. past_key_values=past_key_values,
  883. )
  884. def _update_mamba_mask(self, attention_mask, past_key_values):
  885. """
  886. No need for zeroing states when
  887. 1. Cached forward
  888. 2. Attending to all inputs
  889. """
  890. mamba_mask = attention_mask
  891. if (past_key_values is not None and past_key_values.has_previous_state()) or (
  892. attention_mask is not None and torch.all(attention_mask == 1)
  893. ):
  894. mamba_mask = None
  895. return mamba_mask
  896. @auto_docstring
  897. class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
  898. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  899. _tp_plan = {"lm_head": "colwise_gather_output"}
  900. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  901. def __init__(self, config):
  902. super().__init__(config)
  903. self.model = BambaModel(config)
  904. self.vocab_size = config.vocab_size
  905. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  906. self.z_loss_coefficient = config.z_loss_coefficient
  907. # Initialize weights and apply final processing
  908. self.post_init()
  909. @can_return_tuple
  910. @auto_docstring
  911. def forward(
  912. self,
  913. input_ids: torch.LongTensor | None = None,
  914. attention_mask: torch.Tensor | None = None,
  915. position_ids: torch.LongTensor | None = None,
  916. past_key_values: Cache | None = None,
  917. inputs_embeds: torch.FloatTensor | None = None,
  918. labels: torch.LongTensor | None = None,
  919. use_cache: bool | None = None,
  920. logits_to_keep: int | torch.Tensor = 0,
  921. **kwargs,
  922. ) -> CausalLMOutputWithPast:
  923. r"""
  924. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  925. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  926. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  927. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  928. Example:
  929. ```python
  930. >>> from transformers import AutoTokenizer, BambaForCausalLM
  931. >>> model = BambaForCausalLM.from_pretrained("...")
  932. >>> tokenizer = AutoTokenizer.from_pretrained("...")
  933. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  934. >>> inputs = tokenizer(prompt, return_tensors="pt")
  935. >>> # Generate
  936. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  937. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  938. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  939. ```"""
  940. outputs: BaseModelOutputWithPast = self.model(
  941. input_ids=input_ids,
  942. attention_mask=attention_mask,
  943. position_ids=position_ids,
  944. past_key_values=past_key_values,
  945. inputs_embeds=inputs_embeds,
  946. use_cache=use_cache,
  947. **kwargs,
  948. )
  949. hidden_states = outputs.last_hidden_state
  950. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  951. logits = self.lm_head(hidden_states[:, slice_indices, :])
  952. loss = None
  953. if labels is not None:
  954. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  955. if self.z_loss_coefficient > 0:
  956. z_loss = logits.logsumexp(dim=-1).to(dtype=loss.dtype).pow(2).mean()
  957. loss = loss + self.z_loss_coefficient * z_loss
  958. return CausalLMOutputWithPast(
  959. loss=loss,
  960. logits=logits,
  961. past_key_values=outputs.past_key_values,
  962. hidden_states=outputs.hidden_states,
  963. attentions=outputs.attentions,
  964. )
  965. def prepare_inputs_for_generation(
  966. self,
  967. input_ids,
  968. past_key_values=None,
  969. attention_mask=None,
  970. inputs_embeds=None,
  971. position_ids=None,
  972. use_cache=True,
  973. is_first_iteration=False,
  974. **kwargs,
  975. ):
  976. kwargs["logits_to_keep"] = self.config.num_logits_to_keep
  977. model_inputs = super().prepare_inputs_for_generation(
  978. input_ids,
  979. past_key_values=past_key_values,
  980. attention_mask=attention_mask,
  981. inputs_embeds=inputs_embeds,
  982. position_ids=position_ids,
  983. use_cache=use_cache,
  984. is_first_iteration=is_first_iteration,
  985. **kwargs,
  986. )
  987. return model_inputs
  988. __all__ = ["BambaModel", "BambaForCausalLM", "BambaPreTrainedModel"]