# Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable import torch import torch.nn.functional as F from torch import nn from ... import initialization as init from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GenericForSequenceClassification, GenericForTokenClassification from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import logging from ...utils.generic import is_flash_attention_requested from ..deepseek_v3.modeling_deepseek_v3 import ( DeepseekV3Attention, DeepseekV3DecoderLayer, DeepseekV3MoE, DeepseekV3NaiveMoe, apply_rotary_pos_emb_interleave, ) from ..llama.modeling_llama import ( LlamaForCausalLM, LlamaModel, LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward, ) from ..ministral3.modeling_ministral3 import get_llama_4_attn_scale from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeMLP from .configuration_mistral4 import Mistral4Config logger = logging.get_logger(__name__) class Mistral4RMSNorm(LlamaRMSNorm): pass class Mistral4RotaryEmbedding(LlamaRotaryEmbedding): pass class Mistral4MLP(Qwen2MoeMLP): pass class Mistral4TopkRouter(nn.Module): def __init__(self, config): super().__init__() self.config = config self.n_routed_experts = config.n_routed_experts self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) def forward(self, hidden_states): hidden_states = hidden_states.view(-1, self.config.hidden_size) router_logits = F.linear(hidden_states, self.weight) return router_logits class Mistral4NaiveMoe(DeepseekV3NaiveMoe): pass class Mistral4MoE(DeepseekV3MoE): def route_tokens_to_experts(self, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: router_logits = router_logits.softmax(-1) group_scores = ( router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1) ) group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] group_mask = torch.zeros_like(group_scores) group_mask.scatter_(1, group_idx, 1) score_mask = ( group_mask.unsqueeze(-1) .expand(-1, self.n_group, self.n_routed_experts // self.n_group) .reshape(-1, self.n_routed_experts) ) scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0) topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] topk_weights = router_logits.gather(1, topk_indices) if self.norm_topk_prob: denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 topk_weights /= denominator topk_weights = topk_weights * self.routed_scaling_factor return topk_indices, topk_weights class Mistral4Attention(DeepseekV3Attention): def __init__(self, config: Mistral4Config, layer_idx: int): nn.Module.__init__(self) self.config = config self.layer_idx = layer_idx self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.attention_dropout = config.attention_dropout self.num_heads = config.num_attention_heads self.q_lora_rank = config.q_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim self.qk_head_dim = config.qk_head_dim self.is_causal = True if self.q_lora_rank is None: self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) else: self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) self.q_a_layernorm = Mistral4RMSNorm(config.q_lora_rank) self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) self.kv_a_proj_with_mqa = nn.Linear( config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=config.attention_bias, ) self.kv_a_layernorm = Mistral4RMSNorm(self.kv_lora_rank) self.kv_b_proj = nn.Linear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, ) self.o_proj = nn.Linear( self.num_heads * self.v_head_dim, config.hidden_size, bias=config.attention_bias, ) self.scaling = self.qk_head_dim ** (-0.5) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None, position_ids: torch.Tensor, past_key_values: Cache | None = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: batch_size, seq_length = hidden_states.shape[:-1] query_shape = (batch_size, seq_length, -1, self.qk_head_dim) key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) if self.q_lora_rank is None: q_states = self.q_proj(hidden_states) else: q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q_states = q_states.view(query_shape).transpose(1, 2) q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) cos, sin = position_embeddings if self.config.rope_interleave: # support using interleaved weights for efficiency q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) else: q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) k_rot = k_rot.expand(*k_pass.shape[:-1], -1) query_states = torch.cat((q_pass, q_rot), dim=-1) key_states = torch.cat((k_pass, k_rot), dim=-1) query_states = query_states * get_llama_4_attn_scale( position_ids, self.config.rope_parameters.get("llama_4_scaling_beta"), self.config.rope_parameters.get("original_max_position_embeddings"), ).to(query_states.dtype) if past_key_values is not None: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: attn_output = attn_output[:, :, :, : self.v_head_dim] attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Mistral4DecoderLayer(DeepseekV3DecoderLayer): def __init__(self, config: Mistral4Config, layer_idx: int): nn.Module.__init__(self) self.hidden_size = config.hidden_size self.self_attn = Mistral4Attention(config=config, layer_idx=layer_idx) if layer_idx >= config.first_k_dense_replace: self.mlp = Mistral4MoE(config) else: self.mlp = Mistral4MLP(config) self.input_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) class Mistral4PreTrainedModel(PreTrainedModel): config: Mistral4Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Mistral4DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Mistral4DecoderLayer, "attentions": Mistral4Attention, } _keep_in_fp32_modules_strict = [] _keys_to_ignore_on_load_unexpected = [] @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Mistral4TopkRouter): init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) elif isinstance(module, Mistral4NaiveMoe): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) class Mistral4Model(LlamaModel): pass class Mistral4ForCausalLM(LlamaForCausalLM): pass class Mistral4ForSequenceClassification(GenericForSequenceClassification, Mistral4PreTrainedModel): pass class Mistral4ForTokenClassification(GenericForTokenClassification, Mistral4PreTrainedModel): pass __all__ = [ "Mistral4PreTrainedModel", "Mistral4Model", "Mistral4ForCausalLM", "Mistral4ForSequenceClassification", "Mistral4ForTokenClassification", ]