modular_jamba.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  4. # and OPT implementations in this library. It has been modified from its
  5. # original forms to accommodate minor architectural differences compared
  6. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. from collections.abc import Callable
  20. import torch
  21. from torch import nn
  22. from ... import initialization as init
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache
  25. from ...integrations import lazy_load_kernel
  26. from ...masking_utils import create_causal_mask
  27. from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
  28. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  29. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...utils import TransformersKwargs, auto_docstring, logging
  32. from ...utils.generic import merge_with_config_defaults
  33. from ...utils.import_utils import resolve_internal_import
  34. from ...utils.output_capturing import OutputRecorder, capture_outputs
  35. from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm, eager_attention_forward
  36. from ..mistral.modeling_mistral import MistralMLP
  37. from ..mixtral.modeling_mixtral import MixtralExperts, MixtralForCausalLM
  38. from .configuration_jamba import JambaConfig
  39. logger = logging.get_logger(__name__)
  40. class JambaRMSNorm(LlamaRMSNorm):
  41. pass
  42. class JambaAttention(LlamaAttention):
  43. def __init__(self, config: JambaConfig, layer_idx: int):
  44. super().__init__(config, layer_idx)
  45. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  46. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  47. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  48. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  49. def forward(
  50. self,
  51. hidden_states: torch.Tensor,
  52. attention_mask: torch.Tensor | None = None,
  53. past_key_values: Cache | None = None,
  54. **kwargs: Unpack[TransformersKwargs],
  55. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  56. input_shape = hidden_states.shape[:-1]
  57. hidden_shape = (*input_shape, -1, self.head_dim)
  58. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  59. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  60. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  61. if past_key_values is not None:
  62. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  63. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  64. self.config._attn_implementation, eager_attention_forward
  65. )
  66. attn_output, attn_weights = attention_interface(
  67. self,
  68. query_states,
  69. key_states,
  70. value_states,
  71. attention_mask,
  72. dropout=0.0 if not self.training else self.attention_dropout,
  73. scaling=self.scaling,
  74. **kwargs,
  75. )
  76. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  77. attn_output = self.o_proj(attn_output)
  78. return attn_output, attn_weights
  79. class JambaMambaMixer(nn.Module):
  80. """
  81. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  82. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  83. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  84. and is why Mamba is called **selective** state spaces)
  85. """
  86. def __init__(self, config: JambaConfig, layer_idx):
  87. super().__init__()
  88. self.config = config
  89. self.layer_idx = layer_idx
  90. self.hidden_size = config.hidden_size
  91. self.ssm_state_size = config.mamba_d_state
  92. self.conv_kernel_size = config.mamba_d_conv
  93. self.intermediate_size = config.mamba_expand * config.hidden_size
  94. self.time_step_rank = config.mamba_dt_rank
  95. self.use_conv_bias = config.mamba_conv_bias
  96. self.use_bias = config.mamba_proj_bias
  97. self.conv1d = nn.Conv1d(
  98. in_channels=self.intermediate_size,
  99. out_channels=self.intermediate_size,
  100. bias=self.use_conv_bias,
  101. kernel_size=self.conv_kernel_size,
  102. groups=self.intermediate_size,
  103. padding=self.conv_kernel_size - 1,
  104. )
  105. self.activation = config.hidden_act
  106. self.act = ACT2FN[config.hidden_act]
  107. # projection of the input hidden states
  108. self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias)
  109. # selective projection used to make dt, B and C input dependent
  110. self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
  111. # time step projection (discretization)
  112. self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
  113. # S4D real initialization. These are not discretized!
  114. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  115. A = torch.arange(1, self.ssm_state_size + 1)[None, :]
  116. A = A.expand(self.intermediate_size, -1).contiguous()
  117. self.A_log = nn.Parameter(torch.log(A))
  118. self.D = nn.Parameter(torch.ones(self.intermediate_size))
  119. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
  120. self.dt_layernorm = JambaRMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
  121. self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
  122. self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
  123. global causal_conv1d_update, causal_conv1d_fn
  124. causal_conv1d = lazy_load_kernel("causal-conv1d")
  125. causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
  126. causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
  127. global selective_state_update, mamba_inner_fn, selective_scan_fn
  128. mamba_ssm = lazy_load_kernel("mamba-ssm")
  129. selective_state_update = resolve_internal_import(
  130. mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
  131. )
  132. selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None)
  133. mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None)
  134. global is_fast_path_available
  135. is_fast_path_available = all(
  136. (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
  137. )
  138. if not is_fast_path_available:
  139. logger.warning_once(
  140. "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
  141. " is None. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d."
  142. )
  143. def cuda_kernels_forward(
  144. self,
  145. hidden_states: torch.Tensor,
  146. cache_params: Cache | None = None,
  147. attention_mask: torch.LongTensor | None = None,
  148. ):
  149. batch_size, seq_len, _ = hidden_states.shape
  150. use_precomputed_states = (
  151. cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1
  152. )
  153. # 1. Gated MLP's linear projection
  154. projected_states = self.in_proj(hidden_states).transpose(1, 2)
  155. # We can't use `mamba_inner_fn` even if in training and without cache params because we have the
  156. # inner layernorms which isn't supported by this fused kernel
  157. hidden_states, gate = projected_states.chunk(2, dim=1)
  158. if attention_mask is not None:
  159. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  160. # 2. Convolution sequence transformation
  161. conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
  162. if use_precomputed_states:
  163. hidden_states = causal_conv1d_update(
  164. hidden_states.squeeze(-1),
  165. cache_params.layers[self.layer_idx].conv_states,
  166. conv_weights,
  167. self.conv1d.bias,
  168. self.activation,
  169. )
  170. hidden_states = hidden_states.unsqueeze(-1)
  171. else:
  172. if cache_params is not None:
  173. conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
  174. cache_params.update_conv_state(conv_states, self.layer_idx)
  175. hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)
  176. if attention_mask is not None:
  177. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  178. # 3. State Space Model sequence transformation
  179. # 3.a. input varying initialization of time_step, B and C
  180. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
  181. time_step, B, C = torch.split(
  182. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  183. )
  184. time_step = self.dt_layernorm(time_step)
  185. B = self.b_layernorm(B)
  186. C = self.c_layernorm(C)
  187. # Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel.
  188. # This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed
  189. # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
  190. # linear layers, and requires to call the forward pass directly.
  191. # Quantized model can't work with the original code:
  192. # ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
  193. time_proj_bias = self.dt_proj.bias.data
  194. with torch.no_grad():
  195. self.dt_proj.bias.data = torch.zeros_like(self.dt_proj.bias.data)
  196. discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
  197. with torch.no_grad():
  198. self.dt_proj.bias.data = time_proj_bias
  199. A = -torch.exp(self.A_log.float())
  200. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  201. time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
  202. if use_precomputed_states:
  203. scan_outputs = selective_state_update(
  204. cache_params.layers[self.layer_idx].recurrent_states,
  205. hidden_states[..., 0],
  206. discrete_time_step[..., 0],
  207. A,
  208. B[:, 0],
  209. C[:, 0],
  210. self.D,
  211. gate[..., 0],
  212. time_proj_bias,
  213. dt_softplus=True,
  214. ).unsqueeze(-1)
  215. else:
  216. scan_outputs, ssm_state = selective_scan_fn(
  217. hidden_states,
  218. discrete_time_step,
  219. A,
  220. B.transpose(1, 2),
  221. C.transpose(1, 2),
  222. self.D.float(),
  223. gate,
  224. time_proj_bias,
  225. delta_softplus=True,
  226. return_last_state=True,
  227. )
  228. if ssm_state is not None and cache_params is not None:
  229. cache_params.update_recurrent_state(ssm_state, self.layer_idx)
  230. # 4. Final linear projection
  231. contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
  232. return contextualized_states
  233. # fmt: off
  234. def slow_forward(self, input_states, cache_params: Cache | None = None, attention_mask: torch.LongTensor | None = None):
  235. batch_size, seq_len, _ = input_states.shape
  236. dtype = input_states.dtype
  237. # 1. Gated MLP's linear projection
  238. projected_states = self.in_proj(input_states).transpose(1, 2)
  239. hidden_states, gate = projected_states.chunk(2, dim=1)
  240. if attention_mask is not None:
  241. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  242. if cache_params is not None and cache_params.has_previous_state(self.layer_idx):
  243. # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass
  244. ssm_state = cache_params.layers[self.layer_idx].recurrent_states.clone()
  245. else:
  246. ssm_state = torch.zeros(
  247. (batch_size, self.intermediate_size, self.ssm_state_size),
  248. device=hidden_states.device, dtype=dtype
  249. )
  250. # 2. Convolution sequence transformation
  251. if cache_params is not None:
  252. if cache_params.has_previous_state(self.layer_idx) and seq_len == 1:
  253. conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx)
  254. hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
  255. if self.use_conv_bias:
  256. hidden_states += self.conv1d.bias
  257. hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1)
  258. else:
  259. conv_state = nn.functional.pad(
  260. hidden_states,
  261. (self.conv_kernel_size - hidden_states.shape[-1], 0)
  262. )
  263. conv_state = cache_params.update_conv_state(conv_state, self.layer_idx)
  264. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
  265. else:
  266. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
  267. if attention_mask is not None:
  268. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  269. # 3. State Space Model sequence transformation
  270. # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
  271. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
  272. time_step, B, C = torch.split(
  273. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  274. )
  275. time_step = self.dt_layernorm(time_step)
  276. B = self.b_layernorm(B)
  277. C = self.c_layernorm(C)
  278. discrete_time_step = self.dt_proj(time_step)
  279. discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2)
  280. # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
  281. A = -torch.exp(self.A_log.float())
  282. discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None])
  283. discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
  284. deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
  285. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  286. scan_outputs = []
  287. for i in range(seq_len):
  288. ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
  289. scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))
  290. scan_outputs.append(scan_output[:, :, 0])
  291. scan_output = torch.stack(scan_outputs, dim=-1)
  292. scan_output = scan_output + (hidden_states * self.D[None, :, None])
  293. scan_output = (scan_output * self.act(gate))
  294. if cache_params is not None:
  295. cache_params.update_recurrent_state(ssm_state, self.layer_idx)
  296. # 4. Final linear projection
  297. contextualized_states = self.out_proj(scan_output.transpose(1, 2))
  298. return contextualized_states
  299. # fmt: on
  300. def forward(
  301. self,
  302. hidden_states,
  303. cache_params: Cache | None = None,
  304. attention_mask: torch.LongTensor | None = None,
  305. ):
  306. if self.config.use_mamba_kernels and (
  307. not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type
  308. ):
  309. logger.warning_once(
  310. "Fast Mamba kernels are not available. Make sure that they are installed "
  311. "and that the mamba module is on a CUDA device. Turning off the fast path "
  312. "`config.use_mamba_kernels=False` and falling back to the slow path."
  313. )
  314. self.config.use_mamba_kernels = False
  315. if self.config.use_mamba_kernels:
  316. return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
  317. return self.slow_forward(hidden_states, cache_params, attention_mask)
  318. class JambaMLP(MistralMLP):
  319. pass
  320. class JambaExperts(MixtralExperts):
  321. pass
  322. class JambaSparseMoeBlock(nn.Module):
  323. """
  324. This implementation is
  325. strictly equivalent to standard MoE with full capacity (no
  326. dropped tokens). It's faster since it formulates MoE operations
  327. in terms of block-sparse operations to accommodate imbalanced
  328. assignments of tokens to experts, whereas standard MoE either
  329. (1) drop tokens at the cost of reduced performance or (2) set
  330. capacity factor to number of experts and thus waste computation
  331. and memory on padding.
  332. """
  333. def __init__(self, config: JambaConfig):
  334. super().__init__()
  335. self.hidden_dim = config.hidden_size
  336. self.ffn_dim = config.intermediate_size
  337. self.num_experts = config.num_experts
  338. self.top_k = config.num_experts_per_tok
  339. self.router = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
  340. self.experts = JambaExperts(config)
  341. def route_tokens_to_experts(self, hidden_states, router_logits):
  342. routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
  343. top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1)
  344. return top_k_index, top_k_weights.to(hidden_states.dtype)
  345. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  346. batch_size, sequence_length, hidden_dim = hidden_states.shape
  347. hidden_states = hidden_states.view(-1, hidden_dim)
  348. router_logits = self.router(hidden_states)
  349. top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits)
  350. hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)
  351. hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
  352. return hidden_states
  353. class JambaAttentionDecoderLayer(GradientCheckpointingLayer):
  354. def __init__(self, config: JambaConfig, layer_idx: int):
  355. super().__init__()
  356. num_experts = config.layers_num_experts[layer_idx] if config.layers_num_experts else 1
  357. self.self_attn = JambaAttention(config, layer_idx)
  358. ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
  359. self.feed_forward = ffn_layer_class(config)
  360. self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  361. self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  362. def forward(
  363. self,
  364. hidden_states: torch.Tensor,
  365. attention_mask: torch.Tensor | None = None,
  366. position_ids: torch.LongTensor | None = None,
  367. past_key_values: Cache | None = None,
  368. use_cache: bool | None = False,
  369. **kwargs: Unpack[TransformersKwargs],
  370. ) -> torch.FloatTensor:
  371. residual = hidden_states
  372. hidden_states = self.input_layernorm(hidden_states)
  373. hidden_states, _ = self.self_attn(
  374. hidden_states=hidden_states,
  375. attention_mask=attention_mask,
  376. position_ids=position_ids,
  377. past_key_values=past_key_values,
  378. use_cache=use_cache,
  379. **kwargs,
  380. )
  381. hidden_states = residual + hidden_states
  382. residual = hidden_states
  383. hidden_states = self.pre_ff_layernorm(hidden_states)
  384. hidden_states = self.feed_forward(hidden_states)
  385. hidden_states = residual + hidden_states
  386. return hidden_states
  387. class JambaMambaDecoderLayer(GradientCheckpointingLayer):
  388. def __init__(self, config: JambaConfig, layer_idx: int):
  389. super().__init__()
  390. num_experts = config.layers_num_experts[layer_idx] if config.layers_num_experts else 1
  391. self.mamba = JambaMambaMixer(config=config, layer_idx=layer_idx)
  392. ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
  393. self.feed_forward = ffn_layer_class(config)
  394. self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  395. self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  396. def forward(
  397. self,
  398. hidden_states: torch.Tensor,
  399. attention_mask: torch.Tensor | None = None,
  400. position_ids: torch.LongTensor | None = None,
  401. past_key_values: Cache | None = None,
  402. **kwargs: Unpack[TransformersKwargs],
  403. ) -> torch.FloatTensor:
  404. residual = hidden_states
  405. hidden_states = self.input_layernorm(hidden_states)
  406. hidden_states = self.mamba(
  407. hidden_states=hidden_states,
  408. cache_params=past_key_values,
  409. attention_mask=attention_mask,
  410. )
  411. hidden_states = residual + hidden_states
  412. residual = hidden_states
  413. hidden_states = self.pre_ff_layernorm(hidden_states)
  414. hidden_states = self.feed_forward(hidden_states)
  415. hidden_states = residual + hidden_states
  416. return hidden_states
  417. ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer}
  418. class JambaPreTrainedModel(PreTrainedModel):
  419. config: JambaConfig
  420. base_model_prefix = "model"
  421. supports_gradient_checkpointing = True
  422. _no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"]
  423. _skip_keys_device_placement = "past_key_values"
  424. _supports_flash_attn = True
  425. _supports_sdpa = True
  426. _is_stateful = True
  427. _can_record_outputs = {
  428. "hidden_states": [JambaAttentionDecoderLayer, JambaMambaDecoderLayer],
  429. "attentions": JambaAttention,
  430. "router_logits": OutputRecorder(nn.Linear, layer_name="router"),
  431. }
  432. @torch.no_grad()
  433. def _init_weights(self, module):
  434. super()._init_weights(module)
  435. if isinstance(module, JambaMambaMixer):
  436. A = torch.arange(1, module.ssm_state_size + 1)[None, :]
  437. A = A.expand(module.intermediate_size, -1).contiguous()
  438. init.copy_(module.A_log, torch.log(A))
  439. init.ones_(module.D)
  440. elif isinstance(module, JambaExperts):
  441. init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
  442. init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
  443. @auto_docstring
  444. class JambaModel(JambaPreTrainedModel):
  445. def __init__(self, config: JambaConfig):
  446. super().__init__(config)
  447. self.padding_idx = config.pad_token_id
  448. self.vocab_size = config.vocab_size
  449. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  450. decoder_layers = []
  451. for i in range(config.num_hidden_layers):
  452. layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
  453. decoder_layers.append(layer_class(config, layer_idx=i))
  454. self.layers = nn.ModuleList(decoder_layers)
  455. self.final_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  456. self.gradient_checkpointing = False
  457. # Initialize weights and apply final processing
  458. self.post_init()
  459. @merge_with_config_defaults
  460. @capture_outputs
  461. @auto_docstring
  462. def forward(
  463. self,
  464. input_ids: torch.LongTensor | None = None,
  465. attention_mask: torch.Tensor | None = None,
  466. position_ids: torch.LongTensor | None = None,
  467. past_key_values: Cache | None = None,
  468. inputs_embeds: torch.FloatTensor | None = None,
  469. use_cache: bool | None = None,
  470. **kwargs: Unpack[TransformersKwargs],
  471. ) -> MoeModelOutputWithPast:
  472. if (input_ids is None) ^ (inputs_embeds is not None):
  473. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  474. if inputs_embeds is None:
  475. inputs_embeds = self.embed_tokens(input_ids)
  476. if use_cache and past_key_values is None:
  477. past_key_values = DynamicCache(config=self.config)
  478. if position_ids is None:
  479. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  480. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  481. position_ids = position_ids.unsqueeze(0)
  482. causal_mask = create_causal_mask(
  483. config=self.config,
  484. inputs_embeds=inputs_embeds,
  485. attention_mask=attention_mask,
  486. past_key_values=past_key_values,
  487. position_ids=position_ids,
  488. )
  489. mamba_mask = self._update_mamba_mask(attention_mask, past_key_values)
  490. hidden_states = inputs_embeds
  491. for decoder_layer in self.layers:
  492. layer_mask = mamba_mask if isinstance(decoder_layer, JambaMambaDecoderLayer) else causal_mask
  493. hidden_states = decoder_layer(
  494. hidden_states,
  495. attention_mask=layer_mask,
  496. position_ids=position_ids,
  497. past_key_values=past_key_values,
  498. use_cache=use_cache,
  499. **kwargs,
  500. )
  501. hidden_states = self.final_layernorm(hidden_states)
  502. return MoeModelOutputWithPast(
  503. last_hidden_state=hidden_states,
  504. past_key_values=past_key_values,
  505. )
  506. def _update_mamba_mask(self, attention_mask, past_key_values):
  507. """
  508. No need for zeroing states when
  509. 1. Cached forward
  510. 2. Attending to all inputs
  511. """
  512. mamba_mask = attention_mask
  513. if (past_key_values is not None and past_key_values.has_previous_state()) or (
  514. attention_mask is not None and torch.all(attention_mask == 1)
  515. ):
  516. mamba_mask = None
  517. return mamba_mask
  518. class JambaForCausalLM(MixtralForCausalLM):
  519. def __init__(self, config: JambaConfig):
  520. super().__init__(config)
  521. self.num_experts = config.num_experts
  522. def forward(
  523. self,
  524. input_ids: torch.LongTensor | None = None,
  525. attention_mask: torch.Tensor | None = None,
  526. position_ids: torch.LongTensor | None = None,
  527. past_key_values: Cache | None = None,
  528. inputs_embeds: torch.FloatTensor | None = None,
  529. labels: torch.LongTensor | None = None,
  530. use_cache: bool | None = None,
  531. output_router_logits: bool | None = None,
  532. logits_to_keep: int | torch.Tensor = 0,
  533. **kwargs: Unpack[TransformersKwargs],
  534. ) -> MoeCausalLMOutputWithPast:
  535. r"""
  536. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  537. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  538. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  539. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  540. Example:
  541. ```python
  542. >>> from transformers import AutoTokenizer, JambaForCausalLM
  543. >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
  544. >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
  545. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  546. >>> inputs = tokenizer(prompt, return_tensors="pt")
  547. >>> # Generate
  548. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  549. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  550. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  551. ```"""
  552. return super().forward(
  553. input_ids,
  554. attention_mask,
  555. position_ids,
  556. past_key_values,
  557. inputs_embeds,
  558. labels,
  559. use_cache,
  560. logits_to_keep,
  561. **kwargs,
  562. )
  563. class JambaForSequenceClassification(GenericForSequenceClassification, JambaPreTrainedModel):
  564. pass
  565. __all__ = ["JambaForCausalLM", "JambaForSequenceClassification", "JambaModel", "JambaPreTrainedModel"]