modular_ministral.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright 2025 Mistral AI and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from huggingface_hub.dataclasses import strict
  16. from torch import nn
  17. from ...cache_utils import Cache, DynamicCache
  18. from ...configuration_utils import PreTrainedConfig
  19. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  20. from ...modeling_outputs import BaseModelOutputWithPast
  21. from ...processing_utils import Unpack
  22. from ...utils import TransformersKwargs, auto_docstring
  23. from ...utils.generic import merge_with_config_defaults
  24. from ...utils.output_capturing import capture_outputs
  25. from ..mistral.configuration_mistral import MistralConfig
  26. from ..qwen2.modeling_qwen2 import (
  27. Qwen2Attention,
  28. Qwen2DecoderLayer,
  29. Qwen2ForCausalLM,
  30. Qwen2ForQuestionAnswering,
  31. Qwen2ForSequenceClassification,
  32. Qwen2ForTokenClassification,
  33. Qwen2MLP,
  34. Qwen2Model,
  35. Qwen2PreTrainedModel,
  36. Qwen2RMSNorm,
  37. Qwen2RotaryEmbedding,
  38. )
  39. @auto_docstring(checkpoint="mistralai/Ministral-8B-Instruct-2410")
  40. @strict
  41. class MinistralConfig(MistralConfig):
  42. r"""
  43. Example:
  44. ```python
  45. >>> from transformers import MinistralModel, MinistralConfig
  46. >>> # Initializing a Ministral 8B style configuration
  47. >>> configuration = MinistralConfig()
  48. >>> # Initializing a model from the Ministral 8B style configuration
  49. >>> model = MinistralModel(configuration)
  50. >>> # Accessing the model configuration
  51. >>> configuration = model.config
  52. ```"""
  53. model_type = "ministral"
  54. layer_types: list[str] | None = None
  55. def __post_init__(self, **kwargs):
  56. if self.num_key_value_heads is None:
  57. self.num_key_value_heads = self.num_attention_heads
  58. if self.layer_types is None:
  59. self.layer_types = [
  60. "sliding_attention" if self.sliding_window is not None else "full_attention"
  61. ] * self.num_hidden_layers
  62. PreTrainedConfig.__post_init__(self, **kwargs)
  63. class MinistralMLP(Qwen2MLP):
  64. pass
  65. class MinistralAttention(Qwen2Attention):
  66. def __init__(self, config, layer_idx: int):
  67. super().__init__(config, layer_idx)
  68. # Match Mistral: q/k/v do not have bias
  69. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  70. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  71. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  72. class MinistralRMSNorm(Qwen2RMSNorm):
  73. pass
  74. class MinistralDecoderLayer(Qwen2DecoderLayer):
  75. pass
  76. class MinistralPreTrainedModel(Qwen2PreTrainedModel):
  77. pass
  78. class MinistralRotaryEmbedding(Qwen2RotaryEmbedding):
  79. pass
  80. class MinistralModel(Qwen2Model):
  81. def __init__(self, config: MinistralConfig):
  82. super().__init__(config)
  83. del self.has_sliding_layers
  84. @merge_with_config_defaults
  85. @capture_outputs
  86. @auto_docstring
  87. def forward(
  88. self,
  89. input_ids: torch.LongTensor | None = None,
  90. attention_mask: torch.Tensor | None = None,
  91. position_ids: torch.LongTensor | None = None,
  92. past_key_values: Cache | None = None,
  93. inputs_embeds: torch.FloatTensor | None = None,
  94. use_cache: bool | None = None,
  95. **kwargs: Unpack[TransformersKwargs],
  96. ) -> BaseModelOutputWithPast:
  97. if (input_ids is None) ^ (inputs_embeds is not None):
  98. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  99. if inputs_embeds is None:
  100. inputs_embeds = self.embed_tokens(input_ids)
  101. if use_cache and past_key_values is None:
  102. past_key_values = DynamicCache(config=self.config)
  103. if position_ids is None:
  104. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  105. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  106. position_ids = position_ids.unsqueeze(0)
  107. # It may already have been prepared by e.g. `generate`
  108. if not isinstance(causal_mask_mapping := attention_mask, dict):
  109. # Prepare mask arguments
  110. mask_kwargs = {
  111. "config": self.config,
  112. "inputs_embeds": inputs_embeds,
  113. "attention_mask": attention_mask,
  114. "past_key_values": past_key_values,
  115. "position_ids": position_ids,
  116. }
  117. # Create the masks
  118. causal_mask_mapping = {
  119. "full_attention": create_causal_mask(**mask_kwargs),
  120. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  121. }
  122. hidden_states = inputs_embeds
  123. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  124. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  125. hidden_states = decoder_layer(
  126. hidden_states,
  127. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  128. position_ids=position_ids,
  129. past_key_values=past_key_values,
  130. use_cache=use_cache,
  131. position_embeddings=position_embeddings,
  132. **kwargs,
  133. )
  134. hidden_states = self.norm(hidden_states)
  135. return BaseModelOutputWithPast(
  136. last_hidden_state=hidden_states,
  137. past_key_values=past_key_values if use_cache else None,
  138. )
  139. class MinistralForCausalLM(Qwen2ForCausalLM):
  140. pass
  141. class MinistralForSequenceClassification(Qwen2ForSequenceClassification):
  142. pass
  143. class MinistralForTokenClassification(Qwen2ForTokenClassification):
  144. pass
  145. class MinistralForQuestionAnswering(Qwen2ForQuestionAnswering):
  146. pass
  147. __all__ = [
  148. "MinistralConfig",
  149. "MinistralPreTrainedModel",
  150. "MinistralModel",
  151. "MinistralForCausalLM",
  152. "MinistralForSequenceClassification",
  153. "MinistralForTokenClassification",
  154. "MinistralForQuestionAnswering",
  155. ]