# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/musicflamingo/modular_musicflamingo.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_musicflamingo.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2026 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights # reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable from math import pi from typing import Optional from torch import Tensor, broadcast_tensors, nn from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available from ..auto import AutoModel, AutoModelForCausalLM from .configuration_musicflamingo import MusicFlamingoConfig if is_torch_available(): import torch class MusicFlamingoRotaryEmbedding(nn.Module): """Rotary time embedding module used by MusicFlamingo checkpoints. This is a checkpoint-faithful integration, not a direct implementation of the RoTE formulation described in (Goel et al., 2024): https://arxiv.org/abs/2410.12109. It applies axial rotary embeddings over the window index within each audio sample and the encoder time index within each window, then modulates both axes with absolute timestamps in seconds. """ inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: MusicFlamingoConfig, device=None): super().__init__() self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_type = self.config.rope_parameters["rope_type"] rope_init_fn: Callable = self.compute_default_rope_parameters if self.rope_type != "default": rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) position_angles = self._compute_position_angles(self.inv_freq) self.register_buffer("position_angles", position_angles, persistent=False) @staticmethod def compute_default_rope_parameters( config: MusicFlamingoConfig | None = None, device: Optional["torch.device"] = None, seq_len: int | None = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation Args: config ([`~transformers.PreTrainedConfig`]): The model configuration. device (`torch.device`): The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ base = config.rope_parameters["rope_theta"] partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads dim = int(head_dim * partial_rotary_factor) attention_factor = 1.0 # Unused in this type of RoPE # Compute the inverse frequencies inv_freq = 1.0 / ( base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) ) return inv_freq, attention_factor @torch.no_grad() def forward(self, timestamps: Tensor, seq_len: int) -> tuple[Tensor, Tensor]: """Compute 2D axial rotary embeddings for window and time dimensions.""" # Compute frequencies for the window axis, accounting for x4 due to the downsampling in the audio encoder (conv2 and avg pooling) window_starts = timestamps[:, 0].to(device=self.inv_freq.device, dtype=self.inv_freq.dtype) window_duration = self.config.audio_frame_step * 4 * seq_len window_positions = torch.round(window_starts / window_duration) / self.max_seq_len_cached window_freqs = window_positions.unsqueeze(-1) * self.inv_freq window_freqs = torch.repeat_interleave(window_freqs, 2, dim=-1) # Broadcasting and apply time-based angle modulation window_freqs = window_freqs[:, None, :] time_freqs = self.position_angles[:seq_len][None, :, :] window_freqs, time_freqs = broadcast_tensors(window_freqs, time_freqs) freqs = torch.cat((window_freqs, time_freqs), dim=-1) angle = (-timestamps * 2 * pi).to(freqs) freqs = freqs * angle.unsqueeze(-1) return freqs.cos(), freqs.sin() def _compute_position_angles(self, inv_freq): positions = torch.arange(int(self.max_seq_len_cached), device=inv_freq.device, dtype=inv_freq.dtype) positions = positions / self.max_seq_len_cached * (2 * pi) position_angles = positions.unsqueeze(-1) * inv_freq position_angles = torch.repeat_interleave(position_angles, 2, dim=-1) return position_angles.to(dtype=inv_freq.dtype) @auto_docstring class MusicFlamingoPreTrainedModel(PreTrainedModel): config: MusicFlamingoConfig base_model_prefix = "model" input_modalities = ("audio", "text") supports_gradient_checkpointing = True _no_split_modules = None _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, MusicFlamingoRotaryEmbedding): buffer_value = module._compute_position_angles(module.inv_freq) init.copy_(module.position_angles, buffer_value) class MusicFlamingoMultiModalProjector(nn.Module): """ Audio adaptor (small MLP) that projects MusicFlamingoEncoder features to the LLM embedding space so they can replace `` tokens. """ def __init__(self, config: MusicFlamingoConfig): super().__init__() self.linear_1 = nn.Linear( config.audio_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias ) self.act = ACT2FN[config.projector_hidden_act] self.linear_2 = nn.Linear( config.text_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias ) def forward(self, audio_features): hidden_states = self.linear_1(audio_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states def rotate_half(x): x = x.reshape(*x.shape[:-1], -1, 2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return x.flatten(-2) def apply_rotary_time_emb(hidden_states, cos, sin): original_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float64) cos = cos.to(hidden_states) sin = sin.to(hidden_states) rot_dim = cos.shape[-1] passthrough = hidden_states[..., rot_dim:] rotated = hidden_states[..., :rot_dim] rotated = (rotated * cos) + (rotate_half(rotated) * sin) return torch.cat((rotated, passthrough), dim=-1).to(original_dtype) @auto_docstring( custom_intro=""" The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. """ ) class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin): _keep_in_fp32_modules_strict = None _tp_plan = None _pp_plan = None def __init__(self, config: MusicFlamingoConfig): super().__init__(config) self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) self.language_model = AutoModelForCausalLM.from_config(config.text_config) self.multi_modal_projector = MusicFlamingoMultiModalProjector(config) self.pos_emb = MusicFlamingoRotaryEmbedding(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def get_output_embeddings(self): return self.language_model.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def set_decoder(self, decoder): self.language_model.set_decoder(decoder) def get_decoder(self): return self.language_model.get_decoder() @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." ) def get_audio_features( self, input_features: torch.FloatTensor, input_features_mask: torch.Tensor, input_ids: torch.LongTensor, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Token ids containing the audio token ID placeholders, for reconstructing rotary time embedding timestamps. """ audio_output = self.audio_tower( input_features, input_features_mask=input_features_mask, return_dict=True, **kwargs, ) hidden_states = audio_output.last_hidden_state _, post_lengths = self.audio_tower._get_feat_extract_output_lengths(input_features_mask.sum(-1).to(torch.long)) audio_timestamps = self._build_audio_timestamps(input_ids, post_lengths, hidden_states.shape[-2]) cos, sin = self.pos_emb(audio_timestamps.to(hidden_states.device), seq_len=hidden_states.shape[-2]) hidden_states = apply_rotary_time_emb(hidden_states, cos, sin) audio_embeds = self.multi_modal_projector(hidden_states) # Mask according to the audio tower output lengths, accounting for both conv downsampling and final avg pooling valid_mask = torch.arange(audio_embeds.shape[1], device=post_lengths.device)[None, :] < post_lengths[:, None] audio_output.pooler_output = audio_embeds[valid_mask.to(audio_embeds.device)] return audio_output @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, input_features: torch.FloatTensor | None = None, input_features_mask: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*): Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor >>> model_id = "nvidia/music-flamingo-2601-hf" >>> processor = AutoProcessor.from_pretrained(model_id) >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto") >>> conversation = [ >>> { >>> "role": "user", >>> "content": [ >>> { >>> "type": "text", >>> "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.", >>> }, >>> { >>> "type": "audio", >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3", >>> }, >>> ], >>> } >>> ] >>> inputs = processor.apply_chat_template( >>> conversation, >>> tokenize=True, >>> add_generation_prompt=True, >>> return_dict=True, >>> ).to(model.device, model.dtype) >>> outputs = model.generate(**inputs, max_new_tokens=100) >>> decoded_outputs = processor.batch_decode( >>> outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True >>> ) >>> print(decoded_outputs) ["This track is an uplifting Eurodance-style Trance-Pop anthem..."] ```""" if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_features( input_features, input_features_mask, input_ids=input_ids, return_dict=True ).pooler_output # replace text-audio token placeholders with audio embeddings audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) inputs_embeds = inputs_embeds.masked_scatter( audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) ) outputs: CausalLMOutputWithPast = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, labels=labels, use_cache=use_cache, logits_to_keep=logits_to_keep, **kwargs, ) return outputs def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): input_features = kwargs.pop("input_features", None) input_features_mask = kwargs.pop("input_features_mask", None) model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) if is_first_iteration or not model_inputs.get("use_cache", False): if input_features is not None: model_inputs["input_features"] = input_features if input_features_mask is not None: model_inputs["input_features_mask"] = input_features_mask return model_inputs def _build_audio_timestamps( self, input_ids: torch.LongTensor, post_lengths: torch.LongTensor, max_post_length: int, ) -> torch.FloatTensor: audio_token_mask = input_ids == self.config.audio_token_id diff = torch.diff(torch.nn.functional.pad(audio_token_mask.int(), (1, 1), value=0), dim=1) _, starts = torch.where(diff == 1) _, ends = torch.where(diff == -1) sample_lengths = (ends - starts).to(torch.long) # Account for 4x downsampling in audio encoder (conv2 and avg pooling) audio_embed_frame_step = self.config.audio_frame_step * 4 frame_offsets = ( torch.arange(max_post_length, device=post_lengths.device, dtype=torch.float32) * audio_embed_frame_step ) # Map each encoder output row to its audio sample using token counts cumsum_post = torch.cat([torch.zeros(1, device=post_lengths.device), torch.cumsum(post_lengths, dim=0)[:-1]]) cumsum_samples = torch.cumsum(sample_lengths, dim=0) sample_indices = torch.searchsorted(cumsum_samples, cumsum_post, right=True) # Compute window index within each sample (0, 1, 2, ... then reset for next sample) sample_start_rows = torch.searchsorted( sample_indices, torch.arange(sample_lengths.shape[0], device=post_lengths.device) ) window_indices = ( torch.arange(post_lengths.shape[0], device=post_lengths.device) - sample_start_rows[sample_indices] ) # Compute timestamps return window_indices.unsqueeze(1) * max_post_length * audio_embed_frame_step + frame_offsets __all__ = ["MusicFlamingoForConditionalGeneration", "MusicFlamingoPreTrainedModel"]