modular_mistral4.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Callable
  15. import torch
  16. import torch.nn.functional as F
  17. from torch import nn
  18. from ... import initialization as init
  19. from ...cache_utils import Cache
  20. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  21. from ...modeling_layers import GenericForSequenceClassification, GenericForTokenClassification
  22. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  23. from ...processing_utils import Unpack
  24. from ...utils import logging
  25. from ...utils.generic import is_flash_attention_requested
  26. from ..deepseek_v3.modeling_deepseek_v3 import (
  27. DeepseekV3Attention,
  28. DeepseekV3DecoderLayer,
  29. DeepseekV3MoE,
  30. DeepseekV3NaiveMoe,
  31. apply_rotary_pos_emb_interleave,
  32. )
  33. from ..llama.modeling_llama import (
  34. LlamaForCausalLM,
  35. LlamaModel,
  36. LlamaRMSNorm,
  37. LlamaRotaryEmbedding,
  38. apply_rotary_pos_emb,
  39. eager_attention_forward,
  40. )
  41. from ..ministral3.modeling_ministral3 import get_llama_4_attn_scale
  42. from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeMLP
  43. from .configuration_mistral4 import Mistral4Config
  44. logger = logging.get_logger(__name__)
  45. class Mistral4RMSNorm(LlamaRMSNorm):
  46. pass
  47. class Mistral4RotaryEmbedding(LlamaRotaryEmbedding):
  48. pass
  49. class Mistral4MLP(Qwen2MoeMLP):
  50. pass
  51. class Mistral4TopkRouter(nn.Module):
  52. def __init__(self, config):
  53. super().__init__()
  54. self.config = config
  55. self.n_routed_experts = config.n_routed_experts
  56. self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
  57. def forward(self, hidden_states):
  58. hidden_states = hidden_states.view(-1, self.config.hidden_size)
  59. router_logits = F.linear(hidden_states, self.weight)
  60. return router_logits
  61. class Mistral4NaiveMoe(DeepseekV3NaiveMoe):
  62. pass
  63. class Mistral4MoE(DeepseekV3MoE):
  64. def route_tokens_to_experts(self, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  65. router_logits = router_logits.softmax(-1)
  66. group_scores = (
  67. router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
  68. )
  69. group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
  70. group_mask = torch.zeros_like(group_scores)
  71. group_mask.scatter_(1, group_idx, 1)
  72. score_mask = (
  73. group_mask.unsqueeze(-1)
  74. .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
  75. .reshape(-1, self.n_routed_experts)
  76. )
  77. scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
  78. topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
  79. topk_weights = router_logits.gather(1, topk_indices)
  80. if self.norm_topk_prob:
  81. denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
  82. topk_weights /= denominator
  83. topk_weights = topk_weights * self.routed_scaling_factor
  84. return topk_indices, topk_weights
  85. class Mistral4Attention(DeepseekV3Attention):
  86. def __init__(self, config: Mistral4Config, layer_idx: int):
  87. nn.Module.__init__(self)
  88. self.config = config
  89. self.layer_idx = layer_idx
  90. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  91. self.attention_dropout = config.attention_dropout
  92. self.num_heads = config.num_attention_heads
  93. self.q_lora_rank = config.q_lora_rank
  94. self.qk_rope_head_dim = config.qk_rope_head_dim
  95. self.kv_lora_rank = config.kv_lora_rank
  96. self.v_head_dim = config.v_head_dim
  97. self.qk_nope_head_dim = config.qk_nope_head_dim
  98. self.qk_head_dim = config.qk_head_dim
  99. self.is_causal = True
  100. if self.q_lora_rank is None:
  101. self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
  102. else:
  103. self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
  104. self.q_a_layernorm = Mistral4RMSNorm(config.q_lora_rank)
  105. self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
  106. self.kv_a_proj_with_mqa = nn.Linear(
  107. config.hidden_size,
  108. self.kv_lora_rank + self.qk_rope_head_dim,
  109. bias=config.attention_bias,
  110. )
  111. self.kv_a_layernorm = Mistral4RMSNorm(self.kv_lora_rank)
  112. self.kv_b_proj = nn.Linear(
  113. self.kv_lora_rank,
  114. self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
  115. bias=False,
  116. )
  117. self.o_proj = nn.Linear(
  118. self.num_heads * self.v_head_dim,
  119. config.hidden_size,
  120. bias=config.attention_bias,
  121. )
  122. self.scaling = self.qk_head_dim ** (-0.5)
  123. def forward(
  124. self,
  125. hidden_states: torch.Tensor,
  126. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  127. attention_mask: torch.Tensor | None,
  128. position_ids: torch.Tensor,
  129. past_key_values: Cache | None = None,
  130. **kwargs: Unpack[FlashAttentionKwargs],
  131. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  132. batch_size, seq_length = hidden_states.shape[:-1]
  133. query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
  134. key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
  135. if self.q_lora_rank is None:
  136. q_states = self.q_proj(hidden_states)
  137. else:
  138. q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
  139. q_states = q_states.view(query_shape).transpose(1, 2)
  140. q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
  141. compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
  142. k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
  143. k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
  144. k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
  145. k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
  146. cos, sin = position_embeddings
  147. if self.config.rope_interleave: # support using interleaved weights for efficiency
  148. q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
  149. else:
  150. q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
  151. k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
  152. query_states = torch.cat((q_pass, q_rot), dim=-1)
  153. key_states = torch.cat((k_pass, k_rot), dim=-1)
  154. query_states = query_states * get_llama_4_attn_scale(
  155. position_ids,
  156. self.config.rope_parameters.get("llama_4_scaling_beta"),
  157. self.config.rope_parameters.get("original_max_position_embeddings"),
  158. ).to(query_states.dtype)
  159. if past_key_values is not None:
  160. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  161. if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
  162. value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
  163. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  164. self.config._attn_implementation, eager_attention_forward
  165. )
  166. attn_output, attn_weights = attention_interface(
  167. self,
  168. query_states,
  169. key_states,
  170. value_states,
  171. attention_mask,
  172. dropout=0.0 if not self.training else self.attention_dropout,
  173. scaling=self.scaling,
  174. **kwargs,
  175. )
  176. if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
  177. attn_output = attn_output[:, :, :, : self.v_head_dim]
  178. attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
  179. attn_output = self.o_proj(attn_output)
  180. return attn_output, attn_weights
  181. class Mistral4DecoderLayer(DeepseekV3DecoderLayer):
  182. def __init__(self, config: Mistral4Config, layer_idx: int):
  183. nn.Module.__init__(self)
  184. self.hidden_size = config.hidden_size
  185. self.self_attn = Mistral4Attention(config=config, layer_idx=layer_idx)
  186. if layer_idx >= config.first_k_dense_replace:
  187. self.mlp = Mistral4MoE(config)
  188. else:
  189. self.mlp = Mistral4MLP(config)
  190. self.input_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  191. self.post_attention_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  192. class Mistral4PreTrainedModel(PreTrainedModel):
  193. config: Mistral4Config
  194. base_model_prefix = "model"
  195. supports_gradient_checkpointing = True
  196. _no_split_modules = ["Mistral4DecoderLayer"]
  197. _skip_keys_device_placement = ["past_key_values"]
  198. _supports_flash_attn = True
  199. _supports_sdpa = True
  200. _supports_flex_attn = True
  201. _can_compile_fullgraph = True
  202. _supports_attention_backend = True
  203. _can_record_outputs = {
  204. "hidden_states": Mistral4DecoderLayer,
  205. "attentions": Mistral4Attention,
  206. }
  207. _keep_in_fp32_modules_strict = []
  208. _keys_to_ignore_on_load_unexpected = []
  209. @torch.no_grad()
  210. def _init_weights(self, module):
  211. super()._init_weights(module)
  212. if isinstance(module, Mistral4TopkRouter):
  213. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  214. elif isinstance(module, Mistral4NaiveMoe):
  215. init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
  216. init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
  217. class Mistral4Model(LlamaModel):
  218. pass
  219. class Mistral4ForCausalLM(LlamaForCausalLM):
  220. pass
  221. class Mistral4ForSequenceClassification(GenericForSequenceClassification, Mistral4PreTrainedModel):
  222. pass
  223. class Mistral4ForTokenClassification(GenericForTokenClassification, Mistral4PreTrainedModel):
  224. pass
  225. __all__ = [
  226. "Mistral4PreTrainedModel",
  227. "Mistral4Model",
  228. "Mistral4ForCausalLM",
  229. "Mistral4ForSequenceClassification",
  230. "Mistral4ForTokenClassification",
  231. ]