| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/musicflamingo/modular_musicflamingo.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_musicflamingo.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2026 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
- # reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Callable
- from math import pi
- from typing import Optional
- from torch import Tensor, broadcast_tensors, nn
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache
- from ...generation import GenerationMixin
- from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available
- from ..auto import AutoModel, AutoModelForCausalLM
- from .configuration_musicflamingo import MusicFlamingoConfig
- if is_torch_available():
- import torch
- class MusicFlamingoRotaryEmbedding(nn.Module):
- """Rotary time embedding module used by MusicFlamingo checkpoints.
- This is a checkpoint-faithful integration, not a direct implementation of the RoTE formulation described in
- (Goel et al., 2024): https://arxiv.org/abs/2410.12109. It applies axial rotary embeddings over the window index
- within each audio sample and the encoder time index within each window, then modulates both axes with absolute
- timestamps in seconds.
- """
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, config: MusicFlamingoConfig, device=None):
- super().__init__()
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.rope_type = self.config.rope_parameters["rope_type"]
- rope_init_fn: Callable = self.compute_default_rope_parameters
- if self.rope_type != "default":
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
- position_angles = self._compute_position_angles(self.inv_freq)
- self.register_buffer("position_angles", position_angles, persistent=False)
- @staticmethod
- def compute_default_rope_parameters(
- config: MusicFlamingoConfig | None = None,
- device: Optional["torch.device"] = None,
- seq_len: int | None = None,
- ) -> tuple["torch.Tensor", float]:
- """
- Computes the inverse frequencies according to the original RoPE implementation
- Args:
- config ([`~transformers.PreTrainedConfig`]):
- The model configuration.
- device (`torch.device`):
- The device to use for initialization of the inverse frequencies.
- seq_len (`int`, *optional*):
- The current sequence length. Unused for this type of RoPE.
- Returns:
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
- """
- base = config.rope_parameters["rope_theta"]
- partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
- head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
- dim = int(head_dim * partial_rotary_factor)
- attention_factor = 1.0 # Unused in this type of RoPE
- # Compute the inverse frequencies
- inv_freq = 1.0 / (
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
- )
- return inv_freq, attention_factor
- @torch.no_grad()
- def forward(self, timestamps: Tensor, seq_len: int) -> tuple[Tensor, Tensor]:
- """Compute 2D axial rotary embeddings for window and time dimensions."""
- # Compute frequencies for the window axis, accounting for x4 due to the downsampling in the audio encoder (conv2 and avg pooling)
- window_starts = timestamps[:, 0].to(device=self.inv_freq.device, dtype=self.inv_freq.dtype)
- window_duration = self.config.audio_frame_step * 4 * seq_len
- window_positions = torch.round(window_starts / window_duration) / self.max_seq_len_cached
- window_freqs = window_positions.unsqueeze(-1) * self.inv_freq
- window_freqs = torch.repeat_interleave(window_freqs, 2, dim=-1)
- # Broadcasting and apply time-based angle modulation
- window_freqs = window_freqs[:, None, :]
- time_freqs = self.position_angles[:seq_len][None, :, :]
- window_freqs, time_freqs = broadcast_tensors(window_freqs, time_freqs)
- freqs = torch.cat((window_freqs, time_freqs), dim=-1)
- angle = (-timestamps * 2 * pi).to(freqs)
- freqs = freqs * angle.unsqueeze(-1)
- return freqs.cos(), freqs.sin()
- def _compute_position_angles(self, inv_freq):
- positions = torch.arange(int(self.max_seq_len_cached), device=inv_freq.device, dtype=inv_freq.dtype)
- positions = positions / self.max_seq_len_cached * (2 * pi)
- position_angles = positions.unsqueeze(-1) * inv_freq
- position_angles = torch.repeat_interleave(position_angles, 2, dim=-1)
- return position_angles.to(dtype=inv_freq.dtype)
- @auto_docstring
- class MusicFlamingoPreTrainedModel(PreTrainedModel):
- config: MusicFlamingoConfig
- base_model_prefix = "model"
- input_modalities = ("audio", "text")
- supports_gradient_checkpointing = True
- _no_split_modules = None
- _skip_keys_device_placement = "past_key_values"
- _supports_flash_attn = True
- _supports_sdpa = True
- @torch.no_grad()
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, MusicFlamingoRotaryEmbedding):
- buffer_value = module._compute_position_angles(module.inv_freq)
- init.copy_(module.position_angles, buffer_value)
- class MusicFlamingoMultiModalProjector(nn.Module):
- """
- Audio adaptor (small MLP) that projects MusicFlamingoEncoder features
- to the LLM embedding space so they can replace `<sound>` tokens.
- """
- def __init__(self, config: MusicFlamingoConfig):
- super().__init__()
- self.linear_1 = nn.Linear(
- config.audio_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias
- )
- self.act = ACT2FN[config.projector_hidden_act]
- self.linear_2 = nn.Linear(
- config.text_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias
- )
- def forward(self, audio_features):
- hidden_states = self.linear_1(audio_features)
- hidden_states = self.act(hidden_states)
- hidden_states = self.linear_2(hidden_states)
- return hidden_states
- def rotate_half(x):
- x = x.reshape(*x.shape[:-1], -1, 2)
- x1, x2 = x.unbind(dim=-1)
- x = torch.stack((-x2, x1), dim=-1)
- return x.flatten(-2)
- def apply_rotary_time_emb(hidden_states, cos, sin):
- original_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float64)
- cos = cos.to(hidden_states)
- sin = sin.to(hidden_states)
- rot_dim = cos.shape[-1]
- passthrough = hidden_states[..., rot_dim:]
- rotated = hidden_states[..., :rot_dim]
- rotated = (rotated * cos) + (rotate_half(rotated) * sin)
- return torch.cat((rotated, passthrough), dim=-1).to(original_dtype)
- @auto_docstring(
- custom_intro="""
- The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model.
- """
- )
- class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin):
- _keep_in_fp32_modules_strict = None
- _tp_plan = None
- _pp_plan = None
- def __init__(self, config: MusicFlamingoConfig):
- super().__init__(config)
- self.vocab_size = config.text_config.vocab_size
- self.audio_tower = AutoModel.from_config(config.audio_config)
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
- self.multi_modal_projector = MusicFlamingoMultiModalProjector(config)
- self.pos_emb = MusicFlamingoRotaryEmbedding(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.language_model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.language_model.set_input_embeddings(value)
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
- def get_decoder(self):
- return self.language_model.get_decoder()
- @can_return_tuple
- @auto_docstring(
- custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector."
- )
- def get_audio_features(
- self,
- input_features: torch.FloatTensor,
- input_features_mask: torch.Tensor,
- input_ids: torch.LongTensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- r"""
- input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
- Mask to avoid performing attention on padded feature indices.
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Token ids containing the audio token ID placeholders, for reconstructing rotary time embedding timestamps.
- """
- audio_output = self.audio_tower(
- input_features,
- input_features_mask=input_features_mask,
- return_dict=True,
- **kwargs,
- )
- hidden_states = audio_output.last_hidden_state
- _, post_lengths = self.audio_tower._get_feat_extract_output_lengths(input_features_mask.sum(-1).to(torch.long))
- audio_timestamps = self._build_audio_timestamps(input_ids, post_lengths, hidden_states.shape[-2])
- cos, sin = self.pos_emb(audio_timestamps.to(hidden_states.device), seq_len=hidden_states.shape[-2])
- hidden_states = apply_rotary_time_emb(hidden_states, cos, sin)
- audio_embeds = self.multi_modal_projector(hidden_states)
- # Mask according to the audio tower output lengths, accounting for both conv downsampling and final avg pooling
- valid_mask = torch.arange(audio_embeds.shape[1], device=post_lengths.device)[None, :] < post_lengths[:, None]
- audio_output.pooler_output = audio_embeds[valid_mask.to(audio_embeds.device)]
- return audio_output
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- input_features: torch.FloatTensor | None = None,
- input_features_mask: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
- r"""
- input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*):
- Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Example:
- ```python
- >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor
- >>> model_id = "nvidia/music-flamingo-2601-hf"
- >>> processor = AutoProcessor.from_pretrained(model_id)
- >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto")
- >>> conversation = [
- >>> {
- >>> "role": "user",
- >>> "content": [
- >>> {
- >>> "type": "text",
- >>> "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.",
- >>> },
- >>> {
- >>> "type": "audio",
- >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3",
- >>> },
- >>> ],
- >>> }
- >>> ]
- >>> inputs = processor.apply_chat_template(
- >>> conversation,
- >>> tokenize=True,
- >>> add_generation_prompt=True,
- >>> return_dict=True,
- >>> ).to(model.device, model.dtype)
- >>> outputs = model.generate(**inputs, max_new_tokens=100)
- >>> decoded_outputs = processor.batch_decode(
- >>> outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
- >>> )
- >>> print(decoded_outputs)
- ["This track is an uplifting Eurodance-style Trance-Pop anthem..."]
- ```"""
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
- if input_features is not None and input_ids is not None:
- audio_embeds = self.get_audio_features(
- input_features, input_features_mask, input_ids=input_ids, return_dict=True
- ).pooler_output
- # replace text-audio token placeholders with audio embeddings
- audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
- inputs_embeds = inputs_embeds.masked_scatter(
- audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
- )
- outputs: CausalLMOutputWithPast = self.language_model(
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- labels=labels,
- use_cache=use_cache,
- logits_to_keep=logits_to_keep,
- **kwargs,
- )
- return outputs
- def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs):
- input_features = kwargs.pop("input_features", None)
- input_features_mask = kwargs.pop("input_features_mask", None)
- model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
- if is_first_iteration or not model_inputs.get("use_cache", False):
- if input_features is not None:
- model_inputs["input_features"] = input_features
- if input_features_mask is not None:
- model_inputs["input_features_mask"] = input_features_mask
- return model_inputs
- def _build_audio_timestamps(
- self,
- input_ids: torch.LongTensor,
- post_lengths: torch.LongTensor,
- max_post_length: int,
- ) -> torch.FloatTensor:
- audio_token_mask = input_ids == self.config.audio_token_id
- diff = torch.diff(torch.nn.functional.pad(audio_token_mask.int(), (1, 1), value=0), dim=1)
- _, starts = torch.where(diff == 1)
- _, ends = torch.where(diff == -1)
- sample_lengths = (ends - starts).to(torch.long)
- # Account for 4x downsampling in audio encoder (conv2 and avg pooling)
- audio_embed_frame_step = self.config.audio_frame_step * 4
- frame_offsets = (
- torch.arange(max_post_length, device=post_lengths.device, dtype=torch.float32) * audio_embed_frame_step
- )
- # Map each encoder output row to its audio sample using token counts
- cumsum_post = torch.cat([torch.zeros(1, device=post_lengths.device), torch.cumsum(post_lengths, dim=0)[:-1]])
- cumsum_samples = torch.cumsum(sample_lengths, dim=0)
- sample_indices = torch.searchsorted(cumsum_samples, cumsum_post, right=True)
- # Compute window index within each sample (0, 1, 2, ... then reset for next sample)
- sample_start_rows = torch.searchsorted(
- sample_indices, torch.arange(sample_lengths.shape[0], device=post_lengths.device)
- )
- window_indices = (
- torch.arange(post_lengths.shape[0], device=post_lengths.device) - sample_start_rows[sample_indices]
- )
- # Compute timestamps
- return window_indices.unsqueeze(1) * max_post_length * audio_embed_frame_step + frame_offsets
- __all__ = ["MusicFlamingoForConditionalGeneration", "MusicFlamingoPreTrainedModel"]
|