modeling_audioflamingo3.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/audioflamingo3/modular_audioflamingo3.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_audioflamingo3.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
  8. # reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. from collections.abc import Callable
  23. import torch
  24. from torch import nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, EncoderDecoderCache
  27. from ...generation import GenerationMixin
  28. from ...masking_utils import create_bidirectional_mask
  29. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  35. from ...utils.generic import merge_with_config_defaults
  36. from ...utils.output_capturing import capture_outputs
  37. from ..auto import AutoModel, AutoModelForCausalLM
  38. from .configuration_audioflamingo3 import AudioFlamingo3Config, AudioFlamingo3EncoderConfig
  39. logger = logging.get_logger(__name__)
  40. def eager_attention_forward(
  41. module: nn.Module,
  42. query: torch.Tensor,
  43. key: torch.Tensor,
  44. value: torch.Tensor,
  45. attention_mask: torch.Tensor | None,
  46. scaling: float | None = None,
  47. dropout: float = 0.0,
  48. **kwargs,
  49. ):
  50. if scaling is None:
  51. scaling = query.size(-1) ** -0.5
  52. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  53. if attention_mask is not None:
  54. attn_weights = attn_weights + attention_mask
  55. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  56. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  57. attn_output = torch.matmul(attn_weights, value)
  58. attn_output = attn_output.transpose(1, 2).contiguous()
  59. return attn_output, attn_weights
  60. class AudioFlamingo3Attention(nn.Module):
  61. """Multi-headed attention from 'Attention Is All You Need' paper"""
  62. def __init__(
  63. self,
  64. embed_dim: int,
  65. num_heads: int,
  66. dropout: float = 0.0,
  67. is_decoder: bool = False,
  68. bias: bool = True,
  69. is_causal: bool = False,
  70. layer_idx: int | None = None,
  71. config: AudioFlamingo3Config | None = None,
  72. ):
  73. super().__init__()
  74. self.embed_dim = embed_dim
  75. self.num_heads = num_heads
  76. self.dropout = dropout
  77. self.head_dim = embed_dim // num_heads
  78. self.config = config
  79. if (self.head_dim * num_heads) != self.embed_dim:
  80. raise ValueError(
  81. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  82. f" and `num_heads`: {num_heads})."
  83. )
  84. self.scaling = self.head_dim**-0.5
  85. self.is_decoder = is_decoder
  86. self.is_causal = is_causal
  87. if layer_idx is None and is_decoder:
  88. logger.warning_once(
  89. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  90. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  91. "when creating this class."
  92. )
  93. self.layer_idx = layer_idx
  94. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
  95. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  96. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  97. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  98. def forward(
  99. self,
  100. hidden_states: torch.Tensor,
  101. key_value_states: torch.Tensor | None = None,
  102. past_key_values: Cache | None = None,
  103. attention_mask: torch.Tensor | None = None,
  104. output_attentions: bool = False,
  105. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  106. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  107. **kwargs: Unpack[FlashAttentionKwargs],
  108. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  109. """Input shape: Batch x Time x Channel"""
  110. # if key_value_states are provided this layer is used as a cross-attention layer
  111. # for the decoder
  112. is_cross_attention = key_value_states is not None
  113. input_shape = hidden_states.shape[:-1]
  114. hidden_shape = (*input_shape, -1, self.head_dim)
  115. # Scaling is susceptible to floating point arithmetics' inprecisions
  116. # which can lead to different results (this is dependent from model
  117. # to model, e.g. audioflamingo3 is one such case). We therefore keep the
  118. # original order of scaling to follow the original implementation
  119. # and enforce no scaling (1.0) in the attention call below.
  120. query_states = (self.q_proj(hidden_states) * self.scaling).view(hidden_shape).transpose(1, 2).contiguous()
  121. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  122. if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache):
  123. is_updated = past_key_values.is_updated.get(self.layer_idx)
  124. if is_cross_attention:
  125. # after the first generated id, we can subsequently re-use all key/value_states from cache
  126. past_key_values.is_updated[self.layer_idx] = True
  127. past_key_values = past_key_values.cross_attention_cache
  128. else:
  129. past_key_values = past_key_values.self_attention_cache
  130. # use key_value_states if cross attention
  131. current_states = key_value_states if key_value_states is not None else hidden_states
  132. if is_cross_attention and past_key_values and is_updated:
  133. # reuse k,v, cross_attentions
  134. key_states = past_key_values.layers[self.layer_idx].keys
  135. value_states = past_key_values.layers[self.layer_idx].values
  136. else:
  137. # Use the query's batch dimension for kv view so that a different-batch
  138. # encoder output (e.g. in tests) gets absorbed into the sequence axis,
  139. # preserving backward-compatible behaviour.
  140. kv_shape = (input_shape[0], -1, self.num_heads, self.head_dim)
  141. key_states = self.k_proj(current_states).view(kv_shape).transpose(1, 2).contiguous()
  142. value_states = self.v_proj(current_states).view(kv_shape).transpose(1, 2).contiguous()
  143. if past_key_values is not None:
  144. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  145. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  146. self.config._attn_implementation, eager_attention_forward
  147. )
  148. attn_output, attn_weights = attention_interface(
  149. self,
  150. query_states,
  151. key_states,
  152. value_states,
  153. attention_mask,
  154. dropout=0.0 if not self.training else self.dropout,
  155. scaling=1.0,
  156. output_attentions=output_attentions,
  157. **kwargs,
  158. )
  159. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  160. attn_output = self.out_proj(attn_output)
  161. return attn_output, attn_weights
  162. class AudioFlamingo3EncoderLayer(GradientCheckpointingLayer):
  163. def __init__(self, config: AudioFlamingo3Config):
  164. super().__init__()
  165. self.embed_dim = config.d_model
  166. self.self_attn = AudioFlamingo3Attention(
  167. embed_dim=self.embed_dim,
  168. num_heads=config.encoder_attention_heads,
  169. dropout=config.attention_dropout,
  170. config=config,
  171. )
  172. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  173. self.dropout = config.dropout
  174. self.activation_fn = ACT2FN[config.activation_function]
  175. self.activation_dropout = config.activation_dropout
  176. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  177. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  178. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  179. def forward(
  180. self,
  181. hidden_states: torch.Tensor,
  182. attention_mask: torch.Tensor,
  183. **kwargs: Unpack[TransformersKwargs],
  184. ) -> torch.Tensor:
  185. """
  186. Args:
  187. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  188. attention_mask (`torch.FloatTensor`): attention mask of size
  189. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  190. """
  191. residual = hidden_states
  192. hidden_states = self.self_attn_layer_norm(hidden_states)
  193. hidden_states, _ = self.self_attn(
  194. hidden_states=hidden_states,
  195. attention_mask=attention_mask,
  196. **kwargs,
  197. )
  198. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  199. hidden_states = residual + hidden_states
  200. residual = hidden_states
  201. hidden_states = self.final_layer_norm(hidden_states)
  202. hidden_states = self.activation_fn(self.fc1(hidden_states))
  203. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  204. hidden_states = self.fc2(hidden_states)
  205. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  206. hidden_states = residual + hidden_states
  207. if hidden_states.dtype == torch.float16:
  208. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  209. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  210. return hidden_states
  211. @auto_docstring
  212. class AudioFlamingo3PreTrainedModel(PreTrainedModel):
  213. config: AudioFlamingo3Config
  214. base_model_prefix = "model"
  215. input_modalities = ("audio", "text")
  216. supports_gradient_checkpointing = True
  217. _no_split_modules = ["AudioFlamingo3Attention"]
  218. _skip_keys_device_placement = "past_key_values"
  219. _supports_flash_attn = True
  220. _supports_sdpa = True
  221. @auto_docstring(
  222. custom_intro="""
  223. The audio model from AudioFlamingo3 without any head or projection on top.
  224. """
  225. )
  226. class AudioFlamingo3Encoder(AudioFlamingo3PreTrainedModel):
  227. """
  228. AudioFlamingo3 encoder: Whisper encoder, average pool (time/2), then LayerNorm.
  229. """
  230. # Ignore copy
  231. config: AudioFlamingo3EncoderConfig
  232. main_input_name = "input_features"
  233. input_modalities = "audio"
  234. _no_split_modules = ["AudioFlamingo3EncoderLayer"]
  235. _can_record_outputs = {
  236. "hidden_states": AudioFlamingo3EncoderLayer,
  237. "attentions": AudioFlamingo3Attention,
  238. }
  239. def __init__(self, config: AudioFlamingo3EncoderConfig):
  240. super().__init__(config)
  241. self.dropout = config.dropout
  242. self.layerdrop = config.encoder_layerdrop
  243. embed_dim = config.d_model
  244. self.num_mel_bins = config.num_mel_bins
  245. self.max_source_positions = config.max_source_positions
  246. self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  247. self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
  248. self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
  249. self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
  250. self.embed_positions.requires_grad_(False)
  251. self.layers = nn.ModuleList([AudioFlamingo3EncoderLayer(config) for _ in range(config.encoder_layers)])
  252. self.layer_norm = nn.LayerNorm(config.d_model)
  253. # Ignore copy
  254. self.avg_pooler = nn.AvgPool1d(2, stride=2)
  255. self.gradient_checkpointing = False
  256. # Initialize weights and apply final processing
  257. self.post_init()
  258. def _freeze_parameters(self):
  259. for param in self.parameters():
  260. param.requires_grad = False
  261. self._requires_grad = False
  262. def get_input_embeddings(self) -> nn.Module:
  263. return self.conv1
  264. def set_input_embeddings(self, value: nn.Module):
  265. self.conv1 = value
  266. @merge_with_config_defaults
  267. @capture_outputs
  268. def forward(
  269. self,
  270. input_features: torch.Tensor,
  271. input_features_mask: torch.Tensor | None = None,
  272. **kwargs,
  273. ) -> tuple | BaseModelOutputWithPooling:
  274. r"""
  275. Args:
  276. input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
  277. Log-Mel features extracted from raw audio. Use the processor/feature extractor to compute and pad
  278. these features from waveform input.
  279. input_features_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  280. Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
  281. - 1 for tokens that are **not masked**,
  282. - 0 for tokens that are **masked**.
  283. """
  284. seq_len = (input_features.shape[-1] - 1) // 2 + 1 # After conv2 downsampling
  285. input_features_lengths = input_features_mask.sum(-1)
  286. input_features_lengths = (input_features_lengths - 1) // 2 + 1 # conv2 downsampling
  287. input_features_mask = torch.arange(seq_len, device=input_features.device) < input_features_lengths[:, None]
  288. # Conv front-end
  289. inputs_embeds = nn.functional.gelu(self.conv1(input_features))
  290. inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
  291. inputs_embeds = inputs_embeds.permute(0, 2, 1)
  292. # Add positions, dropout
  293. hidden_states = inputs_embeds + self.embed_positions.weight
  294. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  295. attention_mask = create_bidirectional_mask(
  296. config=self.config,
  297. inputs_embeds=hidden_states,
  298. attention_mask=input_features_mask,
  299. )
  300. # Transformer stack
  301. for layer in self.layers:
  302. drop = self.training and torch.rand([]) < self.layerdrop
  303. if not drop:
  304. hidden_states = layer(hidden_states, attention_mask)
  305. # AvgPool (time/2) + LayerNorm
  306. hidden_states = hidden_states.permute(0, 2, 1)
  307. hidden_states = self.avg_pooler(hidden_states).permute(0, 2, 1)
  308. hidden_states = self.layer_norm(hidden_states)
  309. return BaseModelOutputWithPooling(
  310. last_hidden_state=hidden_states,
  311. )
  312. # Ignore copy
  313. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
  314. """
  315. Computes the output length of the convolutional layers and the output length of the audio encoder
  316. """
  317. input_lengths = (input_lengths - 1) // 2 + 1
  318. output_lengths = (input_lengths - 2) // 2 + 1
  319. return input_lengths, output_lengths
  320. class AudioFlamingo3MultiModalProjector(nn.Module):
  321. """
  322. Audio adaptor (small MLP) that projects AudioFlamingo3Encoder features
  323. to the LLM embedding space so they can replace `<sound>` tokens.
  324. """
  325. def __init__(self, config: AudioFlamingo3Config):
  326. super().__init__()
  327. self.linear_1 = nn.Linear(
  328. config.audio_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias
  329. )
  330. self.act = ACT2FN[config.projector_hidden_act]
  331. self.linear_2 = nn.Linear(
  332. config.text_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias
  333. )
  334. def forward(self, audio_features):
  335. hidden_states = self.linear_1(audio_features)
  336. hidden_states = self.act(hidden_states)
  337. hidden_states = self.linear_2(hidden_states)
  338. return hidden_states
  339. @auto_docstring(
  340. custom_intro="""
  341. The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model.
  342. """
  343. )
  344. class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin):
  345. _keep_in_fp32_modules_strict = None
  346. _tp_plan = None
  347. _pp_plan = None
  348. def __init__(self, config):
  349. super().__init__(config)
  350. self.vocab_size = config.text_config.vocab_size
  351. self.audio_tower = AutoModel.from_config(config.audio_config)
  352. self.language_model = AutoModelForCausalLM.from_config(config.text_config)
  353. self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config)
  354. # Initialize weights and apply final processing
  355. self.post_init()
  356. def get_input_embeddings(self):
  357. return self.language_model.get_input_embeddings()
  358. def set_input_embeddings(self, value):
  359. self.language_model.set_input_embeddings(value)
  360. def get_output_embeddings(self):
  361. return self.language_model.get_output_embeddings()
  362. def set_output_embeddings(self, new_embeddings):
  363. self.language_model.set_output_embeddings(new_embeddings)
  364. def set_decoder(self, decoder):
  365. self.language_model.set_decoder(decoder)
  366. def get_decoder(self):
  367. return self.language_model.get_decoder()
  368. @can_return_tuple
  369. @auto_docstring(
  370. custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector."
  371. )
  372. def get_audio_features(
  373. self,
  374. input_features: torch.FloatTensor,
  375. input_features_mask: torch.Tensor,
  376. **kwargs: Unpack[TransformersKwargs],
  377. ) -> tuple | BaseModelOutputWithPooling:
  378. r"""
  379. input_features (`torch.FloatTensor`):
  380. Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
  381. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
  382. `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
  383. `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
  384. and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
  385. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
  386. Mask to avoid performing attention on padded feature indices.
  387. """
  388. audio_output = self.audio_tower(
  389. input_features, input_features_mask=input_features_mask, return_dict=True, **kwargs
  390. )
  391. audio_embeds = self.multi_modal_projector(audio_output.last_hidden_state)
  392. # Mask according to the audio tower output lengths, accounting for both conv downsampling and final avg pooling
  393. input_lengths = input_features_mask.sum(-1).to(torch.long)
  394. _, post_lengths = self.audio_tower._get_feat_extract_output_lengths(input_lengths)
  395. valid_mask = torch.arange(audio_embeds.shape[1], device=post_lengths.device)[None, :] < post_lengths[:, None]
  396. audio_output.pooler_output = audio_embeds[valid_mask.to(audio_embeds.device)]
  397. return audio_output
  398. @can_return_tuple
  399. @auto_docstring
  400. def forward(
  401. self,
  402. input_ids: torch.LongTensor | None = None,
  403. input_features: torch.FloatTensor | None = None,
  404. input_features_mask: torch.Tensor | None = None,
  405. attention_mask: torch.Tensor | None = None,
  406. position_ids: torch.LongTensor | None = None,
  407. past_key_values: Cache | None = None,
  408. inputs_embeds: torch.FloatTensor | None = None,
  409. labels: torch.LongTensor | None = None,
  410. use_cache: bool | None = None,
  411. logits_to_keep: int | torch.Tensor = 0,
  412. **kwargs: Unpack[TransformersKwargs],
  413. ) -> CausalLMOutputWithPast:
  414. r"""
  415. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
  416. Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
  417. - 1 for tokens that are **not masked**,
  418. - 0 for tokens that are **masked**.
  419. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  420. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  421. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  422. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  423. Example:
  424. ```python
  425. >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
  426. >>> model_id = "nvidia/audio-flamingo-3-hf"
  427. >>> processor = AutoProcessor.from_pretrained(model_id)
  428. >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
  429. >>> conversations = [
  430. >>> [
  431. >>> {
  432. >>> "role": "user",
  433. >>> "content": [
  434. >>> {"type": "text", "text": "Transcribe the input speech."},
  435. >>> {
  436. >>> "type": "audio",
  437. >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav",
  438. >>> },
  439. >>> ],
  440. >>> }
  441. >>> ],
  442. >>> [
  443. >>> {
  444. >>> "role": "user",
  445. >>> "content": [
  446. >>> {
  447. >>> "type": "text",
  448. >>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?",
  449. >>> },
  450. >>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"},
  451. >>> ],
  452. >>> }
  453. >>> ],
  454. >>> ]
  455. >>> inputs = processor.apply_chat_template(
  456. >>> conversations,
  457. >>> tokenize=True,
  458. >>> add_generation_prompt=True,
  459. >>> return_dict=True,
  460. >>> ).to(model.device)
  461. >>> outputs = model.generate(**inputs, max_new_tokens=500)
  462. >>> decoded_outputs = processor.batch_decode(
  463. >>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True
  464. >>> )
  465. >>> print(decoded_outputs)
  466. ["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."]
  467. ```"""
  468. if inputs_embeds is None:
  469. inputs_embeds = self.get_input_embeddings()(input_ids)
  470. if input_features is not None and input_ids is not None:
  471. audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output
  472. # replace text-audio token placeholders with audio embeddings
  473. audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
  474. inputs_embeds = inputs_embeds.masked_scatter(
  475. audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
  476. )
  477. outputs: CausalLMOutputWithPast = self.language_model(
  478. inputs_embeds=inputs_embeds,
  479. attention_mask=attention_mask,
  480. position_ids=position_ids,
  481. past_key_values=past_key_values,
  482. labels=labels,
  483. use_cache=use_cache,
  484. logits_to_keep=logits_to_keep,
  485. **kwargs,
  486. )
  487. return outputs
  488. def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs):
  489. input_features = kwargs.pop("input_features", None)
  490. input_features_mask = kwargs.pop("input_features_mask", None)
  491. model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
  492. if is_first_iteration or not model_inputs.get("use_cache", False):
  493. if input_features is not None:
  494. model_inputs["input_features"] = input_features
  495. if input_features_mask is not None:
  496. model_inputs["input_features_mask"] = input_features_mask
  497. return model_inputs
  498. __all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"]