modular_voxtral.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from torch import nn
  16. from ...activations import ACT2FN
  17. from ...cache_utils import Cache
  18. from ...generation import GenerationMixin
  19. from ...modeling_outputs import (
  20. BaseModelOutputWithPast,
  21. BaseModelOutputWithPooling,
  22. CausalLMOutputWithPast,
  23. )
  24. from ...processing_utils import Unpack
  25. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  26. from ...utils.generic import merge_with_config_defaults
  27. from ...utils.output_capturing import capture_outputs
  28. from ..auto import AutoModel, AutoModelForCausalLM
  29. from ..qwen2_audio.modeling_qwen2_audio import (
  30. Qwen2AudioAttention,
  31. Qwen2AudioEncoder,
  32. Qwen2AudioEncoderLayer,
  33. Qwen2AudioPreTrainedModel,
  34. )
  35. from .configuration_voxtral import VoxtralConfig
  36. class VoxtralAttention(Qwen2AudioAttention):
  37. pass
  38. class VoxtralEncoderLayer(Qwen2AudioEncoderLayer):
  39. pass
  40. class VoxtralPreTrainedModel(Qwen2AudioPreTrainedModel):
  41. _supports_flex_attn = True
  42. _supports_cache_class = True
  43. _supports_attention_backend = True
  44. _can_compile_fullgraph = True
  45. _no_split_modules = None
  46. # TODO: @eustlb, I would really prefer to use WhisperEncoder but it's messing with modular
  47. @auto_docstring(
  48. custom_intro="""
  49. The Voxtral encoder, which is a Whisper encoder.
  50. """
  51. )
  52. class VoxtralEncoder(Qwen2AudioEncoder):
  53. _can_record_outputs = {
  54. "attentions": VoxtralAttention,
  55. "hidden_states": VoxtralEncoderLayer,
  56. }
  57. @merge_with_config_defaults
  58. @capture_outputs
  59. def forward(
  60. self,
  61. input_features,
  62. attention_mask=None,
  63. **kwargs: Unpack[TransformersKwargs],
  64. ) -> tuple | BaseModelOutputWithPooling:
  65. r"""
  66. Args:
  67. input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
  68. Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
  69. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
  70. `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
  71. `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
  72. and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
  73. attention_mask (`torch.Tensor`)`, *optional*):
  74. Voxtral does not support masking of the `input_features`, this argument is preserved for compatibility,
  75. but it is not used. By default the silence in the input log mel spectrogram are ignored.
  76. """
  77. expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
  78. if input_features.shape[-1] != expected_seq_length:
  79. raise ValueError(
  80. f"Voxtral expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
  81. )
  82. input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
  83. inputs_embeds = nn.functional.gelu(self.conv1(input_features))
  84. inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
  85. inputs_embeds = inputs_embeds.permute(0, 2, 1)
  86. embed_pos = self.embed_positions.weight
  87. hidden_states = (inputs_embeds + embed_pos).to(inputs_embeds.dtype)
  88. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  89. for idx, encoder_layer in enumerate(self.layers):
  90. hidden_states = encoder_layer(
  91. hidden_states,
  92. attention_mask=attention_mask,
  93. )
  94. hidden_states = self.layer_norm(hidden_states)
  95. return BaseModelOutputWithPooling(
  96. last_hidden_state=hidden_states,
  97. )
  98. class VoxtralMultiModalProjector(nn.Module):
  99. def __init__(self, config: VoxtralConfig):
  100. super().__init__()
  101. self.linear_1 = nn.Linear(config.audio_config.intermediate_size, config.text_config.hidden_size, bias=False)
  102. self.act = ACT2FN[config.projector_hidden_act]
  103. self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=False)
  104. def forward(self, audio_features):
  105. hidden_states = self.linear_1(audio_features)
  106. hidden_states = self.act(hidden_states)
  107. hidden_states = self.linear_2(hidden_states)
  108. return hidden_states
  109. @auto_docstring(
  110. custom_intro="""
  111. The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model.
  112. """
  113. )
  114. class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin):
  115. _keep_in_fp32_modules_strict = ["embed_positions"]
  116. def __init__(self, config):
  117. super().__init__(config)
  118. self.vocab_size = config.text_config.vocab_size
  119. self.audio_tower = AutoModel.from_config(config.audio_config)
  120. self.language_model = AutoModelForCausalLM.from_config(config.text_config)
  121. self.multi_modal_projector = VoxtralMultiModalProjector(config)
  122. # Initialize weights and apply final processing
  123. self.post_init()
  124. def get_input_embeddings(self):
  125. return self.language_model.get_input_embeddings()
  126. def set_input_embeddings(self, value):
  127. self.language_model.set_input_embeddings(value)
  128. def get_output_embeddings(self):
  129. return self.language_model.get_output_embeddings()
  130. def set_output_embeddings(self, new_embeddings):
  131. self.language_model.set_output_embeddings(new_embeddings)
  132. def set_decoder(self, decoder):
  133. self.language_model.set_decoder(decoder)
  134. def get_decoder(self):
  135. return self.language_model.get_decoder()
  136. @can_return_tuple
  137. @auto_docstring(
  138. custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector."
  139. )
  140. def get_audio_features(
  141. self, input_features: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  142. ) -> tuple | BaseModelOutputWithPooling:
  143. r"""
  144. input_features (`torch.FloatTensor`):
  145. Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
  146. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
  147. `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
  148. `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
  149. and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
  150. """
  151. audio_outputs = self.audio_tower(input_features, return_dict=True, **kwargs)
  152. audio_hidden_states = audio_outputs.last_hidden_state
  153. audio_hidden_states = audio_hidden_states.reshape(-1, self.config.audio_config.intermediate_size)
  154. audio_embeds = self.multi_modal_projector(audio_hidden_states)
  155. audio_outputs.pooler_output = audio_embeds
  156. return audio_outputs
  157. @can_return_tuple
  158. @auto_docstring
  159. def forward(
  160. self,
  161. input_ids: torch.LongTensor | None = None,
  162. input_features: torch.FloatTensor | None = None,
  163. attention_mask: torch.Tensor | None = None,
  164. position_ids: torch.LongTensor | None = None,
  165. past_key_values: Cache | None = None,
  166. inputs_embeds: torch.FloatTensor | None = None,
  167. labels: torch.LongTensor | None = None,
  168. use_cache: bool | None = None,
  169. logits_to_keep: int | torch.Tensor = 0,
  170. **kwargs: Unpack[TransformersKwargs],
  171. ) -> CausalLMOutputWithPast:
  172. r"""
  173. Example:
  174. ```python
  175. >>> from transformers import VoxtralForConditionalGeneration, AutoProcessor
  176. >>> import torch
  177. >>> device = "cuda" if torch.cuda.is_available() else "cpu"
  178. >>> repo_id = "mistralai/Voxtral-Mini-3B-2507"
  179. >>> processor = AutoProcessor.from_pretrained(repo_id)
  180. >>> model = VoxtralForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map=device)
  181. >>> conversation = [
  182. {
  183. "role": "user",
  184. "content": [
  185. {
  186. "type": "audio",
  187. "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/dude_where_is_my_car.wav",
  188. },
  189. {"type": "text", "text": "What can you tell me about this audio?"},
  190. ],
  191. }
  192. ]
  193. >>> inputs = processor.apply_chat_template(conversation)
  194. >>> inputs = inputs.to(device, dtype=torch.bfloat16)
  195. >>> outputs = model.generate(**inputs, max_new_tokens=30)
  196. >>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
  197. ["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."]
  198. ```"""
  199. if inputs_embeds is None:
  200. inputs_embeds = self.get_input_embeddings()(input_ids)
  201. if input_features is not None and input_ids is not None:
  202. audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output
  203. # replace text-audio token placeholders with audio embeddings
  204. audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
  205. inputs_embeds = inputs_embeds.masked_scatter(
  206. audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
  207. )
  208. outputs: BaseModelOutputWithPast = self.language_model(
  209. attention_mask=attention_mask,
  210. position_ids=position_ids,
  211. past_key_values=past_key_values,
  212. inputs_embeds=inputs_embeds,
  213. labels=labels,
  214. use_cache=use_cache,
  215. logits_to_keep=logits_to_keep,
  216. **kwargs,
  217. )
  218. return outputs
  219. def prepare_inputs_for_generation(self, *args, **kwargs):
  220. # Overwritten -- we should not pass input_features when we are in cached decoding stage
  221. input_features = kwargs.pop("input_features", None)
  222. is_first_iteration = kwargs.get("is_first_iteration", False)
  223. model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
  224. if is_first_iteration or not kwargs.get("use_cache", True):
  225. # input_features should only be passed when we are not in cached decoding stage
  226. model_inputs["input_features"] = input_features
  227. return model_inputs
  228. __all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralForConditionalGeneration"]