modular_musicflamingo.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. # Copyright 2026 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
  2. # reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import re
  16. from math import pi
  17. from huggingface_hub.dataclasses import strict
  18. from torch import Tensor, broadcast_tensors
  19. from ... import initialization as init
  20. from ...cache_utils import Cache
  21. from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
  22. from ...modeling_utils import PreTrainedModel
  23. from ...processing_utils import Unpack
  24. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available
  25. from ..audioflamingo3.configuration_audioflamingo3 import AudioFlamingo3Config
  26. from ..audioflamingo3.modeling_audioflamingo3 import (
  27. AudioFlamingo3ForConditionalGeneration,
  28. AudioFlamingo3PreTrainedModel,
  29. )
  30. from ..audioflamingo3.processing_audioflamingo3 import AudioFlamingo3Processor
  31. from ..moonshine.modeling_moonshine import MoonshineRotaryEmbedding
  32. if is_torch_available():
  33. import torch
  34. @auto_docstring(checkpoint="nvidia/music-flamingo-2601-hf")
  35. @strict
  36. class MusicFlamingoConfig(AudioFlamingo3Config):
  37. r"""
  38. audio_bos_token_id (`int`, *optional*, defaults to 151670):
  39. The beginning-of-audio token index used to mark the start of audio spans.
  40. audio_eos_token_id (`int`, *optional*, defaults to 151671):
  41. The end-of-audio token index used to mark the end of audio spans.
  42. audio_frame_step (`float`, *optional*, defaults to 0.01):
  43. Duration in seconds of one input mel frame (trained with hop_length 160 at sampling_rate 16000).
  44. Example:
  45. ```python
  46. >>> from transformers import MusicFlamingoForConditionalGeneration, MusicFlamingoConfig, AudioFlamingo3EncoderConfig, Qwen2Config
  47. >>> # Initializing an MusicFlamingoEncoder config
  48. >>> audio_config = AudioFlamingo3EncoderConfig()
  49. >>> # Initializing a Qwen2 config
  50. >>> text_config = Qwen2Config()
  51. >>> # Initializing an MusicFlamingo configuration
  52. >>> configuration = MusicFlamingoConfig(audio_config, text_config)
  53. >>> # Initializing a model from the musicflamingo style configuration
  54. >>> model = MusicFlamingoForConditionalGeneration(configuration)
  55. >>> # Accessing the model configuration
  56. >>> configuration = model.config
  57. ```"""
  58. audio_bos_token_id: int = 151670
  59. audio_eos_token_id: int = 151671
  60. audio_frame_step: float = 0.01
  61. rope_parameters: dict | None = None
  62. def __post_init__(self, **kwargs):
  63. super().__post_init__(**kwargs)
  64. if self.rope_parameters is None:
  65. self.rope_parameters = {"rope_type": "default", "rope_theta": 1200, "partial_rotary_factor": 0.2}
  66. self.max_position_embeddings = self.rope_parameters["rope_theta"]
  67. self.head_dim = self.audio_config.hidden_size
  68. class MusicFlamingoProcessor(AudioFlamingo3Processor):
  69. r"""
  70. Constructs an MusicFlamingo processor which wraps an MusicFlamingo feature extractor and an MusicFlamingo
  71. tokenizer into a single processor.
  72. [`MusicFlamingoProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`] and
  73. [`Qwen2TokenizerFast`]. See the [`~MusicFlamingoProcessor.__call__`] for more information.
  74. Args:
  75. feature_extractor ([`WhisperFeatureExtractor`]):
  76. The feature extractor is a required input.
  77. tokenizer ([`Qwen2TokenizerFast`]):
  78. The tokenizer is a required input.
  79. chat_template (`Optional[str]`, *optional*):
  80. The Jinja template to use for formatting the conversation. If not provided, the tokenizer's default chat
  81. template will be used.
  82. audio_token (`Optional[str]`, *optional*, defaults to `"<sound>"`):
  83. Special token used to represent audio inputs in the chat template.
  84. audio_bos_token (`Optional[str]`, *optional*, defaults to `"<|sound_bos|>"`):
  85. Special token used to represent the beginning of audio.
  86. audio_eos_token (`Optional[str]`, *optional*, defaults to `"<|sound_eos|>"`):
  87. Special token used to represent the end of audio.
  88. max_audio_len (`int`, *optional*, defaults to 1200):
  89. Maximum length of audio sequences in seconds. Audio longer than this will be truncated.
  90. """
  91. def __init__(
  92. self,
  93. feature_extractor,
  94. tokenizer,
  95. chat_template=None,
  96. audio_token="<sound>",
  97. audio_bos_token="<|sound_bos|>",
  98. audio_eos_token="<|sound_eos|>",
  99. max_audio_len=1200,
  100. ):
  101. super().__init__(
  102. feature_extractor,
  103. tokenizer,
  104. chat_template=chat_template,
  105. audio_token=audio_token,
  106. max_audio_len=max_audio_len,
  107. )
  108. del self.default_transcription_prompt
  109. self.audio_bos_token = audio_bos_token
  110. self.audio_eos_token = audio_eos_token
  111. self.audio_bos_token_id = tokenizer.convert_tokens_to_ids(audio_bos_token)
  112. self.audio_eos_token_id = tokenizer.convert_tokens_to_ids(audio_eos_token)
  113. def _expand_audio_tokens(self, text, padding_mask, per_sample_windows):
  114. audio_lengths = torch.stack([s.sum() for s in torch.split(padding_mask.sum(-1), per_sample_windows)])
  115. audio_tokens_lengths = self._get_audio_token_length(audio_lengths)
  116. audio_token_pattern = re.compile(re.escape(self.audio_token))
  117. for i, audio_length in enumerate(audio_tokens_lengths):
  118. text[i] = audio_token_pattern.sub(
  119. self.audio_bos_token + self.audio_token * audio_length + self.audio_eos_token,
  120. text[i],
  121. )
  122. return text
  123. def _get_audio_tokens_mask(self, input_ids):
  124. return (
  125. (input_ids == self.audio_token_id)
  126. | (input_ids == self.audio_bos_token_id)
  127. | (input_ids == self.audio_eos_token_id)
  128. )
  129. def apply_transcription_request(self, *args, **kwargs):
  130. raise NotImplementedError("This method is not supported for MusicFlamingo.")
  131. def decode(self, *args, **kwargs):
  132. raise NotImplementedError("MusicFlamingo does not need to overwrite this method.")
  133. def batch_decode(self, *args, **kwargs):
  134. raise NotImplementedError("MusicFlamingo does not need to overwrite this method.")
  135. def _strip_assistant_prefix_and_quotes(self, *args, **kwargs):
  136. raise NotImplementedError("This method is not supported for MusicFlamingo.")
  137. def rotate_half(x):
  138. x = x.reshape(*x.shape[:-1], -1, 2)
  139. x1, x2 = x.unbind(dim=-1)
  140. x = torch.stack((-x2, x1), dim=-1)
  141. return x.flatten(-2)
  142. def apply_rotary_time_emb(hidden_states, cos, sin):
  143. original_dtype = hidden_states.dtype
  144. hidden_states = hidden_states.to(torch.float64)
  145. cos = cos.to(hidden_states)
  146. sin = sin.to(hidden_states)
  147. rot_dim = cos.shape[-1]
  148. passthrough = hidden_states[..., rot_dim:]
  149. rotated = hidden_states[..., :rot_dim]
  150. rotated = (rotated * cos) + (rotate_half(rotated) * sin)
  151. return torch.cat((rotated, passthrough), dim=-1).to(original_dtype)
  152. class MusicFlamingoRotaryEmbedding(MoonshineRotaryEmbedding):
  153. """Rotary time embedding module used by MusicFlamingo checkpoints.
  154. This is a checkpoint-faithful integration, not a direct implementation of the RoTE formulation described in
  155. (Goel et al., 2024): https://arxiv.org/abs/2410.12109. It applies axial rotary embeddings over the window index
  156. within each audio sample and the encoder time index within each window, then modulates both axes with absolute
  157. timestamps in seconds.
  158. """
  159. def __init__(self, config: MusicFlamingoConfig, device=None):
  160. super().__init__(config, device=device)
  161. position_angles = self._compute_position_angles(self.inv_freq)
  162. self.register_buffer("position_angles", position_angles, persistent=False)
  163. def _compute_position_angles(self, inv_freq):
  164. positions = torch.arange(int(self.max_seq_len_cached), device=inv_freq.device, dtype=inv_freq.dtype)
  165. positions = positions / self.max_seq_len_cached * (2 * pi)
  166. position_angles = positions.unsqueeze(-1) * inv_freq
  167. position_angles = torch.repeat_interleave(position_angles, 2, dim=-1)
  168. return position_angles.to(dtype=inv_freq.dtype)
  169. @torch.no_grad()
  170. def forward(self, timestamps: Tensor, seq_len: int) -> tuple[Tensor, Tensor]:
  171. """Compute 2D axial rotary embeddings for window and time dimensions."""
  172. # Compute frequencies for the window axis, accounting for x4 due to the downsampling in the audio encoder (conv2 and avg pooling)
  173. window_starts = timestamps[:, 0].to(device=self.inv_freq.device, dtype=self.inv_freq.dtype)
  174. window_duration = self.config.audio_frame_step * 4 * seq_len
  175. window_positions = torch.round(window_starts / window_duration) / self.max_seq_len_cached
  176. window_freqs = window_positions.unsqueeze(-1) * self.inv_freq
  177. window_freqs = torch.repeat_interleave(window_freqs, 2, dim=-1)
  178. # Broadcasting and apply time-based angle modulation
  179. window_freqs = window_freqs[:, None, :]
  180. time_freqs = self.position_angles[:seq_len][None, :, :]
  181. window_freqs, time_freqs = broadcast_tensors(window_freqs, time_freqs)
  182. freqs = torch.cat((window_freqs, time_freqs), dim=-1)
  183. angle = (-timestamps * 2 * pi).to(freqs)
  184. freqs = freqs * angle.unsqueeze(-1)
  185. return freqs.cos(), freqs.sin()
  186. class MusicFlamingoPreTrainedModel(AudioFlamingo3PreTrainedModel):
  187. _no_split_modules = None
  188. @torch.no_grad()
  189. def _init_weights(self, module):
  190. PreTrainedModel._init_weights(self, module)
  191. if isinstance(module, MusicFlamingoRotaryEmbedding):
  192. buffer_value = module._compute_position_angles(module.inv_freq)
  193. init.copy_(module.position_angles, buffer_value)
  194. @auto_docstring(
  195. custom_intro="""
  196. The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model.
  197. """
  198. )
  199. class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
  200. def __init__(self, config: MusicFlamingoConfig):
  201. super().__init__(config)
  202. self.pos_emb = MusicFlamingoRotaryEmbedding(config)
  203. def _build_audio_timestamps(
  204. self,
  205. input_ids: torch.LongTensor,
  206. post_lengths: torch.LongTensor,
  207. max_post_length: int,
  208. ) -> torch.FloatTensor:
  209. audio_token_mask = input_ids == self.config.audio_token_id
  210. diff = torch.diff(torch.nn.functional.pad(audio_token_mask.int(), (1, 1), value=0), dim=1)
  211. _, starts = torch.where(diff == 1)
  212. _, ends = torch.where(diff == -1)
  213. sample_lengths = (ends - starts).to(torch.long)
  214. # Account for 4x downsampling in audio encoder (conv2 and avg pooling)
  215. audio_embed_frame_step = self.config.audio_frame_step * 4
  216. frame_offsets = (
  217. torch.arange(max_post_length, device=post_lengths.device, dtype=torch.float32) * audio_embed_frame_step
  218. )
  219. # Map each encoder output row to its audio sample using token counts
  220. cumsum_post = torch.cat([torch.zeros(1, device=post_lengths.device), torch.cumsum(post_lengths, dim=0)[:-1]])
  221. cumsum_samples = torch.cumsum(sample_lengths, dim=0)
  222. sample_indices = torch.searchsorted(cumsum_samples, cumsum_post, right=True)
  223. # Compute window index within each sample (0, 1, 2, ... then reset for next sample)
  224. sample_start_rows = torch.searchsorted(
  225. sample_indices, torch.arange(sample_lengths.shape[0], device=post_lengths.device)
  226. )
  227. window_indices = (
  228. torch.arange(post_lengths.shape[0], device=post_lengths.device) - sample_start_rows[sample_indices]
  229. )
  230. # Compute timestamps
  231. return window_indices.unsqueeze(1) * max_post_length * audio_embed_frame_step + frame_offsets
  232. @can_return_tuple
  233. @auto_docstring(
  234. custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector."
  235. )
  236. def get_audio_features(
  237. self,
  238. input_features: torch.FloatTensor,
  239. input_features_mask: torch.Tensor,
  240. input_ids: torch.LongTensor,
  241. **kwargs: Unpack[TransformersKwargs],
  242. ) -> tuple | BaseModelOutputWithPooling:
  243. r"""
  244. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
  245. Mask to avoid performing attention on padded feature indices.
  246. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  247. Token ids containing the audio token ID placeholders, for reconstructing rotary time embedding timestamps.
  248. """
  249. audio_output = self.audio_tower(
  250. input_features,
  251. input_features_mask=input_features_mask,
  252. return_dict=True,
  253. **kwargs,
  254. )
  255. hidden_states = audio_output.last_hidden_state
  256. _, post_lengths = self.audio_tower._get_feat_extract_output_lengths(input_features_mask.sum(-1).to(torch.long))
  257. audio_timestamps = self._build_audio_timestamps(input_ids, post_lengths, hidden_states.shape[-2])
  258. cos, sin = self.pos_emb(audio_timestamps.to(hidden_states.device), seq_len=hidden_states.shape[-2])
  259. hidden_states = apply_rotary_time_emb(hidden_states, cos, sin)
  260. audio_embeds = self.multi_modal_projector(hidden_states)
  261. # Mask according to the audio tower output lengths, accounting for both conv downsampling and final avg pooling
  262. valid_mask = torch.arange(audio_embeds.shape[1], device=post_lengths.device)[None, :] < post_lengths[:, None]
  263. audio_output.pooler_output = audio_embeds[valid_mask.to(audio_embeds.device)]
  264. return audio_output
  265. @can_return_tuple
  266. @auto_docstring
  267. def forward(
  268. self,
  269. input_ids: torch.LongTensor | None = None,
  270. input_features: torch.FloatTensor | None = None,
  271. input_features_mask: torch.Tensor | None = None,
  272. attention_mask: torch.Tensor | None = None,
  273. position_ids: torch.LongTensor | None = None,
  274. past_key_values: Cache | None = None,
  275. inputs_embeds: torch.FloatTensor | None = None,
  276. labels: torch.LongTensor | None = None,
  277. use_cache: bool | None = None,
  278. logits_to_keep: int | torch.Tensor = 0,
  279. **kwargs: Unpack[TransformersKwargs],
  280. ) -> CausalLMOutputWithPast:
  281. r"""
  282. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*):
  283. Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
  284. - 1 for tokens that are **not masked**,
  285. - 0 for tokens that are **masked**.
  286. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  287. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  288. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  289. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  290. Example:
  291. ```python
  292. >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor
  293. >>> model_id = "nvidia/music-flamingo-2601-hf"
  294. >>> processor = AutoProcessor.from_pretrained(model_id)
  295. >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto")
  296. >>> conversation = [
  297. >>> {
  298. >>> "role": "user",
  299. >>> "content": [
  300. >>> {
  301. >>> "type": "text",
  302. >>> "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.",
  303. >>> },
  304. >>> {
  305. >>> "type": "audio",
  306. >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3",
  307. >>> },
  308. >>> ],
  309. >>> }
  310. >>> ]
  311. >>> inputs = processor.apply_chat_template(
  312. >>> conversation,
  313. >>> tokenize=True,
  314. >>> add_generation_prompt=True,
  315. >>> return_dict=True,
  316. >>> ).to(model.device, model.dtype)
  317. >>> outputs = model.generate(**inputs, max_new_tokens=100)
  318. >>> decoded_outputs = processor.batch_decode(
  319. >>> outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
  320. >>> )
  321. >>> print(decoded_outputs)
  322. ["This track is an uplifting Eurodance-style Trance-Pop anthem..."]
  323. ```"""
  324. if inputs_embeds is None:
  325. inputs_embeds = self.get_input_embeddings()(input_ids)
  326. if input_features is not None and input_ids is not None:
  327. audio_embeds = self.get_audio_features(
  328. input_features, input_features_mask, input_ids=input_ids, return_dict=True
  329. ).pooler_output
  330. # replace text-audio token placeholders with audio embeddings
  331. audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
  332. inputs_embeds = inputs_embeds.masked_scatter(
  333. audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
  334. )
  335. outputs: CausalLMOutputWithPast = self.language_model(
  336. inputs_embeds=inputs_embeds,
  337. attention_mask=attention_mask,
  338. position_ids=position_ids,
  339. past_key_values=past_key_values,
  340. labels=labels,
  341. use_cache=use_cache,
  342. logits_to_keep=logits_to_keep,
  343. **kwargs,
  344. )
  345. return outputs
  346. __all__ = [
  347. "MusicFlamingoConfig",
  348. "MusicFlamingoProcessor",
  349. "MusicFlamingoForConditionalGeneration",
  350. "MusicFlamingoPreTrainedModel",
  351. ]