modeling_mistral4.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/mistral4/modular_mistral4.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_mistral4.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. import torch.nn.functional as F
  24. from torch import nn
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernel_func_from_hub
  30. from ...masking_utils import create_causal_mask
  31. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  32. from ...modeling_layers import (
  33. GenericForSequenceClassification,
  34. GenericForTokenClassification,
  35. GradientCheckpointingLayer,
  36. )
  37. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  38. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  39. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  40. from ...processing_utils import Unpack
  41. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  42. from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults
  43. from ...utils.output_capturing import capture_outputs
  44. from .configuration_mistral4 import Mistral4Config
  45. @use_kernel_forward_from_hub("RMSNorm")
  46. class Mistral4RMSNorm(nn.Module):
  47. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  48. """
  49. Mistral4RMSNorm is equivalent to T5LayerNorm
  50. """
  51. super().__init__()
  52. self.weight = nn.Parameter(torch.ones(hidden_size))
  53. self.variance_epsilon = eps
  54. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  55. input_dtype = hidden_states.dtype
  56. hidden_states = hidden_states.to(torch.float32)
  57. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  58. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  59. return self.weight * hidden_states.to(input_dtype)
  60. def extra_repr(self):
  61. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  62. class Mistral4RotaryEmbedding(nn.Module):
  63. inv_freq: torch.Tensor # fix linting for `register_buffer`
  64. def __init__(self, config: Mistral4Config, device=None):
  65. super().__init__()
  66. self.max_seq_len_cached = config.max_position_embeddings
  67. self.original_max_seq_len = config.max_position_embeddings
  68. self.config = config
  69. self.rope_type = self.config.rope_parameters["rope_type"]
  70. rope_init_fn: Callable = self.compute_default_rope_parameters
  71. if self.rope_type != "default":
  72. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  73. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  74. self.register_buffer("inv_freq", inv_freq, persistent=False)
  75. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  76. @staticmethod
  77. def compute_default_rope_parameters(
  78. config: Mistral4Config | None = None,
  79. device: Optional["torch.device"] = None,
  80. seq_len: int | None = None,
  81. ) -> tuple["torch.Tensor", float]:
  82. """
  83. Computes the inverse frequencies according to the original RoPE implementation
  84. Args:
  85. config ([`~transformers.PreTrainedConfig`]):
  86. The model configuration.
  87. device (`torch.device`):
  88. The device to use for initialization of the inverse frequencies.
  89. seq_len (`int`, *optional*):
  90. The current sequence length. Unused for this type of RoPE.
  91. Returns:
  92. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  93. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  94. """
  95. base = config.rope_parameters["rope_theta"]
  96. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  97. attention_factor = 1.0 # Unused in this type of RoPE
  98. # Compute the inverse frequencies
  99. inv_freq = 1.0 / (
  100. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  101. )
  102. return inv_freq, attention_factor
  103. @torch.no_grad()
  104. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  105. def forward(self, x, position_ids):
  106. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  107. position_ids_expanded = position_ids[:, None, :].float()
  108. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  109. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  110. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  111. emb = torch.cat((freqs, freqs), dim=-1)
  112. cos = emb.cos() * self.attention_scaling
  113. sin = emb.sin() * self.attention_scaling
  114. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  115. class Mistral4MLP(nn.Module):
  116. def __init__(self, config, intermediate_size=None):
  117. super().__init__()
  118. self.config = config
  119. self.hidden_size = config.hidden_size
  120. self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
  121. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  122. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  123. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  124. self.act_fn = ACT2FN[config.hidden_act]
  125. def forward(self, x):
  126. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  127. return down_proj
  128. class Mistral4TopkRouter(nn.Module):
  129. def __init__(self, config):
  130. super().__init__()
  131. self.config = config
  132. self.n_routed_experts = config.n_routed_experts
  133. self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
  134. def forward(self, hidden_states):
  135. hidden_states = hidden_states.view(-1, self.config.hidden_size)
  136. router_logits = F.linear(hidden_states, self.weight)
  137. return router_logits
  138. @use_experts_implementation
  139. class Mistral4NaiveMoe(nn.Module):
  140. """Collection of expert weights stored as 3D tensors."""
  141. def __init__(self, config):
  142. super().__init__()
  143. self.num_experts = config.num_local_experts
  144. self.hidden_dim = config.hidden_size
  145. self.intermediate_dim = config.moe_intermediate_size
  146. self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
  147. self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
  148. self.act_fn = ACT2FN[config.hidden_act]
  149. def forward(
  150. self,
  151. hidden_states: torch.Tensor,
  152. top_k_index: torch.Tensor,
  153. top_k_weights: torch.Tensor,
  154. ) -> torch.Tensor:
  155. final_hidden_states = torch.zeros_like(hidden_states)
  156. with torch.no_grad():
  157. expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
  158. expert_mask = expert_mask.permute(2, 1, 0)
  159. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  160. for expert_idx in expert_hit:
  161. expert_idx = expert_idx[0]
  162. if expert_idx == self.num_experts:
  163. continue
  164. top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
  165. current_state = hidden_states[token_idx]
  166. gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
  167. current_hidden_states = self.act_fn(gate) * up
  168. current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
  169. current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
  170. final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
  171. return final_hidden_states
  172. class Mistral4MoE(nn.Module):
  173. """
  174. A mixed expert module containing shared experts.
  175. """
  176. def __init__(self, config):
  177. super().__init__()
  178. self.config = config
  179. self.experts = Mistral4NaiveMoe(config)
  180. self.gate = Mistral4TopkRouter(config)
  181. self.shared_experts = Mistral4MLP(
  182. config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
  183. )
  184. self.n_routed_experts = config.n_routed_experts
  185. self.n_group = config.n_group
  186. self.topk_group = config.topk_group
  187. self.norm_topk_prob = config.norm_topk_prob
  188. self.routed_scaling_factor = config.routed_scaling_factor
  189. self.top_k = config.num_experts_per_tok
  190. def route_tokens_to_experts(self, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  191. router_logits = router_logits.softmax(-1)
  192. group_scores = (
  193. router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
  194. )
  195. group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
  196. group_mask = torch.zeros_like(group_scores)
  197. group_mask.scatter_(1, group_idx, 1)
  198. score_mask = (
  199. group_mask.unsqueeze(-1)
  200. .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
  201. .reshape(-1, self.n_routed_experts)
  202. )
  203. scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
  204. topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
  205. topk_weights = router_logits.gather(1, topk_indices)
  206. if self.norm_topk_prob:
  207. denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
  208. topk_weights /= denominator
  209. topk_weights = topk_weights * self.routed_scaling_factor
  210. return topk_indices, topk_weights
  211. def forward(self, hidden_states):
  212. residuals = hidden_states
  213. orig_shape = hidden_states.shape
  214. router_logits = self.gate(hidden_states)
  215. topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
  216. hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
  217. hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
  218. hidden_states = hidden_states + self.shared_experts(residuals)
  219. return hidden_states
  220. def rotate_half(x):
  221. """Rotates half the hidden dims of the input."""
  222. x1 = x[..., : x.shape[-1] // 2]
  223. x2 = x[..., x.shape[-1] // 2 :]
  224. return torch.cat((-x2, x1), dim=-1)
  225. @use_kernel_func_from_hub("rotary_pos_emb")
  226. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  227. """Applies Rotary Position Embedding to the query and key tensors.
  228. Args:
  229. q (`torch.Tensor`): The query tensor.
  230. k (`torch.Tensor`): The key tensor.
  231. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  232. sin (`torch.Tensor`): The sine part of the rotary embedding.
  233. unsqueeze_dim (`int`, *optional*, defaults to 1):
  234. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  235. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  236. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  237. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  238. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  239. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  240. Returns:
  241. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  242. """
  243. cos = cos.unsqueeze(unsqueeze_dim)
  244. sin = sin.unsqueeze(unsqueeze_dim)
  245. q_embed = (q * cos) + (rotate_half(q) * sin)
  246. k_embed = (k * cos) + (rotate_half(k) * sin)
  247. return q_embed, k_embed
  248. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  249. """
  250. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  251. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  252. """
  253. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  254. if n_rep == 1:
  255. return hidden_states
  256. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  257. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  258. def eager_attention_forward(
  259. module: nn.Module,
  260. query: torch.Tensor,
  261. key: torch.Tensor,
  262. value: torch.Tensor,
  263. attention_mask: torch.Tensor | None,
  264. scaling: float,
  265. dropout: float = 0.0,
  266. **kwargs: Unpack[TransformersKwargs],
  267. ):
  268. key_states = repeat_kv(key, module.num_key_value_groups)
  269. value_states = repeat_kv(value, module.num_key_value_groups)
  270. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  271. if attention_mask is not None:
  272. attn_weights = attn_weights + attention_mask
  273. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  274. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  275. attn_output = torch.matmul(attn_weights, value_states)
  276. attn_output = attn_output.transpose(1, 2).contiguous()
  277. return attn_output, attn_weights
  278. def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  279. r"""
  280. TODO let's just use the original freqcis computation to not have the view
  281. transpose + reshape! This is not optimized!
  282. Applies Rotary Position Embedding to the query and key tensors.
  283. Args:
  284. q (`torch.Tensor`): The query tensor.
  285. k (`torch.Tensor`): The key tensor.
  286. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  287. sin (`torch.Tensor`): The sine part of the rotary embedding.
  288. position_ids (`torch.Tensor`):
  289. The position indices of the tokens corresponding to the query and key tensors. For example, this can be
  290. used to pass offsetted position ids when working with a KV-cache.
  291. unsqueeze_dim (`int`, *optional*, defaults to 1):
  292. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  293. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  294. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  295. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  296. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  297. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  298. Returns:
  299. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  300. """
  301. cos = cos.unsqueeze(unsqueeze_dim)
  302. sin = sin.unsqueeze(unsqueeze_dim)
  303. b, h, s, d = q.shape
  304. q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
  305. b, h, s, d = k.shape
  306. k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
  307. q_embed = (q * cos) + (rotate_half(q) * sin)
  308. k_embed = (k * cos) + (rotate_half(k) * sin)
  309. return q_embed, k_embed
  310. def get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
  311. scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
  312. return scaling[:, None, :, None]
  313. class Mistral4Attention(nn.Module):
  314. """Multi-headed attention from 'Attention Is All You Need' paper"""
  315. def __init__(self, config: Mistral4Config, layer_idx: int):
  316. super().__init__()
  317. self.config = config
  318. self.layer_idx = layer_idx
  319. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  320. self.attention_dropout = config.attention_dropout
  321. self.num_heads = config.num_attention_heads
  322. self.q_lora_rank = config.q_lora_rank
  323. self.qk_rope_head_dim = config.qk_rope_head_dim
  324. self.kv_lora_rank = config.kv_lora_rank
  325. self.v_head_dim = config.v_head_dim
  326. self.qk_nope_head_dim = config.qk_nope_head_dim
  327. self.qk_head_dim = config.qk_head_dim
  328. self.is_causal = True
  329. if self.q_lora_rank is None:
  330. self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
  331. else:
  332. self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
  333. self.q_a_layernorm = Mistral4RMSNorm(config.q_lora_rank)
  334. self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
  335. self.kv_a_proj_with_mqa = nn.Linear(
  336. config.hidden_size,
  337. self.kv_lora_rank + self.qk_rope_head_dim,
  338. bias=config.attention_bias,
  339. )
  340. self.kv_a_layernorm = Mistral4RMSNorm(self.kv_lora_rank)
  341. self.kv_b_proj = nn.Linear(
  342. self.kv_lora_rank,
  343. self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
  344. bias=False,
  345. )
  346. self.o_proj = nn.Linear(
  347. self.num_heads * self.v_head_dim,
  348. config.hidden_size,
  349. bias=config.attention_bias,
  350. )
  351. self.scaling = self.qk_head_dim ** (-0.5)
  352. def forward(
  353. self,
  354. hidden_states: torch.Tensor,
  355. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  356. attention_mask: torch.Tensor | None,
  357. position_ids: torch.Tensor,
  358. past_key_values: Cache | None = None,
  359. **kwargs: Unpack[FlashAttentionKwargs],
  360. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  361. batch_size, seq_length = hidden_states.shape[:-1]
  362. query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
  363. key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
  364. if self.q_lora_rank is None:
  365. q_states = self.q_proj(hidden_states)
  366. else:
  367. q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
  368. q_states = q_states.view(query_shape).transpose(1, 2)
  369. q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
  370. compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
  371. k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
  372. k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
  373. k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
  374. k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
  375. cos, sin = position_embeddings
  376. if self.config.rope_interleave: # support using interleaved weights for efficiency
  377. q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
  378. else:
  379. q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
  380. k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
  381. query_states = torch.cat((q_pass, q_rot), dim=-1)
  382. key_states = torch.cat((k_pass, k_rot), dim=-1)
  383. query_states = query_states * get_llama_4_attn_scale(
  384. position_ids,
  385. self.config.rope_parameters.get("llama_4_scaling_beta"),
  386. self.config.rope_parameters.get("original_max_position_embeddings"),
  387. ).to(query_states.dtype)
  388. if past_key_values is not None:
  389. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  390. if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
  391. value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
  392. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  393. self.config._attn_implementation, eager_attention_forward
  394. )
  395. attn_output, attn_weights = attention_interface(
  396. self,
  397. query_states,
  398. key_states,
  399. value_states,
  400. attention_mask,
  401. dropout=0.0 if not self.training else self.attention_dropout,
  402. scaling=self.scaling,
  403. **kwargs,
  404. )
  405. if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
  406. attn_output = attn_output[:, :, :, : self.v_head_dim]
  407. attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
  408. attn_output = self.o_proj(attn_output)
  409. return attn_output, attn_weights
  410. class Mistral4DecoderLayer(GradientCheckpointingLayer):
  411. def __init__(self, config: Mistral4Config, layer_idx: int):
  412. super().__init__()
  413. self.hidden_size = config.hidden_size
  414. self.self_attn = Mistral4Attention(config=config, layer_idx=layer_idx)
  415. if layer_idx >= config.first_k_dense_replace:
  416. self.mlp = Mistral4MoE(config)
  417. else:
  418. self.mlp = Mistral4MLP(config)
  419. self.input_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  420. self.post_attention_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  421. def forward(
  422. self,
  423. hidden_states: torch.Tensor,
  424. attention_mask: torch.Tensor | None = None,
  425. position_ids: torch.LongTensor | None = None,
  426. past_key_values: Cache | None = None,
  427. use_cache: bool | None = False,
  428. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  429. **kwargs: Unpack[TransformersKwargs],
  430. ) -> torch.Tensor:
  431. residual = hidden_states
  432. hidden_states = self.input_layernorm(hidden_states)
  433. # Self Attention
  434. hidden_states, _ = self.self_attn(
  435. hidden_states=hidden_states,
  436. attention_mask=attention_mask,
  437. position_ids=position_ids,
  438. past_key_values=past_key_values,
  439. use_cache=use_cache,
  440. position_embeddings=position_embeddings,
  441. **kwargs,
  442. )
  443. hidden_states = residual + hidden_states
  444. # Fully Connected
  445. residual = hidden_states
  446. hidden_states = self.post_attention_layernorm(hidden_states)
  447. hidden_states = self.mlp(hidden_states)
  448. hidden_states = residual + hidden_states
  449. return hidden_states
  450. class Mistral4PreTrainedModel(PreTrainedModel):
  451. config: Mistral4Config
  452. base_model_prefix = "model"
  453. supports_gradient_checkpointing = True
  454. _no_split_modules = ["Mistral4DecoderLayer"]
  455. _skip_keys_device_placement = ["past_key_values"]
  456. _supports_flash_attn = True
  457. _supports_sdpa = True
  458. _supports_flex_attn = True
  459. _can_compile_fullgraph = True
  460. _supports_attention_backend = True
  461. _can_record_outputs = {
  462. "hidden_states": Mistral4DecoderLayer,
  463. "attentions": Mistral4Attention,
  464. }
  465. _keep_in_fp32_modules_strict = []
  466. _keys_to_ignore_on_load_unexpected = []
  467. @torch.no_grad()
  468. def _init_weights(self, module):
  469. super()._init_weights(module)
  470. if isinstance(module, Mistral4TopkRouter):
  471. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  472. elif isinstance(module, Mistral4NaiveMoe):
  473. init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
  474. init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
  475. @auto_docstring
  476. class Mistral4Model(Mistral4PreTrainedModel):
  477. def __init__(self, config: Mistral4Config):
  478. super().__init__(config)
  479. self.padding_idx = config.pad_token_id
  480. self.vocab_size = config.vocab_size
  481. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  482. self.layers = nn.ModuleList(
  483. [Mistral4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  484. )
  485. self.norm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  486. self.rotary_emb = Mistral4RotaryEmbedding(config=config)
  487. self.gradient_checkpointing = False
  488. # Initialize weights and apply final processing
  489. self.post_init()
  490. @merge_with_config_defaults
  491. @capture_outputs
  492. @auto_docstring
  493. def forward(
  494. self,
  495. input_ids: torch.LongTensor | None = None,
  496. attention_mask: torch.Tensor | None = None,
  497. position_ids: torch.LongTensor | None = None,
  498. past_key_values: Cache | None = None,
  499. inputs_embeds: torch.FloatTensor | None = None,
  500. use_cache: bool | None = None,
  501. **kwargs: Unpack[TransformersKwargs],
  502. ) -> BaseModelOutputWithPast:
  503. if (input_ids is None) ^ (inputs_embeds is not None):
  504. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  505. if inputs_embeds is None:
  506. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  507. if use_cache and past_key_values is None:
  508. past_key_values = DynamicCache(config=self.config)
  509. if position_ids is None:
  510. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  511. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  512. position_ids = position_ids.unsqueeze(0)
  513. causal_mask = create_causal_mask(
  514. config=self.config,
  515. inputs_embeds=inputs_embeds,
  516. attention_mask=attention_mask,
  517. past_key_values=past_key_values,
  518. position_ids=position_ids,
  519. )
  520. hidden_states = inputs_embeds
  521. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  522. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  523. hidden_states = decoder_layer(
  524. hidden_states,
  525. attention_mask=causal_mask,
  526. position_embeddings=position_embeddings,
  527. position_ids=position_ids,
  528. past_key_values=past_key_values,
  529. use_cache=use_cache,
  530. **kwargs,
  531. )
  532. hidden_states = self.norm(hidden_states)
  533. return BaseModelOutputWithPast(
  534. last_hidden_state=hidden_states,
  535. past_key_values=past_key_values,
  536. )
  537. @auto_docstring
  538. class Mistral4ForCausalLM(Mistral4PreTrainedModel, GenerationMixin):
  539. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  540. _tp_plan = {"lm_head": "colwise_gather_output"}
  541. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  542. def __init__(self, config):
  543. super().__init__(config)
  544. self.model = Mistral4Model(config)
  545. self.vocab_size = config.vocab_size
  546. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  547. # Initialize weights and apply final processing
  548. self.post_init()
  549. @can_return_tuple
  550. @auto_docstring
  551. def forward(
  552. self,
  553. input_ids: torch.LongTensor | None = None,
  554. attention_mask: torch.Tensor | None = None,
  555. position_ids: torch.LongTensor | None = None,
  556. past_key_values: Cache | None = None,
  557. inputs_embeds: torch.FloatTensor | None = None,
  558. labels: torch.LongTensor | None = None,
  559. use_cache: bool | None = None,
  560. logits_to_keep: int | torch.Tensor = 0,
  561. **kwargs: Unpack[TransformersKwargs],
  562. ) -> CausalLMOutputWithPast:
  563. r"""
  564. Example:
  565. ```python
  566. >>> from transformers import AutoTokenizer, Mistral4ForCausalLM
  567. >>> model = Mistral4ForCausalLM.from_pretrained("meta-mistral4/Mistral4-2-7b-hf")
  568. >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral4/Mistral4-2-7b-hf")
  569. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  570. >>> inputs = tokenizer(prompt, return_tensors="pt")
  571. >>> # Generate
  572. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  573. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  574. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  575. ```"""
  576. outputs: BaseModelOutputWithPast = self.model(
  577. input_ids=input_ids,
  578. attention_mask=attention_mask,
  579. position_ids=position_ids,
  580. past_key_values=past_key_values,
  581. inputs_embeds=inputs_embeds,
  582. use_cache=use_cache,
  583. **kwargs,
  584. )
  585. hidden_states = outputs.last_hidden_state
  586. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  587. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  588. logits = self.lm_head(hidden_states[:, slice_indices, :])
  589. loss = None
  590. if labels is not None:
  591. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  592. return CausalLMOutputWithPast(
  593. loss=loss,
  594. logits=logits,
  595. past_key_values=outputs.past_key_values,
  596. hidden_states=outputs.hidden_states,
  597. attentions=outputs.attentions,
  598. )
  599. class Mistral4ForSequenceClassification(GenericForSequenceClassification, Mistral4PreTrainedModel):
  600. pass
  601. class Mistral4ForTokenClassification(GenericForTokenClassification, Mistral4PreTrainedModel):
  602. pass
  603. __all__ = [
  604. "Mistral4PreTrainedModel",
  605. "Mistral4Model",
  606. "Mistral4ForCausalLM",
  607. "Mistral4ForSequenceClassification",
  608. "Mistral4ForTokenClassification",
  609. ]