modeling_musicflamingo.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/musicflamingo/modular_musicflamingo.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_musicflamingo.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2026 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
  8. # reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. from math import pi
  23. from typing import Optional
  24. from torch import Tensor, broadcast_tensors, nn
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache
  28. from ...generation import GenerationMixin
  29. from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
  30. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
  31. from ...modeling_utils import PreTrainedModel
  32. from ...processing_utils import Unpack
  33. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available
  34. from ..auto import AutoModel, AutoModelForCausalLM
  35. from .configuration_musicflamingo import MusicFlamingoConfig
  36. if is_torch_available():
  37. import torch
  38. class MusicFlamingoRotaryEmbedding(nn.Module):
  39. """Rotary time embedding module used by MusicFlamingo checkpoints.
  40. This is a checkpoint-faithful integration, not a direct implementation of the RoTE formulation described in
  41. (Goel et al., 2024): https://arxiv.org/abs/2410.12109. It applies axial rotary embeddings over the window index
  42. within each audio sample and the encoder time index within each window, then modulates both axes with absolute
  43. timestamps in seconds.
  44. """
  45. inv_freq: torch.Tensor # fix linting for `register_buffer`
  46. def __init__(self, config: MusicFlamingoConfig, device=None):
  47. super().__init__()
  48. self.max_seq_len_cached = config.max_position_embeddings
  49. self.original_max_seq_len = config.max_position_embeddings
  50. self.config = config
  51. self.rope_type = self.config.rope_parameters["rope_type"]
  52. rope_init_fn: Callable = self.compute_default_rope_parameters
  53. if self.rope_type != "default":
  54. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  55. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  56. self.register_buffer("inv_freq", inv_freq, persistent=False)
  57. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  58. position_angles = self._compute_position_angles(self.inv_freq)
  59. self.register_buffer("position_angles", position_angles, persistent=False)
  60. @staticmethod
  61. def compute_default_rope_parameters(
  62. config: MusicFlamingoConfig | None = None,
  63. device: Optional["torch.device"] = None,
  64. seq_len: int | None = None,
  65. ) -> tuple["torch.Tensor", float]:
  66. """
  67. Computes the inverse frequencies according to the original RoPE implementation
  68. Args:
  69. config ([`~transformers.PreTrainedConfig`]):
  70. The model configuration.
  71. device (`torch.device`):
  72. The device to use for initialization of the inverse frequencies.
  73. seq_len (`int`, *optional*):
  74. The current sequence length. Unused for this type of RoPE.
  75. Returns:
  76. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  77. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  78. """
  79. base = config.rope_parameters["rope_theta"]
  80. partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
  81. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  82. dim = int(head_dim * partial_rotary_factor)
  83. attention_factor = 1.0 # Unused in this type of RoPE
  84. # Compute the inverse frequencies
  85. inv_freq = 1.0 / (
  86. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  87. )
  88. return inv_freq, attention_factor
  89. @torch.no_grad()
  90. def forward(self, timestamps: Tensor, seq_len: int) -> tuple[Tensor, Tensor]:
  91. """Compute 2D axial rotary embeddings for window and time dimensions."""
  92. # Compute frequencies for the window axis, accounting for x4 due to the downsampling in the audio encoder (conv2 and avg pooling)
  93. window_starts = timestamps[:, 0].to(device=self.inv_freq.device, dtype=self.inv_freq.dtype)
  94. window_duration = self.config.audio_frame_step * 4 * seq_len
  95. window_positions = torch.round(window_starts / window_duration) / self.max_seq_len_cached
  96. window_freqs = window_positions.unsqueeze(-1) * self.inv_freq
  97. window_freqs = torch.repeat_interleave(window_freqs, 2, dim=-1)
  98. # Broadcasting and apply time-based angle modulation
  99. window_freqs = window_freqs[:, None, :]
  100. time_freqs = self.position_angles[:seq_len][None, :, :]
  101. window_freqs, time_freqs = broadcast_tensors(window_freqs, time_freqs)
  102. freqs = torch.cat((window_freqs, time_freqs), dim=-1)
  103. angle = (-timestamps * 2 * pi).to(freqs)
  104. freqs = freqs * angle.unsqueeze(-1)
  105. return freqs.cos(), freqs.sin()
  106. def _compute_position_angles(self, inv_freq):
  107. positions = torch.arange(int(self.max_seq_len_cached), device=inv_freq.device, dtype=inv_freq.dtype)
  108. positions = positions / self.max_seq_len_cached * (2 * pi)
  109. position_angles = positions.unsqueeze(-1) * inv_freq
  110. position_angles = torch.repeat_interleave(position_angles, 2, dim=-1)
  111. return position_angles.to(dtype=inv_freq.dtype)
  112. @auto_docstring
  113. class MusicFlamingoPreTrainedModel(PreTrainedModel):
  114. config: MusicFlamingoConfig
  115. base_model_prefix = "model"
  116. input_modalities = ("audio", "text")
  117. supports_gradient_checkpointing = True
  118. _no_split_modules = None
  119. _skip_keys_device_placement = "past_key_values"
  120. _supports_flash_attn = True
  121. _supports_sdpa = True
  122. @torch.no_grad()
  123. def _init_weights(self, module):
  124. super()._init_weights(module)
  125. if isinstance(module, MusicFlamingoRotaryEmbedding):
  126. buffer_value = module._compute_position_angles(module.inv_freq)
  127. init.copy_(module.position_angles, buffer_value)
  128. class MusicFlamingoMultiModalProjector(nn.Module):
  129. """
  130. Audio adaptor (small MLP) that projects MusicFlamingoEncoder features
  131. to the LLM embedding space so they can replace `<sound>` tokens.
  132. """
  133. def __init__(self, config: MusicFlamingoConfig):
  134. super().__init__()
  135. self.linear_1 = nn.Linear(
  136. config.audio_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias
  137. )
  138. self.act = ACT2FN[config.projector_hidden_act]
  139. self.linear_2 = nn.Linear(
  140. config.text_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias
  141. )
  142. def forward(self, audio_features):
  143. hidden_states = self.linear_1(audio_features)
  144. hidden_states = self.act(hidden_states)
  145. hidden_states = self.linear_2(hidden_states)
  146. return hidden_states
  147. def rotate_half(x):
  148. x = x.reshape(*x.shape[:-1], -1, 2)
  149. x1, x2 = x.unbind(dim=-1)
  150. x = torch.stack((-x2, x1), dim=-1)
  151. return x.flatten(-2)
  152. def apply_rotary_time_emb(hidden_states, cos, sin):
  153. original_dtype = hidden_states.dtype
  154. hidden_states = hidden_states.to(torch.float64)
  155. cos = cos.to(hidden_states)
  156. sin = sin.to(hidden_states)
  157. rot_dim = cos.shape[-1]
  158. passthrough = hidden_states[..., rot_dim:]
  159. rotated = hidden_states[..., :rot_dim]
  160. rotated = (rotated * cos) + (rotate_half(rotated) * sin)
  161. return torch.cat((rotated, passthrough), dim=-1).to(original_dtype)
  162. @auto_docstring(
  163. custom_intro="""
  164. The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model.
  165. """
  166. )
  167. class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin):
  168. _keep_in_fp32_modules_strict = None
  169. _tp_plan = None
  170. _pp_plan = None
  171. def __init__(self, config: MusicFlamingoConfig):
  172. super().__init__(config)
  173. self.vocab_size = config.text_config.vocab_size
  174. self.audio_tower = AutoModel.from_config(config.audio_config)
  175. self.language_model = AutoModelForCausalLM.from_config(config.text_config)
  176. self.multi_modal_projector = MusicFlamingoMultiModalProjector(config)
  177. self.pos_emb = MusicFlamingoRotaryEmbedding(config)
  178. # Initialize weights and apply final processing
  179. self.post_init()
  180. def get_input_embeddings(self):
  181. return self.language_model.get_input_embeddings()
  182. def set_input_embeddings(self, value):
  183. self.language_model.set_input_embeddings(value)
  184. def get_output_embeddings(self):
  185. return self.language_model.get_output_embeddings()
  186. def set_output_embeddings(self, new_embeddings):
  187. self.language_model.set_output_embeddings(new_embeddings)
  188. def set_decoder(self, decoder):
  189. self.language_model.set_decoder(decoder)
  190. def get_decoder(self):
  191. return self.language_model.get_decoder()
  192. @can_return_tuple
  193. @auto_docstring(
  194. custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector."
  195. )
  196. def get_audio_features(
  197. self,
  198. input_features: torch.FloatTensor,
  199. input_features_mask: torch.Tensor,
  200. input_ids: torch.LongTensor,
  201. **kwargs: Unpack[TransformersKwargs],
  202. ) -> tuple | BaseModelOutputWithPooling:
  203. r"""
  204. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
  205. Mask to avoid performing attention on padded feature indices.
  206. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  207. Token ids containing the audio token ID placeholders, for reconstructing rotary time embedding timestamps.
  208. """
  209. audio_output = self.audio_tower(
  210. input_features,
  211. input_features_mask=input_features_mask,
  212. return_dict=True,
  213. **kwargs,
  214. )
  215. hidden_states = audio_output.last_hidden_state
  216. _, post_lengths = self.audio_tower._get_feat_extract_output_lengths(input_features_mask.sum(-1).to(torch.long))
  217. audio_timestamps = self._build_audio_timestamps(input_ids, post_lengths, hidden_states.shape[-2])
  218. cos, sin = self.pos_emb(audio_timestamps.to(hidden_states.device), seq_len=hidden_states.shape[-2])
  219. hidden_states = apply_rotary_time_emb(hidden_states, cos, sin)
  220. audio_embeds = self.multi_modal_projector(hidden_states)
  221. # Mask according to the audio tower output lengths, accounting for both conv downsampling and final avg pooling
  222. valid_mask = torch.arange(audio_embeds.shape[1], device=post_lengths.device)[None, :] < post_lengths[:, None]
  223. audio_output.pooler_output = audio_embeds[valid_mask.to(audio_embeds.device)]
  224. return audio_output
  225. @can_return_tuple
  226. @auto_docstring
  227. def forward(
  228. self,
  229. input_ids: torch.LongTensor | None = None,
  230. input_features: torch.FloatTensor | None = None,
  231. input_features_mask: torch.Tensor | None = None,
  232. attention_mask: torch.Tensor | None = None,
  233. position_ids: torch.LongTensor | None = None,
  234. past_key_values: Cache | None = None,
  235. inputs_embeds: torch.FloatTensor | None = None,
  236. labels: torch.LongTensor | None = None,
  237. use_cache: bool | None = None,
  238. logits_to_keep: int | torch.Tensor = 0,
  239. **kwargs: Unpack[TransformersKwargs],
  240. ) -> CausalLMOutputWithPast:
  241. r"""
  242. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*):
  243. Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
  244. - 1 for tokens that are **not masked**,
  245. - 0 for tokens that are **masked**.
  246. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  247. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  248. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  249. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  250. Example:
  251. ```python
  252. >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor
  253. >>> model_id = "nvidia/music-flamingo-2601-hf"
  254. >>> processor = AutoProcessor.from_pretrained(model_id)
  255. >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto")
  256. >>> conversation = [
  257. >>> {
  258. >>> "role": "user",
  259. >>> "content": [
  260. >>> {
  261. >>> "type": "text",
  262. >>> "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.",
  263. >>> },
  264. >>> {
  265. >>> "type": "audio",
  266. >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3",
  267. >>> },
  268. >>> ],
  269. >>> }
  270. >>> ]
  271. >>> inputs = processor.apply_chat_template(
  272. >>> conversation,
  273. >>> tokenize=True,
  274. >>> add_generation_prompt=True,
  275. >>> return_dict=True,
  276. >>> ).to(model.device, model.dtype)
  277. >>> outputs = model.generate(**inputs, max_new_tokens=100)
  278. >>> decoded_outputs = processor.batch_decode(
  279. >>> outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
  280. >>> )
  281. >>> print(decoded_outputs)
  282. ["This track is an uplifting Eurodance-style Trance-Pop anthem..."]
  283. ```"""
  284. if inputs_embeds is None:
  285. inputs_embeds = self.get_input_embeddings()(input_ids)
  286. if input_features is not None and input_ids is not None:
  287. audio_embeds = self.get_audio_features(
  288. input_features, input_features_mask, input_ids=input_ids, return_dict=True
  289. ).pooler_output
  290. # replace text-audio token placeholders with audio embeddings
  291. audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
  292. inputs_embeds = inputs_embeds.masked_scatter(
  293. audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
  294. )
  295. outputs: CausalLMOutputWithPast = self.language_model(
  296. inputs_embeds=inputs_embeds,
  297. attention_mask=attention_mask,
  298. position_ids=position_ids,
  299. past_key_values=past_key_values,
  300. labels=labels,
  301. use_cache=use_cache,
  302. logits_to_keep=logits_to_keep,
  303. **kwargs,
  304. )
  305. return outputs
  306. def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs):
  307. input_features = kwargs.pop("input_features", None)
  308. input_features_mask = kwargs.pop("input_features_mask", None)
  309. model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
  310. if is_first_iteration or not model_inputs.get("use_cache", False):
  311. if input_features is not None:
  312. model_inputs["input_features"] = input_features
  313. if input_features_mask is not None:
  314. model_inputs["input_features_mask"] = input_features_mask
  315. return model_inputs
  316. def _build_audio_timestamps(
  317. self,
  318. input_ids: torch.LongTensor,
  319. post_lengths: torch.LongTensor,
  320. max_post_length: int,
  321. ) -> torch.FloatTensor:
  322. audio_token_mask = input_ids == self.config.audio_token_id
  323. diff = torch.diff(torch.nn.functional.pad(audio_token_mask.int(), (1, 1), value=0), dim=1)
  324. _, starts = torch.where(diff == 1)
  325. _, ends = torch.where(diff == -1)
  326. sample_lengths = (ends - starts).to(torch.long)
  327. # Account for 4x downsampling in audio encoder (conv2 and avg pooling)
  328. audio_embed_frame_step = self.config.audio_frame_step * 4
  329. frame_offsets = (
  330. torch.arange(max_post_length, device=post_lengths.device, dtype=torch.float32) * audio_embed_frame_step
  331. )
  332. # Map each encoder output row to its audio sample using token counts
  333. cumsum_post = torch.cat([torch.zeros(1, device=post_lengths.device), torch.cumsum(post_lengths, dim=0)[:-1]])
  334. cumsum_samples = torch.cumsum(sample_lengths, dim=0)
  335. sample_indices = torch.searchsorted(cumsum_samples, cumsum_post, right=True)
  336. # Compute window index within each sample (0, 1, 2, ... then reset for next sample)
  337. sample_start_rows = torch.searchsorted(
  338. sample_indices, torch.arange(sample_lengths.shape[0], device=post_lengths.device)
  339. )
  340. window_indices = (
  341. torch.arange(post_lengths.shape[0], device=post_lengths.device) - sample_start_rows[sample_indices]
  342. )
  343. # Compute timestamps
  344. return window_indices.unsqueeze(1) * max_post_length * audio_embed_frame_step + frame_offsets
  345. __all__ = ["MusicFlamingoForConditionalGeneration", "MusicFlamingoPreTrainedModel"]