modeling_voxtral.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/voxtral/modular_voxtral.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_voxtral.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. import torch
  23. from torch import nn
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache
  26. from ...generation import GenerationMixin
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
  29. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  32. from ...utils.generic import merge_with_config_defaults
  33. from ...utils.output_capturing import capture_outputs
  34. from ..auto import AutoModel, AutoModelForCausalLM
  35. from .configuration_voxtral import VoxtralConfig, VoxtralEncoderConfig
  36. logger = logging.get_logger(__name__)
  37. def eager_attention_forward(
  38. module: nn.Module,
  39. query: torch.Tensor,
  40. key: torch.Tensor,
  41. value: torch.Tensor,
  42. attention_mask: torch.Tensor | None,
  43. scaling: float | None = None,
  44. dropout: float = 0.0,
  45. **kwargs,
  46. ):
  47. if scaling is None:
  48. scaling = query.size(-1) ** -0.5
  49. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  50. if attention_mask is not None:
  51. attn_weights = attn_weights + attention_mask
  52. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  53. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  54. attn_output = torch.matmul(attn_weights, value)
  55. attn_output = attn_output.transpose(1, 2).contiguous()
  56. return attn_output, attn_weights
  57. class VoxtralAttention(nn.Module):
  58. """Multi-headed attention from 'Attention Is All You Need' paper"""
  59. def __init__(
  60. self,
  61. embed_dim: int,
  62. num_heads: int,
  63. dropout: float = 0.0,
  64. is_decoder: bool = False,
  65. bias: bool = True,
  66. is_causal: bool = False,
  67. layer_idx: int | None = None,
  68. config: VoxtralConfig | None = None,
  69. ):
  70. super().__init__()
  71. self.embed_dim = embed_dim
  72. self.num_heads = num_heads
  73. self.dropout = dropout
  74. self.head_dim = embed_dim // num_heads
  75. self.config = config
  76. if (self.head_dim * num_heads) != self.embed_dim:
  77. raise ValueError(
  78. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  79. f" and `num_heads`: {num_heads})."
  80. )
  81. self.scaling = self.head_dim**-0.5
  82. self.is_decoder = is_decoder
  83. self.is_causal = is_causal
  84. if layer_idx is None and is_decoder:
  85. logger.warning_once(
  86. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  87. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  88. "when creating this class."
  89. )
  90. self.layer_idx = layer_idx
  91. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
  92. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  93. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  94. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  95. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  96. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  97. def forward(
  98. self,
  99. hidden_states: torch.Tensor,
  100. attention_mask: torch.Tensor | None = None,
  101. output_attentions: bool = False,
  102. **kwargs,
  103. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  104. """Input shape: Batch x Time x Channel"""
  105. bsz, tgt_len, _ = hidden_states.size()
  106. # Scaling is susceptible to floating point arithmetics' inprecisions
  107. # which can lead to different results (this is dependent from model
  108. # to model, e.g. whisper is one such case). We therefore keep the
  109. # original order of scaling to follow the original implementation
  110. # and enforce no scaling (1.0) in the attention call below.
  111. query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
  112. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  113. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  114. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  115. self.config._attn_implementation, eager_attention_forward
  116. )
  117. attn_output, attn_weights = attention_interface(
  118. self,
  119. query_states,
  120. key_states,
  121. value_states,
  122. attention_mask,
  123. dropout=0.0 if not self.training else self.dropout,
  124. scaling=1.0,
  125. output_attentions=output_attentions,
  126. **kwargs,
  127. )
  128. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  129. attn_output = self.out_proj(attn_output)
  130. return attn_output, attn_weights
  131. class VoxtralEncoderLayer(GradientCheckpointingLayer):
  132. def __init__(self, config: VoxtralConfig):
  133. super().__init__()
  134. self.embed_dim = config.d_model
  135. self.self_attn = VoxtralAttention(
  136. embed_dim=self.embed_dim,
  137. num_heads=config.encoder_attention_heads,
  138. dropout=config.attention_dropout,
  139. config=config,
  140. )
  141. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  142. self.dropout = config.dropout
  143. self.activation_fn = ACT2FN[config.activation_function]
  144. self.activation_dropout = config.activation_dropout
  145. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  146. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  147. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  148. def forward(
  149. self,
  150. hidden_states: torch.Tensor,
  151. attention_mask: torch.Tensor,
  152. **kwargs: Unpack[TransformersKwargs],
  153. ) -> torch.Tensor:
  154. """
  155. Args:
  156. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  157. attention_mask (`torch.FloatTensor`): attention mask of size
  158. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  159. """
  160. residual = hidden_states
  161. hidden_states = self.self_attn_layer_norm(hidden_states)
  162. hidden_states, _ = self.self_attn(
  163. hidden_states=hidden_states,
  164. attention_mask=attention_mask,
  165. **kwargs,
  166. )
  167. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  168. hidden_states = residual + hidden_states
  169. residual = hidden_states
  170. hidden_states = self.final_layer_norm(hidden_states)
  171. hidden_states = self.activation_fn(self.fc1(hidden_states))
  172. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  173. hidden_states = self.fc2(hidden_states)
  174. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  175. hidden_states = residual + hidden_states
  176. if hidden_states.dtype == torch.float16:
  177. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  178. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  179. return hidden_states
  180. @auto_docstring
  181. class VoxtralPreTrainedModel(PreTrainedModel):
  182. config: VoxtralConfig
  183. base_model_prefix = "model"
  184. input_modalities = ("audio", "text")
  185. supports_gradient_checkpointing = True
  186. _no_split_modules = None
  187. _skip_keys_device_placement = "past_key_values"
  188. _supports_flash_attn = True
  189. _supports_sdpa = True
  190. _supports_flex_attn = True
  191. _supports_cache_class = True
  192. _supports_attention_backend = True
  193. _can_compile_fullgraph = True
  194. @auto_docstring(
  195. custom_intro="""
  196. The Voxtral encoder, which is a Whisper encoder.
  197. """
  198. )
  199. class VoxtralEncoder(VoxtralPreTrainedModel):
  200. """
  201. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  202. [`VoxtralEncoderLayer`].
  203. Args:
  204. config: VoxtralEncoderConfig
  205. """
  206. # Ignore copy
  207. config: VoxtralEncoderConfig
  208. main_input_name = "input_features"
  209. input_modalities = "audio"
  210. _no_split_modules = ["VoxtralEncoderLayer"]
  211. _can_record_outputs = {
  212. "attentions": VoxtralAttention,
  213. "hidden_states": VoxtralEncoderLayer,
  214. }
  215. def __init__(self, config: VoxtralEncoderConfig):
  216. super().__init__(config)
  217. self.dropout = config.dropout
  218. self.layerdrop = config.encoder_layerdrop
  219. embed_dim = config.d_model
  220. self.num_mel_bins = config.num_mel_bins
  221. self.max_source_positions = config.max_source_positions
  222. self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  223. self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
  224. self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
  225. self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
  226. self.embed_positions.requires_grad_(False)
  227. self.layers = nn.ModuleList([VoxtralEncoderLayer(config) for _ in range(config.encoder_layers)])
  228. self.layer_norm = nn.LayerNorm(config.d_model)
  229. # Ignore copy
  230. self.avg_pooler = nn.AvgPool1d(2, stride=2)
  231. self.gradient_checkpointing = False
  232. # Initialize weights and apply final processing
  233. self.post_init()
  234. def _freeze_parameters(self):
  235. for param in self.parameters():
  236. param.requires_grad = False
  237. self._requires_grad = False
  238. def get_input_embeddings(self) -> nn.Module:
  239. return self.conv1
  240. def set_input_embeddings(self, value: nn.Module):
  241. self.conv1 = value
  242. @merge_with_config_defaults
  243. @capture_outputs
  244. def forward(
  245. self,
  246. input_features,
  247. attention_mask=None,
  248. **kwargs: Unpack[TransformersKwargs],
  249. ) -> tuple | BaseModelOutputWithPooling:
  250. r"""
  251. Args:
  252. input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
  253. Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
  254. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
  255. `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
  256. `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
  257. and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
  258. attention_mask (`torch.Tensor`)`, *optional*):
  259. Voxtral does not support masking of the `input_features`, this argument is preserved for compatibility,
  260. but it is not used. By default the silence in the input log mel spectrogram are ignored.
  261. """
  262. expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
  263. if input_features.shape[-1] != expected_seq_length:
  264. raise ValueError(
  265. f"Voxtral expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
  266. )
  267. input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
  268. inputs_embeds = nn.functional.gelu(self.conv1(input_features))
  269. inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
  270. inputs_embeds = inputs_embeds.permute(0, 2, 1)
  271. embed_pos = self.embed_positions.weight
  272. hidden_states = (inputs_embeds + embed_pos).to(inputs_embeds.dtype)
  273. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  274. for idx, encoder_layer in enumerate(self.layers):
  275. hidden_states = encoder_layer(
  276. hidden_states,
  277. attention_mask=attention_mask,
  278. )
  279. hidden_states = self.layer_norm(hidden_states)
  280. return BaseModelOutputWithPooling(
  281. last_hidden_state=hidden_states,
  282. )
  283. # Ignore copy
  284. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
  285. """
  286. Computes the output length of the convolutional layers and the output length of the audio encoder
  287. """
  288. input_lengths = (input_lengths - 1) // 2 + 1
  289. output_lengths = (input_lengths - 2) // 2 + 1
  290. return input_lengths, output_lengths
  291. class VoxtralMultiModalProjector(nn.Module):
  292. def __init__(self, config: VoxtralConfig):
  293. super().__init__()
  294. self.linear_1 = nn.Linear(config.audio_config.intermediate_size, config.text_config.hidden_size, bias=False)
  295. self.act = ACT2FN[config.projector_hidden_act]
  296. self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=False)
  297. def forward(self, audio_features):
  298. hidden_states = self.linear_1(audio_features)
  299. hidden_states = self.act(hidden_states)
  300. hidden_states = self.linear_2(hidden_states)
  301. return hidden_states
  302. @auto_docstring(
  303. custom_intro="""
  304. The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model.
  305. """
  306. )
  307. class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin):
  308. _keep_in_fp32_modules_strict = ["embed_positions"]
  309. def __init__(self, config):
  310. super().__init__(config)
  311. self.vocab_size = config.text_config.vocab_size
  312. self.audio_tower = AutoModel.from_config(config.audio_config)
  313. self.language_model = AutoModelForCausalLM.from_config(config.text_config)
  314. self.multi_modal_projector = VoxtralMultiModalProjector(config)
  315. # Initialize weights and apply final processing
  316. self.post_init()
  317. def get_input_embeddings(self):
  318. return self.language_model.get_input_embeddings()
  319. def set_input_embeddings(self, value):
  320. self.language_model.set_input_embeddings(value)
  321. def get_output_embeddings(self):
  322. return self.language_model.get_output_embeddings()
  323. def set_output_embeddings(self, new_embeddings):
  324. self.language_model.set_output_embeddings(new_embeddings)
  325. def set_decoder(self, decoder):
  326. self.language_model.set_decoder(decoder)
  327. def get_decoder(self):
  328. return self.language_model.get_decoder()
  329. @can_return_tuple
  330. @auto_docstring(
  331. custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector."
  332. )
  333. def get_audio_features(
  334. self, input_features: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  335. ) -> tuple | BaseModelOutputWithPooling:
  336. r"""
  337. input_features (`torch.FloatTensor`):
  338. Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
  339. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
  340. `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
  341. `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
  342. and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
  343. """
  344. audio_outputs = self.audio_tower(input_features, return_dict=True, **kwargs)
  345. audio_hidden_states = audio_outputs.last_hidden_state
  346. audio_hidden_states = audio_hidden_states.reshape(-1, self.config.audio_config.intermediate_size)
  347. audio_embeds = self.multi_modal_projector(audio_hidden_states)
  348. audio_outputs.pooler_output = audio_embeds
  349. return audio_outputs
  350. @can_return_tuple
  351. @auto_docstring
  352. def forward(
  353. self,
  354. input_ids: torch.LongTensor | None = None,
  355. input_features: torch.FloatTensor | None = None,
  356. attention_mask: torch.Tensor | None = None,
  357. position_ids: torch.LongTensor | None = None,
  358. past_key_values: Cache | None = None,
  359. inputs_embeds: torch.FloatTensor | None = None,
  360. labels: torch.LongTensor | None = None,
  361. use_cache: bool | None = None,
  362. logits_to_keep: int | torch.Tensor = 0,
  363. **kwargs: Unpack[TransformersKwargs],
  364. ) -> CausalLMOutputWithPast:
  365. r"""
  366. Example:
  367. ```python
  368. >>> from transformers import VoxtralForConditionalGeneration, AutoProcessor
  369. >>> import torch
  370. >>> device = "cuda" if torch.cuda.is_available() else "cpu"
  371. >>> repo_id = "mistralai/Voxtral-Mini-3B-2507"
  372. >>> processor = AutoProcessor.from_pretrained(repo_id)
  373. >>> model = VoxtralForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map=device)
  374. >>> conversation = [
  375. {
  376. "role": "user",
  377. "content": [
  378. {
  379. "type": "audio",
  380. "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/dude_where_is_my_car.wav",
  381. },
  382. {"type": "text", "text": "What can you tell me about this audio?"},
  383. ],
  384. }
  385. ]
  386. >>> inputs = processor.apply_chat_template(conversation)
  387. >>> inputs = inputs.to(device, dtype=torch.bfloat16)
  388. >>> outputs = model.generate(**inputs, max_new_tokens=30)
  389. >>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
  390. ["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."]
  391. ```"""
  392. if inputs_embeds is None:
  393. inputs_embeds = self.get_input_embeddings()(input_ids)
  394. if input_features is not None and input_ids is not None:
  395. audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output
  396. # replace text-audio token placeholders with audio embeddings
  397. audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
  398. inputs_embeds = inputs_embeds.masked_scatter(
  399. audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
  400. )
  401. outputs: BaseModelOutputWithPast = self.language_model(
  402. attention_mask=attention_mask,
  403. position_ids=position_ids,
  404. past_key_values=past_key_values,
  405. inputs_embeds=inputs_embeds,
  406. labels=labels,
  407. use_cache=use_cache,
  408. logits_to_keep=logits_to_keep,
  409. **kwargs,
  410. )
  411. return outputs
  412. def prepare_inputs_for_generation(self, *args, **kwargs):
  413. # Overwritten -- we should not pass input_features when we are in cached decoding stage
  414. input_features = kwargs.pop("input_features", None)
  415. is_first_iteration = kwargs.get("is_first_iteration", False)
  416. model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
  417. if is_first_iteration or not kwargs.get("use_cache", True):
  418. # input_features should only be passed when we are not in cached decoding stage
  419. model_inputs["input_features"] = input_features
  420. return model_inputs
  421. __all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralForConditionalGeneration"]