| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/voxtral/modular_voxtral.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_voxtral.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from collections.abc import Callable
- import torch
- from torch import nn
- from ...activations import ACT2FN
- from ...cache_utils import Cache
- from ...generation import GenerationMixin
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
- from ...utils.generic import merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from ..auto import AutoModel, AutoModelForCausalLM
- from .configuration_voxtral import VoxtralConfig, VoxtralEncoderConfig
- logger = logging.get_logger(__name__)
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float | None = None,
- dropout: float = 0.0,
- **kwargs,
- ):
- if scaling is None:
- scaling = query.size(-1) ** -0.5
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class VoxtralAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(
- self,
- embed_dim: int,
- num_heads: int,
- dropout: float = 0.0,
- is_decoder: bool = False,
- bias: bool = True,
- is_causal: bool = False,
- layer_idx: int | None = None,
- config: VoxtralConfig | None = None,
- ):
- super().__init__()
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.dropout = dropout
- self.head_dim = embed_dim // num_heads
- self.config = config
- if (self.head_dim * num_heads) != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
- f" and `num_heads`: {num_heads})."
- )
- self.scaling = self.head_dim**-0.5
- self.is_decoder = is_decoder
- self.is_causal = is_causal
- if layer_idx is None and is_decoder:
- logger.warning_once(
- f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
- "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
- self.layer_idx = layer_idx
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool = False,
- **kwargs,
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
- """Input shape: Batch x Time x Channel"""
- bsz, tgt_len, _ = hidden_states.size()
- # Scaling is susceptible to floating point arithmetics' inprecisions
- # which can lead to different results (this is dependent from model
- # to model, e.g. whisper is one such case). We therefore keep the
- # original order of scaling to follow the original implementation
- # and enforce no scaling (1.0) in the attention call below.
- query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.dropout,
- scaling=1.0,
- output_attentions=output_attentions,
- **kwargs,
- )
- attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights
- class VoxtralEncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: VoxtralConfig):
- super().__init__()
- self.embed_dim = config.d_model
- self.self_attn = VoxtralAttention(
- embed_dim=self.embed_dim,
- num_heads=config.encoder_attention_heads,
- dropout=config.attention_dropout,
- config=config,
- )
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.dropout = config.dropout
- self.activation_fn = ACT2FN[config.activation_function]
- self.activation_dropout = config.activation_dropout
- self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
- self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- """
- residual = hidden_states
- hidden_states = self.self_attn_layer_norm(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- **kwargs,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.final_layer_norm(hidden_states)
- hidden_states = self.activation_fn(self.fc1(hidden_states))
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
- hidden_states = self.fc2(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- if hidden_states.dtype == torch.float16:
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- return hidden_states
- @auto_docstring
- class VoxtralPreTrainedModel(PreTrainedModel):
- config: VoxtralConfig
- base_model_prefix = "model"
- input_modalities = ("audio", "text")
- supports_gradient_checkpointing = True
- _no_split_modules = None
- _skip_keys_device_placement = "past_key_values"
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _supports_cache_class = True
- _supports_attention_backend = True
- _can_compile_fullgraph = True
- @auto_docstring(
- custom_intro="""
- The Voxtral encoder, which is a Whisper encoder.
- """
- )
- class VoxtralEncoder(VoxtralPreTrainedModel):
- """
- Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
- [`VoxtralEncoderLayer`].
- Args:
- config: VoxtralEncoderConfig
- """
- # Ignore copy
- config: VoxtralEncoderConfig
- main_input_name = "input_features"
- input_modalities = "audio"
- _no_split_modules = ["VoxtralEncoderLayer"]
- _can_record_outputs = {
- "attentions": VoxtralAttention,
- "hidden_states": VoxtralEncoderLayer,
- }
- def __init__(self, config: VoxtralEncoderConfig):
- super().__init__(config)
- self.dropout = config.dropout
- self.layerdrop = config.encoder_layerdrop
- embed_dim = config.d_model
- self.num_mel_bins = config.num_mel_bins
- self.max_source_positions = config.max_source_positions
- self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
- self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
- self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
- self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
- self.embed_positions.requires_grad_(False)
- self.layers = nn.ModuleList([VoxtralEncoderLayer(config) for _ in range(config.encoder_layers)])
- self.layer_norm = nn.LayerNorm(config.d_model)
- # Ignore copy
- self.avg_pooler = nn.AvgPool1d(2, stride=2)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- def _freeze_parameters(self):
- for param in self.parameters():
- param.requires_grad = False
- self._requires_grad = False
- def get_input_embeddings(self) -> nn.Module:
- return self.conv1
- def set_input_embeddings(self, value: nn.Module):
- self.conv1 = value
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self,
- input_features,
- attention_mask=None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- r"""
- Args:
- input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
- Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
- obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
- `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
- `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
- and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
- attention_mask (`torch.Tensor`)`, *optional*):
- Voxtral does not support masking of the `input_features`, this argument is preserved for compatibility,
- but it is not used. By default the silence in the input log mel spectrogram are ignored.
- """
- expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
- if input_features.shape[-1] != expected_seq_length:
- raise ValueError(
- f"Voxtral expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
- )
- input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
- inputs_embeds = nn.functional.gelu(self.conv1(input_features))
- inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
- inputs_embeds = inputs_embeds.permute(0, 2, 1)
- embed_pos = self.embed_positions.weight
- hidden_states = (inputs_embeds + embed_pos).to(inputs_embeds.dtype)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- for idx, encoder_layer in enumerate(self.layers):
- hidden_states = encoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- )
- hidden_states = self.layer_norm(hidden_states)
- return BaseModelOutputWithPooling(
- last_hidden_state=hidden_states,
- )
- # Ignore copy
- def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
- """
- Computes the output length of the convolutional layers and the output length of the audio encoder
- """
- input_lengths = (input_lengths - 1) // 2 + 1
- output_lengths = (input_lengths - 2) // 2 + 1
- return input_lengths, output_lengths
- class VoxtralMultiModalProjector(nn.Module):
- def __init__(self, config: VoxtralConfig):
- super().__init__()
- self.linear_1 = nn.Linear(config.audio_config.intermediate_size, config.text_config.hidden_size, bias=False)
- self.act = ACT2FN[config.projector_hidden_act]
- self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=False)
- def forward(self, audio_features):
- hidden_states = self.linear_1(audio_features)
- hidden_states = self.act(hidden_states)
- hidden_states = self.linear_2(hidden_states)
- return hidden_states
- @auto_docstring(
- custom_intro="""
- The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model.
- """
- )
- class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin):
- _keep_in_fp32_modules_strict = ["embed_positions"]
- def __init__(self, config):
- super().__init__(config)
- self.vocab_size = config.text_config.vocab_size
- self.audio_tower = AutoModel.from_config(config.audio_config)
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
- self.multi_modal_projector = VoxtralMultiModalProjector(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.language_model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.language_model.set_input_embeddings(value)
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
- def get_decoder(self):
- return self.language_model.get_decoder()
- @can_return_tuple
- @auto_docstring(
- custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector."
- )
- def get_audio_features(
- self, input_features: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
- ) -> tuple | BaseModelOutputWithPooling:
- r"""
- input_features (`torch.FloatTensor`):
- Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
- obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
- `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
- `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
- and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
- """
- audio_outputs = self.audio_tower(input_features, return_dict=True, **kwargs)
- audio_hidden_states = audio_outputs.last_hidden_state
- audio_hidden_states = audio_hidden_states.reshape(-1, self.config.audio_config.intermediate_size)
- audio_embeds = self.multi_modal_projector(audio_hidden_states)
- audio_outputs.pooler_output = audio_embeds
- return audio_outputs
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- input_features: torch.FloatTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
- r"""
- Example:
- ```python
- >>> from transformers import VoxtralForConditionalGeneration, AutoProcessor
- >>> import torch
- >>> device = "cuda" if torch.cuda.is_available() else "cpu"
- >>> repo_id = "mistralai/Voxtral-Mini-3B-2507"
- >>> processor = AutoProcessor.from_pretrained(repo_id)
- >>> model = VoxtralForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map=device)
- >>> conversation = [
- {
- "role": "user",
- "content": [
- {
- "type": "audio",
- "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/dude_where_is_my_car.wav",
- },
- {"type": "text", "text": "What can you tell me about this audio?"},
- ],
- }
- ]
- >>> inputs = processor.apply_chat_template(conversation)
- >>> inputs = inputs.to(device, dtype=torch.bfloat16)
- >>> outputs = model.generate(**inputs, max_new_tokens=30)
- >>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
- ["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."]
- ```"""
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
- if input_features is not None and input_ids is not None:
- audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output
- # replace text-audio token placeholders with audio embeddings
- audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
- inputs_embeds = inputs_embeds.masked_scatter(
- audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
- )
- outputs: BaseModelOutputWithPast = self.language_model(
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- labels=labels,
- use_cache=use_cache,
- logits_to_keep=logits_to_keep,
- **kwargs,
- )
- return outputs
- def prepare_inputs_for_generation(self, *args, **kwargs):
- # Overwritten -- we should not pass input_features when we are in cached decoding stage
- input_features = kwargs.pop("input_features", None)
- is_first_iteration = kwargs.get("is_first_iteration", False)
- model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
- if is_first_iteration or not kwargs.get("use_cache", True):
- # input_features should only be passed when we are not in cached decoding stage
- model_inputs["input_features"] = input_features
- return model_inputs
- __all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralForConditionalGeneration"]
|