processing_audioflamingo3.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. # Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
  2. # reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import re
  16. import numpy as np
  17. from ...audio_utils import AudioInput, make_list_of_audio
  18. from ...feature_extraction_utils import BatchFeature
  19. from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
  20. from ...tokenization_utils_base import TextInput
  21. from ...utils import is_torch_available, logging
  22. if is_torch_available():
  23. import torch
  24. logger = logging.get_logger(__name__)
  25. class AudioFlamingo3ProcessorKwargs(ProcessingKwargs, total=False):
  26. _defaults = {
  27. "text_kwargs": {
  28. "padding": True,
  29. },
  30. "audio_kwargs": {
  31. "sampling_rate": 16000,
  32. "return_attention_mask": True,
  33. "padding": "max_length",
  34. },
  35. "common_kwargs": {
  36. "return_tensors": "pt",
  37. "padding_side": "left",
  38. },
  39. }
  40. class AudioFlamingo3Processor(ProcessorMixin):
  41. r"""
  42. Constructs an AudioFlamingo3 processor which wraps an AudioFlamingo3 feature extractor and an AudioFlamingo3
  43. tokenizer into a single processor.
  44. [`AudioFlamingo3Processor`] offers all the functionalities of [`WhisperFeatureExtractor`] and
  45. [`Qwen2TokenizerFast`]. See the [`~AudioFlamingo3Processor.__call__`] for more information.
  46. Args:
  47. feature_extractor ([`WhisperFeatureExtractor`]):
  48. The feature extractor is a required input.
  49. tokenizer ([`Qwen2TokenizerFast`]):
  50. The tokenizer is a required input.
  51. chat_template (`Optional[str]`, *optional*):
  52. The Jinja template to use for formatting the conversation. If not provided, the tokenizer's default chat
  53. template will be used.
  54. audio_token (`Optional[str]`, *optional*, defaults to `"<sound>"`):
  55. Special token used to represent audio inputs in the chat template.
  56. default_transcription_prompt (`str`, *optional*, defaults to `"Transcribe the input speech."`):
  57. Default prompt to use for transcription tasks when applying transcription requests.
  58. max_audio_len (`int`, *optional*, defaults to 600):
  59. Maximum length of audio sequences in seconds. Audio longer than this will be truncated.
  60. """
  61. def __init__(
  62. self,
  63. feature_extractor,
  64. tokenizer,
  65. chat_template=None,
  66. audio_token="<sound>",
  67. default_transcription_prompt="Transcribe the input speech.",
  68. max_audio_len=600,
  69. ):
  70. self.audio_token = audio_token
  71. self.audio_token_id = tokenizer.convert_tokens_to_ids(audio_token)
  72. self.default_transcription_prompt = default_transcription_prompt
  73. self.max_audio_len = max_audio_len
  74. super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
  75. def _get_audio_token_length(self, audio_lengths):
  76. conv_output_lengths = (audio_lengths - 1) // 2 + 1 # After conv2 downsampling
  77. audio_tokens_lengths = (conv_output_lengths - 2) // 2 + 1 # After avg pooling
  78. return audio_tokens_lengths
  79. def _expand_audio_tokens(self, text, padding_mask, per_sample_windows):
  80. audio_lengths = torch.stack([s.sum() for s in torch.split(padding_mask.sum(-1), per_sample_windows)])
  81. audio_tokens_lengths = self._get_audio_token_length(audio_lengths)
  82. audio_token_pattern = re.compile(re.escape(self.audio_token))
  83. for i, audio_length in enumerate(audio_tokens_lengths):
  84. text[i] = audio_token_pattern.sub(self.audio_token * audio_length, text[i])
  85. return text
  86. def _get_audio_tokens_mask(self, input_ids):
  87. return input_ids == self.audio_token_id
  88. def __call__(
  89. self,
  90. text: TextInput | list[TextInput],
  91. audio: AudioInput | None = None,
  92. output_labels: bool | None = False,
  93. **kwargs: Unpack[AudioFlamingo3ProcessorKwargs],
  94. ) -> BatchFeature:
  95. r"""
  96. Main method to prepare one or several text sequence(s) and audio waveform(s) for the model. This
  97. method expands `<sound>` placeholders in the text based on the post-pool frame counts of the
  98. audio windows, then tokenizes the provided strings as-is, and extracts log-mel features
  99. with [`WhisperFeatureExtractor`]. If `audio` is `None`, no audio processing is performed and
  100. the text is tokenized as-is (LM-only behavior).
  101. Args:
  102. text (`str` or `list[str]`):
  103. Input sequence or batch of sequences.
  104. audio (`np.ndarray` or `list[np.ndarray]`):
  105. Input audio or batch of audios as NumPy arrays. If provided, there must be as many `text` inputs as
  106. `audio` inputs.
  107. output_labels (bool, *optional*, default=False):
  108. Whether to return labels for training.
  109. Returns:
  110. [`BatchFeature`]: A dictionary with tokenized text (`input_ids`, `attention_mask`) and
  111. audio features (`input_features`, `input_features_mask`).
  112. """
  113. # Merge defaults with user kwargs
  114. call_kwargs = self._merge_kwargs(
  115. AudioFlamingo3ProcessorKwargs,
  116. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  117. **kwargs,
  118. )
  119. text_kwargs = call_kwargs["text_kwargs"]
  120. audio_kwargs = call_kwargs["audio_kwargs"]
  121. return_tensors = text_kwargs.get("return_tensors")
  122. if return_tensors != "pt":
  123. raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
  124. if isinstance(text, str):
  125. text = [text]
  126. elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
  127. raise ValueError("Invalid input text. Please provide a string, or a list of strings")
  128. audio_inputs = {}
  129. if audio is not None:
  130. audio = make_list_of_audio(audio)
  131. if len(text) != len(audio):
  132. raise ValueError(f"Got {len(text)} text but {len(audio)} audios; they must match 1:1.")
  133. # Determine number of chunks per sample, and flatten
  134. window_size = int(audio_kwargs["sampling_rate"] * self.feature_extractor.chunk_length)
  135. max_windows = int(self.max_audio_len // self.feature_extractor.chunk_length)
  136. per_sample_windows: list[int] = []
  137. flat_chunks: list[np.ndarray] = []
  138. for audio_el in audio:
  139. n_samples = int(audio_el.shape[0])
  140. n_win = max(1, (n_samples + window_size - 1) // window_size)
  141. if n_win > max_windows:
  142. logger.warning(
  143. f"Audio duration ({n_samples / audio_kwargs['sampling_rate']:.1f}s) exceeds {self.max_audio_len}s; truncating to first {self.max_audio_len}s."
  144. )
  145. n_win = max_windows
  146. per_sample_windows.append(n_win)
  147. time_cap = min(n_samples, n_win * window_size)
  148. for i in range(n_win):
  149. start = i * window_size
  150. end = min((i + 1) * window_size, time_cap)
  151. flat_chunks.append(audio_el[start:end])
  152. # Feature extraction
  153. audio_inputs = self.feature_extractor(flat_chunks, **audio_kwargs)
  154. padding_mask = audio_inputs.pop("attention_mask")
  155. audio_inputs["input_features_mask"] = padding_mask
  156. # Expand audio tokens in text
  157. text = self._expand_audio_tokens(text, padding_mask, per_sample_windows)
  158. # Tokenize
  159. text_inputs = self.tokenizer(text, **text_kwargs)
  160. data = {**text_inputs, **audio_inputs}
  161. if output_labels:
  162. labels = data["input_ids"].clone()
  163. labels[self._get_audio_tokens_mask(labels)] = -100
  164. labels[labels == self.tokenizer.pad_token_id] = -100
  165. data["labels"] = labels
  166. return BatchFeature(data=data, tensor_type=return_tensors)
  167. @property
  168. def model_input_names(self) -> list[str]:
  169. tok_names = self.tokenizer.model_input_names
  170. fea_names = self.feature_extractor.model_input_names
  171. return list(dict.fromkeys(tok_names + fea_names + ["input_features_mask"]))
  172. def apply_transcription_request(
  173. self,
  174. audio: str | list[str] | AudioInput,
  175. prompt: str | list[str] | None = None,
  176. **kwargs: Unpack[AudioFlamingo3ProcessorKwargs],
  177. ) -> BatchFeature:
  178. """
  179. Prepare inputs for automatic speech recognition without manually writing the default transcription prompt.
  180. Args:
  181. audio (`str`, `list[str]`, `np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
  182. Audio to transcribe. Strings are interpreted as local paths or URLs and will be loaded automatically by
  183. the chat template loader; NumPy arrays and PyTorch tensors are forwarded directly.
  184. prompt (`str` or `list[str]`, *optional*):
  185. Custom prompt(s) to include in the user turn. A list must be the same length as the batch. When `None`,
  186. each sample uses `"Transcribe the input speech."`.
  187. **kwargs:
  188. Additional keyword arguments forwarded to [`~AudioFlamingo3Processor.apply_chat_template`] (for example
  189. `text_kwargs`, `audio_kwargs`, ...).
  190. Returns:
  191. [`BatchFeature`]: Processor outputs ready to be passed to [`AudioFlamingo3ForConditionalGeneration.generate`].
  192. """
  193. if isinstance(audio, str):
  194. audio_items: list[str | np.ndarray] = [audio]
  195. elif isinstance(audio, (list, tuple)) and audio and all(isinstance(el, str) for el in audio):
  196. audio_items = list(audio)
  197. else:
  198. audio_items = list(make_list_of_audio(audio))
  199. if is_torch_available():
  200. audio_items = [el.detach().cpu().numpy() if isinstance(el, torch.Tensor) else el for el in audio_items]
  201. batch_size = len(audio_items)
  202. if batch_size == 0:
  203. raise ValueError("`audio` must contain at least one sample.")
  204. if prompt is None:
  205. prompts = [self.default_transcription_prompt] * batch_size
  206. elif isinstance(prompt, str):
  207. prompts = [prompt] * batch_size
  208. elif isinstance(prompt, (list, tuple)):
  209. if len(prompt) != batch_size:
  210. raise ValueError(
  211. f"Received {len(prompt)} prompt(s) for {batch_size} audio sample(s); counts must match."
  212. )
  213. prompts = []
  214. for item in prompt:
  215. if item is None:
  216. prompts.append(self.default_transcription_prompt)
  217. elif isinstance(item, str):
  218. prompts.append(item)
  219. else:
  220. raise TypeError("Each prompt must be a string or `None`.")
  221. else:
  222. raise TypeError("`prompt` must be a string, a sequence of strings, or `None`.")
  223. conversations = [
  224. [
  225. {
  226. "role": "user",
  227. "content": [
  228. {"type": "text", "text": prompt_text},
  229. {"type": "audio", "path": audio_item}
  230. if isinstance(audio_item, str)
  231. else {"type": "audio", "audio": audio_item},
  232. ],
  233. }
  234. ]
  235. for prompt_text, audio_item in zip(prompts, audio_items)
  236. ]
  237. return self.apply_chat_template(
  238. conversations,
  239. tokenize=True,
  240. add_generation_prompt=True,
  241. return_dict=True,
  242. **kwargs,
  243. )
  244. def decode(self, *args, strip_prefix=False, **kwargs):
  245. """
  246. Forward arguments to [`~PreTrainedTokenizer.decode`] and optionally remove the assistant framing the model
  247. was trained to produce.
  248. AF3 transcription requests respond with sentences such as `"The spoken content of the audio is \"...\"."`.
  249. Setting `strip_prefix=True` trims the fixed prefix for just the transcription text.
  250. """
  251. decoded = self.tokenizer.decode(*args, **kwargs)
  252. if strip_prefix:
  253. decoded = [self._strip_assistant_prefix_and_quotes(text) for text in decoded]
  254. return decoded
  255. def batch_decode(self, *args, **kwargs):
  256. """BC as previous examples used batch_decode"""
  257. return self.decode(*args, **kwargs)
  258. def _strip_assistant_prefix_and_quotes(self, text: str) -> str:
  259. """
  260. Remove the assistant prefix and surrounding quotes from a decoded transcription string.
  261. """
  262. stripped = text.strip()
  263. for prefix in (
  264. "The spoken content of the audio is",
  265. "The transcription of the audio is",
  266. "The content of the input audio is",
  267. ):
  268. if stripped.startswith(prefix):
  269. stripped = stripped[len(prefix) :].strip()
  270. break
  271. if stripped.endswith("."):
  272. stripped = stripped[:-1].strip()
  273. if len(stripped) >= 2 and stripped[0] == stripped[-1] and stripped[0] in {"'", '"'}:
  274. stripped = stripped[1:-1].strip()
  275. return stripped
  276. __all__ = ["AudioFlamingo3Processor"]