modular_audioflamingo3.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
  2. # reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import torch
  16. from torch import nn
  17. from ...activations import ACT2FN
  18. from ...cache_utils import Cache
  19. from ...masking_utils import create_bidirectional_mask
  20. from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
  21. from ...processing_utils import Unpack
  22. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  23. from ...utils.generic import merge_with_config_defaults
  24. from ...utils.output_capturing import capture_outputs
  25. from ..qwen2_audio.modeling_qwen2_audio import (
  26. Qwen2AudioEncoder,
  27. Qwen2AudioPreTrainedModel,
  28. )
  29. from ..voxtral.modeling_voxtral import VoxtralForConditionalGeneration, VoxtralMultiModalProjector
  30. from ..whisper.modeling_whisper import WhisperAttention, WhisperEncoderLayer
  31. from .configuration_audioflamingo3 import AudioFlamingo3Config
  32. logger = logging.get_logger(__name__)
  33. class AudioFlamingo3Attention(WhisperAttention):
  34. pass
  35. class AudioFlamingo3EncoderLayer(WhisperEncoderLayer):
  36. pass
  37. class AudioFlamingo3PreTrainedModel(Qwen2AudioPreTrainedModel):
  38. pass
  39. @auto_docstring(
  40. custom_intro="""
  41. The audio model from AudioFlamingo3 without any head or projection on top.
  42. """
  43. )
  44. class AudioFlamingo3Encoder(Qwen2AudioEncoder):
  45. """
  46. AudioFlamingo3 encoder: Whisper encoder, average pool (time/2), then LayerNorm.
  47. """
  48. _can_record_outputs = {
  49. "hidden_states": AudioFlamingo3EncoderLayer,
  50. "attentions": AudioFlamingo3Attention,
  51. }
  52. @merge_with_config_defaults
  53. @capture_outputs
  54. def forward(
  55. self,
  56. input_features: torch.Tensor,
  57. input_features_mask: torch.Tensor | None = None,
  58. **kwargs,
  59. ) -> tuple | BaseModelOutputWithPooling:
  60. r"""
  61. Args:
  62. input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
  63. Log-Mel features extracted from raw audio. Use the processor/feature extractor to compute and pad
  64. these features from waveform input.
  65. input_features_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  66. Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
  67. - 1 for tokens that are **not masked**,
  68. - 0 for tokens that are **masked**.
  69. """
  70. seq_len = (input_features.shape[-1] - 1) // 2 + 1 # After conv2 downsampling
  71. input_features_lengths = input_features_mask.sum(-1)
  72. input_features_lengths = (input_features_lengths - 1) // 2 + 1 # conv2 downsampling
  73. input_features_mask = torch.arange(seq_len, device=input_features.device) < input_features_lengths[:, None]
  74. # Conv front-end
  75. inputs_embeds = nn.functional.gelu(self.conv1(input_features))
  76. inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
  77. inputs_embeds = inputs_embeds.permute(0, 2, 1)
  78. # Add positions, dropout
  79. hidden_states = inputs_embeds + self.embed_positions.weight
  80. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  81. attention_mask = create_bidirectional_mask(
  82. config=self.config,
  83. inputs_embeds=hidden_states,
  84. attention_mask=input_features_mask,
  85. )
  86. # Transformer stack
  87. for layer in self.layers:
  88. drop = self.training and torch.rand([]) < self.layerdrop
  89. if not drop:
  90. hidden_states = layer(hidden_states, attention_mask)
  91. # AvgPool (time/2) + LayerNorm
  92. hidden_states = hidden_states.permute(0, 2, 1)
  93. hidden_states = self.avg_pooler(hidden_states).permute(0, 2, 1)
  94. hidden_states = self.layer_norm(hidden_states)
  95. return BaseModelOutputWithPooling(
  96. last_hidden_state=hidden_states,
  97. )
  98. class AudioFlamingo3MultiModalProjector(VoxtralMultiModalProjector):
  99. """
  100. Audio adaptor (small MLP) that projects AudioFlamingo3Encoder features
  101. to the LLM embedding space so they can replace `<sound>` tokens.
  102. """
  103. def __init__(self, config: AudioFlamingo3Config):
  104. super().__init__()
  105. self.linear_1 = nn.Linear(
  106. config.audio_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias
  107. )
  108. self.act = ACT2FN[config.projector_hidden_act]
  109. self.linear_2 = nn.Linear(
  110. config.text_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias
  111. )
  112. @auto_docstring(
  113. custom_intro="""
  114. The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model.
  115. """
  116. )
  117. class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration):
  118. _tp_plan = None
  119. _pp_plan = None
  120. _keep_in_fp32_modules_strict = None
  121. def __init__(self, config):
  122. super().__init__(config)
  123. @can_return_tuple
  124. @auto_docstring(
  125. custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector."
  126. )
  127. def get_audio_features(
  128. self,
  129. input_features: torch.FloatTensor,
  130. input_features_mask: torch.Tensor,
  131. **kwargs: Unpack[TransformersKwargs],
  132. ) -> tuple | BaseModelOutputWithPooling:
  133. r"""
  134. input_features (`torch.FloatTensor`):
  135. Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
  136. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
  137. `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
  138. `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
  139. and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
  140. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
  141. Mask to avoid performing attention on padded feature indices.
  142. """
  143. audio_output = self.audio_tower(
  144. input_features, input_features_mask=input_features_mask, return_dict=True, **kwargs
  145. )
  146. audio_embeds = self.multi_modal_projector(audio_output.last_hidden_state)
  147. # Mask according to the audio tower output lengths, accounting for both conv downsampling and final avg pooling
  148. input_lengths = input_features_mask.sum(-1).to(torch.long)
  149. _, post_lengths = self.audio_tower._get_feat_extract_output_lengths(input_lengths)
  150. valid_mask = torch.arange(audio_embeds.shape[1], device=post_lengths.device)[None, :] < post_lengths[:, None]
  151. audio_output.pooler_output = audio_embeds[valid_mask.to(audio_embeds.device)]
  152. return audio_output
  153. @can_return_tuple
  154. @auto_docstring
  155. def forward(
  156. self,
  157. input_ids: torch.LongTensor | None = None,
  158. input_features: torch.FloatTensor | None = None,
  159. input_features_mask: torch.Tensor | None = None,
  160. attention_mask: torch.Tensor | None = None,
  161. position_ids: torch.LongTensor | None = None,
  162. past_key_values: Cache | None = None,
  163. inputs_embeds: torch.FloatTensor | None = None,
  164. labels: torch.LongTensor | None = None,
  165. use_cache: bool | None = None,
  166. logits_to_keep: int | torch.Tensor = 0,
  167. **kwargs: Unpack[TransformersKwargs],
  168. ) -> CausalLMOutputWithPast:
  169. r"""
  170. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
  171. Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
  172. - 1 for tokens that are **not masked**,
  173. - 0 for tokens that are **masked**.
  174. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  175. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  176. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  177. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  178. Example:
  179. ```python
  180. >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
  181. >>> model_id = "nvidia/audio-flamingo-3-hf"
  182. >>> processor = AutoProcessor.from_pretrained(model_id)
  183. >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
  184. >>> conversations = [
  185. >>> [
  186. >>> {
  187. >>> "role": "user",
  188. >>> "content": [
  189. >>> {"type": "text", "text": "Transcribe the input speech."},
  190. >>> {
  191. >>> "type": "audio",
  192. >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav",
  193. >>> },
  194. >>> ],
  195. >>> }
  196. >>> ],
  197. >>> [
  198. >>> {
  199. >>> "role": "user",
  200. >>> "content": [
  201. >>> {
  202. >>> "type": "text",
  203. >>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?",
  204. >>> },
  205. >>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"},
  206. >>> ],
  207. >>> }
  208. >>> ],
  209. >>> ]
  210. >>> inputs = processor.apply_chat_template(
  211. >>> conversations,
  212. >>> tokenize=True,
  213. >>> add_generation_prompt=True,
  214. >>> return_dict=True,
  215. >>> ).to(model.device)
  216. >>> outputs = model.generate(**inputs, max_new_tokens=500)
  217. >>> decoded_outputs = processor.batch_decode(
  218. >>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True
  219. >>> )
  220. >>> print(decoded_outputs)
  221. ["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."]
  222. ```"""
  223. if inputs_embeds is None:
  224. inputs_embeds = self.get_input_embeddings()(input_ids)
  225. if input_features is not None and input_ids is not None:
  226. audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output
  227. # replace text-audio token placeholders with audio embeddings
  228. audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
  229. inputs_embeds = inputs_embeds.masked_scatter(
  230. audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
  231. )
  232. outputs: CausalLMOutputWithPast = self.language_model(
  233. inputs_embeds=inputs_embeds,
  234. attention_mask=attention_mask,
  235. position_ids=position_ids,
  236. past_key_values=past_key_values,
  237. labels=labels,
  238. use_cache=use_cache,
  239. logits_to_keep=logits_to_keep,
  240. **kwargs,
  241. )
  242. return outputs
  243. def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs):
  244. input_features = kwargs.pop("input_features", None)
  245. input_features_mask = kwargs.pop("input_features_mask", None)
  246. model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
  247. if is_first_iteration or not model_inputs.get("use_cache", False):
  248. if input_features is not None:
  249. model_inputs["input_features"] = input_features
  250. if input_features_mask is not None:
  251. model_inputs["input_features_mask"] = input_features_mask
  252. return model_inputs
  253. __all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"]