modeling_jamba.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/jamba/modular_jamba.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_jamba.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. from collections.abc import Callable
  26. import torch
  27. from torch import nn
  28. from ... import initialization as init
  29. from ...activations import ACT2FN
  30. from ...cache_utils import Cache, DynamicCache
  31. from ...generation import GenerationMixin
  32. from ...integrations import (
  33. lazy_load_kernel,
  34. use_experts_implementation,
  35. use_kernel_forward_from_hub,
  36. use_kernel_func_from_hub,
  37. use_kernelized_func,
  38. )
  39. from ...masking_utils import create_causal_mask
  40. from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
  41. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  42. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  43. from ...processing_utils import Unpack
  44. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  45. from ...utils.generic import merge_with_config_defaults
  46. from ...utils.import_utils import resolve_internal_import
  47. from ...utils.output_capturing import OutputRecorder, capture_outputs
  48. from .configuration_jamba import JambaConfig
  49. logger = logging.get_logger(__name__)
  50. @use_kernel_forward_from_hub("RMSNorm")
  51. class JambaRMSNorm(nn.Module):
  52. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  53. """
  54. JambaRMSNorm is equivalent to T5LayerNorm
  55. """
  56. super().__init__()
  57. self.weight = nn.Parameter(torch.ones(hidden_size))
  58. self.variance_epsilon = eps
  59. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  60. input_dtype = hidden_states.dtype
  61. hidden_states = hidden_states.to(torch.float32)
  62. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  63. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  64. return self.weight * hidden_states.to(input_dtype)
  65. def extra_repr(self):
  66. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  67. def rotate_half(x):
  68. """Rotates half the hidden dims of the input."""
  69. x1 = x[..., : x.shape[-1] // 2]
  70. x2 = x[..., x.shape[-1] // 2 :]
  71. return torch.cat((-x2, x1), dim=-1)
  72. @use_kernel_func_from_hub("rotary_pos_emb")
  73. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  74. """Applies Rotary Position Embedding to the query and key tensors.
  75. Args:
  76. q (`torch.Tensor`): The query tensor.
  77. k (`torch.Tensor`): The key tensor.
  78. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  79. sin (`torch.Tensor`): The sine part of the rotary embedding.
  80. unsqueeze_dim (`int`, *optional*, defaults to 1):
  81. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  82. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  83. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  84. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  85. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  86. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  87. Returns:
  88. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  89. """
  90. cos = cos.unsqueeze(unsqueeze_dim)
  91. sin = sin.unsqueeze(unsqueeze_dim)
  92. q_embed = (q * cos) + (rotate_half(q) * sin)
  93. k_embed = (k * cos) + (rotate_half(k) * sin)
  94. return q_embed, k_embed
  95. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  96. """
  97. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  98. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  99. """
  100. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  101. if n_rep == 1:
  102. return hidden_states
  103. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  104. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  105. def eager_attention_forward(
  106. module: nn.Module,
  107. query: torch.Tensor,
  108. key: torch.Tensor,
  109. value: torch.Tensor,
  110. attention_mask: torch.Tensor | None,
  111. scaling: float,
  112. dropout: float = 0.0,
  113. **kwargs: Unpack[TransformersKwargs],
  114. ):
  115. key_states = repeat_kv(key, module.num_key_value_groups)
  116. value_states = repeat_kv(value, module.num_key_value_groups)
  117. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  118. if attention_mask is not None:
  119. attn_weights = attn_weights + attention_mask
  120. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  121. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  122. attn_output = torch.matmul(attn_weights, value_states)
  123. attn_output = attn_output.transpose(1, 2).contiguous()
  124. return attn_output, attn_weights
  125. @use_kernelized_func(apply_rotary_pos_emb)
  126. class JambaAttention(nn.Module):
  127. """Multi-headed attention from 'Attention Is All You Need' paper"""
  128. def __init__(self, config: JambaConfig, layer_idx: int):
  129. super().__init__()
  130. self.config = config
  131. self.layer_idx = layer_idx
  132. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  133. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  134. self.scaling = self.head_dim**-0.5
  135. self.attention_dropout = config.attention_dropout
  136. self.is_causal = True
  137. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  138. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  139. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  140. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  141. def forward(
  142. self,
  143. hidden_states: torch.Tensor,
  144. attention_mask: torch.Tensor | None = None,
  145. past_key_values: Cache | None = None,
  146. **kwargs: Unpack[TransformersKwargs],
  147. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  148. input_shape = hidden_states.shape[:-1]
  149. hidden_shape = (*input_shape, -1, self.head_dim)
  150. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  151. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  152. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  153. if past_key_values is not None:
  154. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  155. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  156. self.config._attn_implementation, eager_attention_forward
  157. )
  158. attn_output, attn_weights = attention_interface(
  159. self,
  160. query_states,
  161. key_states,
  162. value_states,
  163. attention_mask,
  164. dropout=0.0 if not self.training else self.attention_dropout,
  165. scaling=self.scaling,
  166. **kwargs,
  167. )
  168. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  169. attn_output = self.o_proj(attn_output)
  170. return attn_output, attn_weights
  171. class JambaMambaMixer(nn.Module):
  172. """
  173. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  174. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  175. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  176. and is why Mamba is called **selective** state spaces)
  177. """
  178. def __init__(self, config: JambaConfig, layer_idx):
  179. super().__init__()
  180. self.config = config
  181. self.layer_idx = layer_idx
  182. self.hidden_size = config.hidden_size
  183. self.ssm_state_size = config.mamba_d_state
  184. self.conv_kernel_size = config.mamba_d_conv
  185. self.intermediate_size = config.mamba_expand * config.hidden_size
  186. self.time_step_rank = config.mamba_dt_rank
  187. self.use_conv_bias = config.mamba_conv_bias
  188. self.use_bias = config.mamba_proj_bias
  189. self.conv1d = nn.Conv1d(
  190. in_channels=self.intermediate_size,
  191. out_channels=self.intermediate_size,
  192. bias=self.use_conv_bias,
  193. kernel_size=self.conv_kernel_size,
  194. groups=self.intermediate_size,
  195. padding=self.conv_kernel_size - 1,
  196. )
  197. self.activation = config.hidden_act
  198. self.act = ACT2FN[config.hidden_act]
  199. # projection of the input hidden states
  200. self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias)
  201. # selective projection used to make dt, B and C input dependent
  202. self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
  203. # time step projection (discretization)
  204. self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
  205. # S4D real initialization. These are not discretized!
  206. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  207. A = torch.arange(1, self.ssm_state_size + 1)[None, :]
  208. A = A.expand(self.intermediate_size, -1).contiguous()
  209. self.A_log = nn.Parameter(torch.log(A))
  210. self.D = nn.Parameter(torch.ones(self.intermediate_size))
  211. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
  212. self.dt_layernorm = JambaRMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
  213. self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
  214. self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
  215. global causal_conv1d_update, causal_conv1d_fn
  216. causal_conv1d = lazy_load_kernel("causal-conv1d")
  217. causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
  218. causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
  219. global selective_state_update, mamba_inner_fn, selective_scan_fn
  220. mamba_ssm = lazy_load_kernel("mamba-ssm")
  221. selective_state_update = resolve_internal_import(
  222. mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
  223. )
  224. selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None)
  225. mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None)
  226. global is_fast_path_available
  227. is_fast_path_available = all(
  228. (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
  229. )
  230. if not is_fast_path_available:
  231. logger.warning_once(
  232. "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
  233. " is None. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d."
  234. )
  235. def cuda_kernels_forward(
  236. self,
  237. hidden_states: torch.Tensor,
  238. cache_params: Cache | None = None,
  239. attention_mask: torch.LongTensor | None = None,
  240. ):
  241. batch_size, seq_len, _ = hidden_states.shape
  242. use_precomputed_states = (
  243. cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1
  244. )
  245. # 1. Gated MLP's linear projection
  246. projected_states = self.in_proj(hidden_states).transpose(1, 2)
  247. # We can't use `mamba_inner_fn` even if in training and without cache params because we have the
  248. # inner layernorms which isn't supported by this fused kernel
  249. hidden_states, gate = projected_states.chunk(2, dim=1)
  250. if attention_mask is not None:
  251. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  252. # 2. Convolution sequence transformation
  253. conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
  254. if use_precomputed_states:
  255. hidden_states = causal_conv1d_update(
  256. hidden_states.squeeze(-1),
  257. cache_params.layers[self.layer_idx].conv_states,
  258. conv_weights,
  259. self.conv1d.bias,
  260. self.activation,
  261. )
  262. hidden_states = hidden_states.unsqueeze(-1)
  263. else:
  264. if cache_params is not None:
  265. conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
  266. cache_params.update_conv_state(conv_states, self.layer_idx)
  267. hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)
  268. if attention_mask is not None:
  269. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  270. # 3. State Space Model sequence transformation
  271. # 3.a. input varying initialization of time_step, B and C
  272. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
  273. time_step, B, C = torch.split(
  274. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  275. )
  276. time_step = self.dt_layernorm(time_step)
  277. B = self.b_layernorm(B)
  278. C = self.c_layernorm(C)
  279. # Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel.
  280. # This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed
  281. # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
  282. # linear layers, and requires to call the forward pass directly.
  283. # Quantized model can't work with the original code:
  284. # ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
  285. time_proj_bias = self.dt_proj.bias.data
  286. with torch.no_grad():
  287. self.dt_proj.bias.data = torch.zeros_like(self.dt_proj.bias.data)
  288. discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
  289. with torch.no_grad():
  290. self.dt_proj.bias.data = time_proj_bias
  291. A = -torch.exp(self.A_log.float())
  292. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  293. time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
  294. if use_precomputed_states:
  295. scan_outputs = selective_state_update(
  296. cache_params.layers[self.layer_idx].recurrent_states,
  297. hidden_states[..., 0],
  298. discrete_time_step[..., 0],
  299. A,
  300. B[:, 0],
  301. C[:, 0],
  302. self.D,
  303. gate[..., 0],
  304. time_proj_bias,
  305. dt_softplus=True,
  306. ).unsqueeze(-1)
  307. else:
  308. scan_outputs, ssm_state = selective_scan_fn(
  309. hidden_states,
  310. discrete_time_step,
  311. A,
  312. B.transpose(1, 2),
  313. C.transpose(1, 2),
  314. self.D.float(),
  315. gate,
  316. time_proj_bias,
  317. delta_softplus=True,
  318. return_last_state=True,
  319. )
  320. if ssm_state is not None and cache_params is not None:
  321. cache_params.update_recurrent_state(ssm_state, self.layer_idx)
  322. # 4. Final linear projection
  323. contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
  324. return contextualized_states
  325. # fmt: off
  326. def slow_forward(self, input_states, cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None):
  327. batch_size, seq_len, _ = input_states.shape
  328. dtype = input_states.dtype
  329. # 1. Gated MLP's linear projection
  330. projected_states = self.in_proj(input_states).transpose(1, 2)
  331. hidden_states, gate = projected_states.chunk(2, dim=1)
  332. if attention_mask is not None:
  333. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  334. if cache_params is not None and cache_params.has_previous_state(self.layer_idx):
  335. # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass
  336. ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone()
  337. else:
  338. ssm_state = torch.zeros(
  339. (batch_size, self.intermediate_size, self.ssm_state_size),
  340. device=hidden_states.device, dtype=dtype
  341. )
  342. # 2. Convolution sequence transformation
  343. if cache_params is not None:
  344. if cache_params.has_previous_state(self.layer_idx) and seq_len == 1:
  345. conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx)
  346. hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
  347. if self.use_conv_bias:
  348. hidden_states += self.conv1d.bias
  349. hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1)
  350. else:
  351. conv_state = nn.functional.pad(
  352. hidden_states,
  353. (self.conv_kernel_size - hidden_states.shape[-1], 0)
  354. )
  355. conv_state = cache_params.update_conv_state(conv_state, self.layer_idx)
  356. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
  357. else:
  358. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
  359. if attention_mask is not None:
  360. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  361. # 3. State Space Model sequence transformation
  362. # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
  363. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
  364. time_step, B, C = torch.split(
  365. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  366. )
  367. time_step = self.dt_layernorm(time_step)
  368. B = self.b_layernorm(B)
  369. C = self.c_layernorm(C)
  370. discrete_time_step = self.dt_proj(time_step)
  371. discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2)
  372. # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
  373. A = -torch.exp(self.A_log.float())
  374. discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None])
  375. discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
  376. deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
  377. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  378. scan_outputs = []
  379. for i in range(seq_len):
  380. ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
  381. scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))
  382. scan_outputs.append(scan_output[:, :, 0])
  383. scan_output = torch.stack(scan_outputs, dim=-1)
  384. scan_output = scan_output + (hidden_states * self.D[None, :, None])
  385. scan_output = (scan_output * self.act(gate))
  386. if cache_params is not None:
  387. cache_params.update_recurrent_state(ssm_state, self.layer_idx)
  388. # 4. Final linear projection
  389. contextualized_states = self.out_proj(scan_output.transpose(1, 2))
  390. return contextualized_states
  391. # fmt: on
  392. def forward(
  393. self,
  394. hidden_states,
  395. cache_params: Cache | None = None,
  396. attention_mask: torch.LongTensor | None = None,
  397. ):
  398. if self.config.use_mamba_kernels and (
  399. not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type
  400. ):
  401. logger.warning_once(
  402. "Fast Mamba kernels are not available. Make sure that they are installed "
  403. "and that the mamba module is on a CUDA device. Turning off the fast path "
  404. "`config.use_mamba_kernels=False` and falling back to the slow path."
  405. )
  406. self.config.use_mamba_kernels = False
  407. if self.config.use_mamba_kernels:
  408. return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
  409. return self.slow_forward(hidden_states, cache_params, attention_mask)
  410. class JambaMLP(nn.Module):
  411. def __init__(self, config):
  412. super().__init__()
  413. self.config = config
  414. self.hidden_size = config.hidden_size
  415. self.intermediate_size = config.intermediate_size
  416. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  417. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  418. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  419. self.act_fn = ACT2FN[config.hidden_act]
  420. def forward(self, x):
  421. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  422. return down_proj
  423. @use_experts_implementation
  424. class JambaExperts(nn.Module):
  425. """Collection of expert weights stored as 3D tensors."""
  426. def __init__(self, config: JambaConfig):
  427. super().__init__()
  428. self.num_experts = config.num_local_experts
  429. self.hidden_dim = config.hidden_size
  430. self.intermediate_dim = config.intermediate_size
  431. self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
  432. self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
  433. self.act_fn = ACT2FN[config.hidden_act]
  434. def forward(
  435. self,
  436. hidden_states: torch.Tensor,
  437. top_k_index: torch.Tensor,
  438. top_k_weights: torch.Tensor,
  439. ) -> torch.Tensor:
  440. final_hidden_states = torch.zeros_like(hidden_states)
  441. with torch.no_grad():
  442. expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
  443. expert_mask = expert_mask.permute(2, 1, 0)
  444. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  445. for expert_idx in expert_hit:
  446. expert_idx = expert_idx[0]
  447. if expert_idx == self.num_experts:
  448. continue
  449. top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
  450. current_state = hidden_states[token_idx]
  451. gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
  452. current_hidden_states = self.act_fn(gate) * up
  453. current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
  454. current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
  455. final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
  456. return final_hidden_states
  457. class JambaSparseMoeBlock(nn.Module):
  458. """
  459. This implementation is
  460. strictly equivalent to standard MoE with full capacity (no
  461. dropped tokens). It's faster since it formulates MoE operations
  462. in terms of block-sparse operations to accommodate imbalanced
  463. assignments of tokens to experts, whereas standard MoE either
  464. (1) drop tokens at the cost of reduced performance or (2) set
  465. capacity factor to number of experts and thus waste computation
  466. and memory on padding.
  467. """
  468. def __init__(self, config: JambaConfig):
  469. super().__init__()
  470. self.hidden_dim = config.hidden_size
  471. self.ffn_dim = config.intermediate_size
  472. self.num_experts = config.num_experts
  473. self.top_k = config.num_experts_per_tok
  474. self.router = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
  475. self.experts = JambaExperts(config)
  476. def route_tokens_to_experts(self, hidden_states, router_logits):
  477. routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
  478. top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1)
  479. return top_k_index, top_k_weights.to(hidden_states.dtype)
  480. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  481. batch_size, sequence_length, hidden_dim = hidden_states.shape
  482. hidden_states = hidden_states.view(-1, hidden_dim)
  483. router_logits = self.router(hidden_states)
  484. top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits)
  485. hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)
  486. hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
  487. return hidden_states
  488. class JambaAttentionDecoderLayer(GradientCheckpointingLayer):
  489. def __init__(self, config: JambaConfig, layer_idx: int):
  490. super().__init__()
  491. num_experts = config.layers_num_experts[layer_idx] if config.layers_num_experts else 1
  492. self.self_attn = JambaAttention(config, layer_idx)
  493. ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
  494. self.feed_forward = ffn_layer_class(config)
  495. self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  496. self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  497. def forward(
  498. self,
  499. hidden_states: torch.Tensor,
  500. attention_mask: torch.Tensor | None = None,
  501. position_ids: torch.LongTensor | None = None,
  502. past_key_values: Cache | None = None,
  503. use_cache: bool | None = False,
  504. **kwargs: Unpack[TransformersKwargs],
  505. ) -> torch.FloatTensor:
  506. residual = hidden_states
  507. hidden_states = self.input_layernorm(hidden_states)
  508. hidden_states, _ = self.self_attn(
  509. hidden_states=hidden_states,
  510. attention_mask=attention_mask,
  511. position_ids=position_ids,
  512. past_key_values=past_key_values,
  513. use_cache=use_cache,
  514. **kwargs,
  515. )
  516. hidden_states = residual + hidden_states
  517. residual = hidden_states
  518. hidden_states = self.pre_ff_layernorm(hidden_states)
  519. hidden_states = self.feed_forward(hidden_states)
  520. hidden_states = residual + hidden_states
  521. return hidden_states
  522. class JambaMambaDecoderLayer(GradientCheckpointingLayer):
  523. def __init__(self, config: JambaConfig, layer_idx: int):
  524. super().__init__()
  525. num_experts = config.layers_num_experts[layer_idx] if config.layers_num_experts else 1
  526. self.mamba = JambaMambaMixer(config=config, layer_idx=layer_idx)
  527. ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
  528. self.feed_forward = ffn_layer_class(config)
  529. self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  530. self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  531. def forward(
  532. self,
  533. hidden_states: torch.Tensor,
  534. attention_mask: torch.Tensor | None = None,
  535. position_ids: torch.LongTensor | None = None,
  536. past_key_values: Cache | None = None,
  537. **kwargs: Unpack[TransformersKwargs],
  538. ) -> torch.FloatTensor:
  539. residual = hidden_states
  540. hidden_states = self.input_layernorm(hidden_states)
  541. hidden_states = self.mamba(
  542. hidden_states=hidden_states,
  543. cache_params=past_key_values,
  544. attention_mask=attention_mask,
  545. )
  546. hidden_states = residual + hidden_states
  547. residual = hidden_states
  548. hidden_states = self.pre_ff_layernorm(hidden_states)
  549. hidden_states = self.feed_forward(hidden_states)
  550. hidden_states = residual + hidden_states
  551. return hidden_states
  552. class JambaPreTrainedModel(PreTrainedModel):
  553. config: JambaConfig
  554. base_model_prefix = "model"
  555. supports_gradient_checkpointing = True
  556. _no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"]
  557. _skip_keys_device_placement = "past_key_values"
  558. _supports_flash_attn = True
  559. _supports_sdpa = True
  560. _is_stateful = True
  561. _can_record_outputs = {
  562. "hidden_states": [JambaAttentionDecoderLayer, JambaMambaDecoderLayer],
  563. "attentions": JambaAttention,
  564. "router_logits": OutputRecorder(nn.Linear, layer_name="router"),
  565. }
  566. @torch.no_grad()
  567. def _init_weights(self, module):
  568. super()._init_weights(module)
  569. if isinstance(module, JambaMambaMixer):
  570. A = torch.arange(1, module.ssm_state_size + 1)[None, :]
  571. A = A.expand(module.intermediate_size, -1).contiguous()
  572. init.copy_(module.A_log, torch.log(A))
  573. init.ones_(module.D)
  574. elif isinstance(module, JambaExperts):
  575. init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
  576. init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
  577. ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer}
  578. @auto_docstring
  579. class JambaModel(JambaPreTrainedModel):
  580. def __init__(self, config: JambaConfig):
  581. super().__init__(config)
  582. self.padding_idx = config.pad_token_id
  583. self.vocab_size = config.vocab_size
  584. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  585. decoder_layers = []
  586. for i in range(config.num_hidden_layers):
  587. layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
  588. decoder_layers.append(layer_class(config, layer_idx=i))
  589. self.layers = nn.ModuleList(decoder_layers)
  590. self.final_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  591. self.gradient_checkpointing = False
  592. # Initialize weights and apply final processing
  593. self.post_init()
  594. @merge_with_config_defaults
  595. @capture_outputs
  596. @auto_docstring
  597. def forward(
  598. self,
  599. input_ids: torch.LongTensor | None = None,
  600. attention_mask: torch.Tensor | None = None,
  601. position_ids: torch.LongTensor | None = None,
  602. past_key_values: Cache | None = None,
  603. inputs_embeds: torch.FloatTensor | None = None,
  604. use_cache: bool | None = None,
  605. **kwargs: Unpack[TransformersKwargs],
  606. ) -> MoeModelOutputWithPast:
  607. if (input_ids is None) ^ (inputs_embeds is not None):
  608. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  609. if inputs_embeds is None:
  610. inputs_embeds = self.embed_tokens(input_ids)
  611. if use_cache and past_key_values is None:
  612. past_key_values = DynamicCache(config=self.config)
  613. if position_ids is None:
  614. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  615. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  616. position_ids = position_ids.unsqueeze(0)
  617. causal_mask = create_causal_mask(
  618. config=self.config,
  619. inputs_embeds=inputs_embeds,
  620. attention_mask=attention_mask,
  621. past_key_values=past_key_values,
  622. position_ids=position_ids,
  623. )
  624. mamba_mask = self._update_mamba_mask(attention_mask, past_key_values)
  625. hidden_states = inputs_embeds
  626. for decoder_layer in self.layers:
  627. layer_mask = mamba_mask if isinstance(decoder_layer, JambaMambaDecoderLayer) else causal_mask
  628. hidden_states = decoder_layer(
  629. hidden_states,
  630. attention_mask=layer_mask,
  631. position_ids=position_ids,
  632. past_key_values=past_key_values,
  633. use_cache=use_cache,
  634. **kwargs,
  635. )
  636. hidden_states = self.final_layernorm(hidden_states)
  637. return MoeModelOutputWithPast(
  638. last_hidden_state=hidden_states,
  639. past_key_values=past_key_values,
  640. )
  641. def _update_mamba_mask(self, attention_mask, past_key_values):
  642. """
  643. No need for zeroing states when
  644. 1. Cached forward
  645. 2. Attending to all inputs
  646. """
  647. mamba_mask = attention_mask
  648. if (past_key_values is not None and past_key_values.has_previous_state()) or (
  649. attention_mask is not None and torch.all(attention_mask == 1)
  650. ):
  651. mamba_mask = None
  652. return mamba_mask
  653. def load_balancing_loss_func(
  654. gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
  655. num_experts: int | None = None,
  656. top_k=2,
  657. attention_mask: torch.Tensor | None = None,
  658. ) -> torch.Tensor | int:
  659. r"""
  660. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  661. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  662. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  663. experts is too unbalanced.
  664. Args:
  665. gate_logits:
  666. Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
  667. shape [batch_size X sequence_length, num_experts].
  668. num_experts:
  669. Number of experts
  670. top_k:
  671. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  672. parameter.
  673. attention_mask (`torch.Tensor`, *optional*):
  674. The attention_mask used in forward function
  675. shape [batch_size X sequence_length] if not None.
  676. Returns:
  677. The auxiliary loss.
  678. """
  679. if gate_logits is None or not isinstance(gate_logits, tuple):
  680. return 0
  681. if isinstance(gate_logits, tuple):
  682. compute_device = gate_logits[0].device
  683. concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
  684. routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
  685. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  686. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  687. if attention_mask is None:
  688. # Compute the percentage of tokens routed to each experts
  689. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  690. # Compute the average probability of routing to these experts
  691. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  692. else:
  693. batch_size, sequence_length = attention_mask.shape
  694. num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
  695. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  696. expert_attention_mask = (
  697. attention_mask[None, :, :, None, None]
  698. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  699. .reshape(-1, top_k, num_experts)
  700. .to(compute_device)
  701. )
  702. # Compute the percentage of tokens routed to each experts
  703. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  704. expert_attention_mask, dim=0
  705. )
  706. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  707. router_per_expert_attention_mask = (
  708. attention_mask[None, :, :, None]
  709. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  710. .reshape(-1, num_experts)
  711. .to(compute_device)
  712. )
  713. # Compute the average probability of routing to these experts
  714. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  715. router_per_expert_attention_mask, dim=0
  716. )
  717. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
  718. return overall_loss * num_experts
  719. @auto_docstring
  720. class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin):
  721. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  722. _tp_plan = {"lm_head": "colwise_gather_output"}
  723. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  724. def __init__(self, config: JambaConfig):
  725. super().__init__(config)
  726. self.model = JambaModel(config)
  727. self.vocab_size = config.vocab_size
  728. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  729. self.router_aux_loss_coef = config.router_aux_loss_coef
  730. self.num_experts = config.num_experts
  731. self.num_experts_per_tok = config.num_experts_per_tok
  732. # Initialize weights and apply final processing
  733. self.post_init()
  734. @can_return_tuple
  735. @auto_docstring
  736. def forward(
  737. self,
  738. input_ids: torch.LongTensor | None = None,
  739. attention_mask: torch.Tensor | None = None,
  740. position_ids: torch.LongTensor | None = None,
  741. past_key_values: Cache | None = None,
  742. inputs_embeds: torch.FloatTensor | None = None,
  743. labels: torch.LongTensor | None = None,
  744. use_cache: bool | None = None,
  745. output_router_logits: bool | None = None,
  746. logits_to_keep: int | torch.Tensor = 0,
  747. **kwargs: Unpack[TransformersKwargs],
  748. ) -> MoeCausalLMOutputWithPast:
  749. r"""
  750. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  751. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  752. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  753. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  754. Example:
  755. ```python
  756. >>> from transformers import AutoTokenizer, JambaForCausalLM
  757. >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
  758. >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
  759. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  760. >>> inputs = tokenizer(prompt, return_tensors="pt")
  761. >>> # Generate
  762. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  763. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  764. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  765. ```"""
  766. output_router_logits = (
  767. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  768. )
  769. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  770. outputs: MoeModelOutputWithPast = self.model(
  771. input_ids=input_ids,
  772. attention_mask=attention_mask,
  773. position_ids=position_ids,
  774. past_key_values=past_key_values,
  775. inputs_embeds=inputs_embeds,
  776. use_cache=use_cache,
  777. output_router_logits=output_router_logits,
  778. **kwargs,
  779. )
  780. hidden_states = outputs.last_hidden_state
  781. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  782. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  783. logits = self.lm_head(hidden_states[:, slice_indices, :])
  784. loss = None
  785. if labels is not None:
  786. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  787. aux_loss = None
  788. if output_router_logits:
  789. aux_loss = load_balancing_loss_func(
  790. outputs.router_logits,
  791. self.num_experts,
  792. self.num_experts_per_tok,
  793. attention_mask,
  794. )
  795. if labels is not None:
  796. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  797. return MoeCausalLMOutputWithPast(
  798. loss=loss,
  799. aux_loss=aux_loss,
  800. logits=logits,
  801. past_key_values=outputs.past_key_values,
  802. hidden_states=outputs.hidden_states,
  803. attentions=outputs.attentions,
  804. router_logits=outputs.router_logits,
  805. )
  806. class JambaForSequenceClassification(GenericForSequenceClassification, JambaPreTrainedModel):
  807. pass
  808. __all__ = ["JambaForCausalLM", "JambaForSequenceClassification", "JambaModel", "JambaPreTrainedModel"]