modeling_mistral3.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/mistral3/modular_mistral3.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_mistral3.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from dataclasses import dataclass
  22. import torch
  23. from torch import nn
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache
  26. from ...generation import GenerationMixin
  27. from ...integrations import use_kernel_forward_from_hub
  28. from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
  29. from ...modeling_utils import PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...utils import TransformersKwargs, auto_docstring, torch_compilable_check
  32. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  33. from ..auto import AutoModel
  34. from .configuration_mistral3 import Mistral3Config
  35. @use_kernel_forward_from_hub("RMSNorm")
  36. class Mistral3RMSNorm(nn.Module):
  37. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  38. """
  39. Mistral3RMSNorm is equivalent to T5LayerNorm
  40. """
  41. super().__init__()
  42. self.weight = nn.Parameter(torch.ones(hidden_size))
  43. self.variance_epsilon = eps
  44. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  45. input_dtype = hidden_states.dtype
  46. hidden_states = hidden_states.to(torch.float32)
  47. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  48. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  49. return self.weight * hidden_states.to(input_dtype)
  50. def extra_repr(self):
  51. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  52. class Mistral3PatchMerger(nn.Module):
  53. """
  54. Learned merging of spatial_merge_size ** 2 patches
  55. """
  56. def __init__(self, config: Mistral3Config):
  57. super().__init__()
  58. self.config = config
  59. hidden_size = config.vision_config.hidden_size
  60. self.spatial_merge_size = config.spatial_merge_size
  61. self.patch_size = self.config.vision_config.patch_size
  62. self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False)
  63. def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor:
  64. image_sizes = [
  65. (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes
  66. ]
  67. tokens_per_image = [h * w for h, w in image_sizes]
  68. d = image_features.shape[-1]
  69. permuted_tensor = []
  70. for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
  71. # Reshape image_tokens into a 2D grid
  72. h, w = image_sizes[image_index]
  73. image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
  74. grid = torch.nn.functional.unfold(
  75. image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size
  76. )
  77. grid = grid.view(d * self.spatial_merge_size**2, -1).t()
  78. permuted_tensor.append(grid)
  79. image_features = torch.cat(permuted_tensor, dim=0)
  80. image_features = self.merging_layer(image_features)
  81. return image_features
  82. class Mistral3MultiModalProjector(nn.Module):
  83. def __init__(self, config: Mistral3Config):
  84. super().__init__()
  85. self.norm = Mistral3RMSNorm(config.vision_config.hidden_size, eps=config.text_config.rms_norm_eps)
  86. self.patch_merger = Mistral3PatchMerger(config)
  87. # We have hidden_size * the number of vision feature layers
  88. self.num_feature_layers = (
  89. 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
  90. )
  91. self.linear_1 = nn.Linear(
  92. config.vision_config.hidden_size * self.num_feature_layers,
  93. config.text_config.hidden_size,
  94. bias=config.multimodal_projector_bias,
  95. )
  96. self.act = ACT2FN[config.projector_hidden_act]
  97. self.linear_2 = nn.Linear(
  98. config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
  99. )
  100. def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
  101. image_features = self.norm(image_features)
  102. image_features = self.patch_merger(image_features, image_sizes)
  103. hidden_states = self.linear_1(image_features)
  104. hidden_states = self.act(hidden_states)
  105. hidden_states = self.linear_2(hidden_states)
  106. return hidden_states
  107. @dataclass
  108. @auto_docstring(
  109. custom_intro="""
  110. Base class for Mistral3 causal language model (or autoregressive) outputs.
  111. """
  112. )
  113. class Mistral3CausalLMOutputWithPast(ModelOutput):
  114. r"""
  115. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  116. Language modeling loss (for next-token prediction).
  117. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  118. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  119. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  120. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  121. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  122. `past_key_values` input) to speed up sequential decoding.
  123. image_hidden_states (`torch.FloatTensor`, *optional*):
  124. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  125. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  126. """
  127. loss: torch.FloatTensor | None = None
  128. logits: torch.FloatTensor | None = None
  129. past_key_values: Cache | None = None
  130. hidden_states: tuple[torch.FloatTensor] | None = None
  131. attentions: tuple[torch.FloatTensor] | None = None
  132. image_hidden_states: torch.FloatTensor | None = None
  133. @dataclass
  134. @auto_docstring(
  135. custom_intro="""
  136. Base class for Mistral3 outputs, with hidden states and attentions.
  137. """
  138. )
  139. class Mistral3ModelOutputWithPast(BaseModelOutputWithPast):
  140. r"""
  141. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  142. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  143. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  144. `past_key_values` input) to speed up sequential decoding.
  145. image_hidden_states (`torch.FloatTensor`, *optional*):
  146. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  147. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  148. """
  149. image_hidden_states: torch.FloatTensor | None = None
  150. @auto_docstring
  151. class Mistral3PreTrainedModel(PreTrainedModel):
  152. config: Mistral3Config
  153. base_model_prefix = "model"
  154. input_modalities = ("image", "text")
  155. supports_gradient_checkpointing = True
  156. _skip_keys_device_placement = "past_key_values"
  157. _supports_flash_attn = True
  158. _supports_sdpa = True
  159. _can_compile_fullgraph = True
  160. _supports_flex_attn = True
  161. _supports_attention_backend = True
  162. @auto_docstring(
  163. custom_intro="""
  164. The Mistral3 model which consists of a vision backbone and a language model, without a language modeling head.
  165. """
  166. )
  167. class Mistral3Model(Mistral3PreTrainedModel):
  168. def __init__(self, config: Mistral3Config):
  169. super().__init__(config)
  170. self.vision_tower = AutoModel.from_config(config.vision_config)
  171. self.multi_modal_projector = Mistral3MultiModalProjector(config)
  172. self.language_model = AutoModel.from_config(config.text_config)
  173. self.post_init()
  174. def get_input_embeddings(self):
  175. return self.language_model.get_input_embeddings()
  176. def set_input_embeddings(self, value):
  177. self.language_model.set_input_embeddings(value)
  178. @merge_with_config_defaults
  179. @can_return_tuple
  180. @auto_docstring(
  181. custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
  182. )
  183. def get_image_features(
  184. self,
  185. pixel_values: torch.FloatTensor,
  186. image_sizes: torch.Tensor,
  187. vision_feature_layer: int | list[int] | list[int] | None = None,
  188. output_hidden_states: bool | None = None,
  189. **kwargs: Unpack[TransformersKwargs],
  190. ) -> tuple | BaseModelOutputWithPooling:
  191. kwargs = {k: v for k, v in kwargs.items() if v is not None}
  192. # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
  193. image_outputs = self.vision_tower(
  194. pixel_values,
  195. image_sizes=image_sizes,
  196. output_hidden_states=True, # Ignore arg on purpose
  197. return_dict=True,
  198. **kwargs,
  199. )
  200. # If we have one vision feature layer, return the corresponding hidden states,
  201. # otherwise, select the hidden states of each feature layer and concatenate them
  202. if isinstance(vision_feature_layer, int):
  203. selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
  204. else:
  205. hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
  206. selected_image_feature = torch.cat(hs_pool, dim=-1)
  207. image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
  208. downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size
  209. split_sizes = (
  210. (torch.as_tensor(image_sizes, device=image_features.device) // downsample_ratio).prod(dim=-1).tolist()
  211. )
  212. image_features = torch.split(image_features.squeeze(0), split_sizes)
  213. image_outputs.pooler_output = image_features
  214. return image_outputs
  215. def get_placeholder_mask(
  216. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  217. ):
  218. """
  219. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  220. equal to the length of multimodal features. If the lengths are different, an error is raised.
  221. """
  222. if input_ids is None:
  223. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  224. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  225. )
  226. special_image_mask = special_image_mask.all(-1)
  227. else:
  228. special_image_mask = input_ids == self.config.image_token_id
  229. n_image_tokens = special_image_mask.sum()
  230. n_image_features = image_features.shape[0] * image_features.shape[1]
  231. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  232. torch_compilable_check(
  233. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  234. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
  235. )
  236. return special_image_mask
  237. @merge_with_config_defaults
  238. @can_return_tuple
  239. @auto_docstring
  240. def forward(
  241. self,
  242. input_ids: torch.LongTensor | None = None,
  243. pixel_values: torch.FloatTensor | None = None,
  244. attention_mask: torch.Tensor | None = None,
  245. position_ids: torch.LongTensor | None = None,
  246. past_key_values: Cache | None = None,
  247. inputs_embeds: torch.FloatTensor | None = None,
  248. vision_feature_layer: int | list[int] | list[int] | None = None,
  249. use_cache: bool | None = None,
  250. image_sizes: torch.Tensor | None = None,
  251. **kwargs: Unpack[TransformersKwargs],
  252. ) -> tuple | Mistral3ModelOutputWithPast:
  253. if (input_ids is None) ^ (inputs_embeds is not None):
  254. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  255. if inputs_embeds is None:
  256. inputs_embeds = self.get_input_embeddings()(input_ids)
  257. if pixel_values is not None:
  258. image_features = self.get_image_features(
  259. pixel_values=pixel_values,
  260. vision_feature_layer=vision_feature_layer,
  261. image_sizes=image_sizes,
  262. return_dict=True,
  263. ).pooler_output
  264. image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
  265. special_image_mask = self.get_placeholder_mask(
  266. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  267. )
  268. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  269. outputs = self.language_model(
  270. attention_mask=attention_mask,
  271. position_ids=position_ids,
  272. past_key_values=past_key_values,
  273. inputs_embeds=inputs_embeds,
  274. use_cache=use_cache,
  275. **kwargs,
  276. )
  277. return Mistral3ModelOutputWithPast(
  278. last_hidden_state=outputs.last_hidden_state,
  279. past_key_values=outputs.past_key_values,
  280. hidden_states=outputs.hidden_states,
  281. attentions=outputs.attentions,
  282. image_hidden_states=image_features if pixel_values is not None else None,
  283. )
  284. @auto_docstring(
  285. custom_intro="""
  286. The MISTRAL3 model which consists of a vision backbone and a language model.
  287. """
  288. )
  289. class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin):
  290. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  291. def __init__(self, config: Mistral3Config):
  292. super().__init__(config)
  293. self.model = Mistral3Model(config)
  294. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  295. self.post_init()
  296. def get_input_embeddings(self):
  297. return self.model.get_input_embeddings()
  298. def set_input_embeddings(self, value):
  299. self.model.set_input_embeddings(value)
  300. def get_output_embeddings(self) -> nn.Module:
  301. return self.lm_head
  302. @merge_with_config_defaults
  303. @can_return_tuple
  304. @auto_docstring
  305. def get_image_features(
  306. self,
  307. pixel_values: torch.FloatTensor,
  308. image_sizes: torch.Tensor,
  309. vision_feature_layer: int | list[int] | list[int] | None = None,
  310. **kwargs: Unpack[TransformersKwargs],
  311. ) -> tuple | BaseModelOutputWithPooling:
  312. return self.model.get_image_features(
  313. pixel_values=pixel_values,
  314. image_sizes=image_sizes,
  315. vision_feature_layer=vision_feature_layer,
  316. **kwargs,
  317. )
  318. @merge_with_config_defaults
  319. @can_return_tuple
  320. @auto_docstring
  321. def forward(
  322. self,
  323. input_ids: torch.LongTensor | None = None,
  324. pixel_values: torch.FloatTensor | None = None,
  325. attention_mask: torch.Tensor | None = None,
  326. position_ids: torch.LongTensor | None = None,
  327. past_key_values: Cache | None = None,
  328. inputs_embeds: torch.FloatTensor | None = None,
  329. labels: torch.LongTensor | None = None,
  330. use_cache: bool | None = None,
  331. logits_to_keep: int | torch.Tensor = 0,
  332. image_sizes: torch.Tensor | None = None,
  333. **kwargs: Unpack[TransformersKwargs],
  334. ) -> tuple | Mistral3CausalLMOutputWithPast:
  335. r"""
  336. Example:
  337. ```python
  338. >>> from PIL import Image
  339. >>> import httpx
  340. >>> from io import BytesIO
  341. >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
  342. >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
  343. >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
  344. >>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
  345. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  346. >>> with httpx.stream("GET", url) as response:
  347. ... image = Image.open(BytesIO(response.read()))
  348. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  349. >>> # Generate
  350. >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
  351. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  352. "What is the image?The image depicts two cats lying on a pink blanket."
  353. ```"""
  354. outputs = self.model(
  355. input_ids=input_ids,
  356. pixel_values=pixel_values,
  357. attention_mask=attention_mask,
  358. position_ids=position_ids,
  359. past_key_values=past_key_values,
  360. inputs_embeds=inputs_embeds,
  361. use_cache=use_cache,
  362. image_sizes=image_sizes,
  363. **kwargs,
  364. )
  365. hidden_states = outputs[0]
  366. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  367. logits = self.lm_head(hidden_states[:, slice_indices, :])
  368. loss = None
  369. if labels is not None:
  370. loss = self.loss_function(
  371. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  372. )
  373. return Mistral3CausalLMOutputWithPast(
  374. loss=loss,
  375. logits=logits,
  376. past_key_values=outputs.past_key_values,
  377. hidden_states=outputs.hidden_states,
  378. attentions=outputs.attentions,
  379. image_hidden_states=outputs.image_hidden_states,
  380. )
  381. def prepare_inputs_for_generation(
  382. self,
  383. input_ids,
  384. past_key_values=None,
  385. inputs_embeds=None,
  386. pixel_values=None,
  387. attention_mask=None,
  388. logits_to_keep=None,
  389. is_first_iteration=False,
  390. **kwargs,
  391. ):
  392. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  393. model_inputs = super().prepare_inputs_for_generation(
  394. input_ids,
  395. past_key_values=past_key_values,
  396. inputs_embeds=inputs_embeds,
  397. attention_mask=attention_mask,
  398. logits_to_keep=logits_to_keep,
  399. is_first_iteration=is_first_iteration,
  400. **kwargs,
  401. )
  402. if is_first_iteration or not kwargs.get("use_cache", True):
  403. # Pixel values are used only in the first iteration if available
  404. # In subsequent iterations, they are already merged with text and cached
  405. # NOTE: first iteration doesn't have to be prefill, it can be the first
  406. # iteration with a question and cached system prompt (continue generate from cache)
  407. model_inputs["pixel_values"] = pixel_values
  408. return model_inputs
  409. __all__ = ["Mistral3Model", "Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"]