processing_musicflamingo.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/musicflamingo/modular_musicflamingo.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_musicflamingo.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2026 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
  8. # reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import re
  22. import numpy as np
  23. from ...audio_utils import AudioInput, make_list_of_audio
  24. from ...feature_extraction_utils import BatchFeature
  25. from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
  26. from ...tokenization_utils_base import TextInput
  27. from ...utils import is_torch_available, logging
  28. if is_torch_available():
  29. import torch
  30. logger = logging.get_logger(__name__)
  31. class MusicFlamingoProcessorKwargs(ProcessingKwargs, total=False):
  32. _defaults = {
  33. "text_kwargs": {
  34. "padding": True,
  35. },
  36. "audio_kwargs": {
  37. "sampling_rate": 16000,
  38. "return_attention_mask": True,
  39. "padding": "max_length",
  40. },
  41. "common_kwargs": {
  42. "return_tensors": "pt",
  43. "padding_side": "left",
  44. },
  45. }
  46. class MusicFlamingoProcessor(ProcessorMixin):
  47. r"""
  48. Constructs an MusicFlamingo processor which wraps an MusicFlamingo feature extractor and an MusicFlamingo
  49. tokenizer into a single processor.
  50. [`MusicFlamingoProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`] and
  51. [`Qwen2TokenizerFast`]. See the [`~MusicFlamingoProcessor.__call__`] for more information.
  52. Args:
  53. feature_extractor ([`WhisperFeatureExtractor`]):
  54. The feature extractor is a required input.
  55. tokenizer ([`Qwen2TokenizerFast`]):
  56. The tokenizer is a required input.
  57. chat_template (`Optional[str]`, *optional*):
  58. The Jinja template to use for formatting the conversation. If not provided, the tokenizer's default chat
  59. template will be used.
  60. audio_token (`Optional[str]`, *optional*, defaults to `"<sound>"`):
  61. Special token used to represent audio inputs in the chat template.
  62. audio_bos_token (`Optional[str]`, *optional*, defaults to `"<|sound_bos|>"`):
  63. Special token used to represent the beginning of audio.
  64. audio_eos_token (`Optional[str]`, *optional*, defaults to `"<|sound_eos|>"`):
  65. Special token used to represent the end of audio.
  66. max_audio_len (`int`, *optional*, defaults to 1200):
  67. Maximum length of audio sequences in seconds. Audio longer than this will be truncated.
  68. """
  69. def __init__(
  70. self,
  71. feature_extractor,
  72. tokenizer,
  73. chat_template=None,
  74. audio_token="<sound>",
  75. audio_bos_token="<|sound_bos|>",
  76. audio_eos_token="<|sound_eos|>",
  77. max_audio_len=1200,
  78. ):
  79. self.audio_token = audio_token
  80. self.audio_token_id = tokenizer.convert_tokens_to_ids(audio_token)
  81. self.max_audio_len = max_audio_len
  82. super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
  83. self.audio_bos_token = audio_bos_token
  84. self.audio_eos_token = audio_eos_token
  85. self.audio_bos_token_id = tokenizer.convert_tokens_to_ids(audio_bos_token)
  86. self.audio_eos_token_id = tokenizer.convert_tokens_to_ids(audio_eos_token)
  87. def _get_audio_token_length(self, audio_lengths):
  88. conv_output_lengths = (audio_lengths - 1) // 2 + 1 # After conv2 downsampling
  89. audio_tokens_lengths = (conv_output_lengths - 2) // 2 + 1 # After avg pooling
  90. return audio_tokens_lengths
  91. def _expand_audio_tokens(self, text, padding_mask, per_sample_windows):
  92. audio_lengths = torch.stack([s.sum() for s in torch.split(padding_mask.sum(-1), per_sample_windows)])
  93. audio_tokens_lengths = self._get_audio_token_length(audio_lengths)
  94. audio_token_pattern = re.compile(re.escape(self.audio_token))
  95. for i, audio_length in enumerate(audio_tokens_lengths):
  96. text[i] = audio_token_pattern.sub(
  97. self.audio_bos_token + self.audio_token * audio_length + self.audio_eos_token,
  98. text[i],
  99. )
  100. return text
  101. def _get_audio_tokens_mask(self, input_ids):
  102. return (
  103. (input_ids == self.audio_token_id)
  104. | (input_ids == self.audio_bos_token_id)
  105. | (input_ids == self.audio_eos_token_id)
  106. )
  107. def __call__(
  108. self,
  109. text: TextInput | list[TextInput],
  110. audio: AudioInput | None = None,
  111. output_labels: bool | None = False,
  112. **kwargs: Unpack[MusicFlamingoProcessorKwargs],
  113. ) -> BatchFeature:
  114. r"""
  115. Main method to prepare one or several text sequence(s) and audio waveform(s) for the model. This
  116. method expands `<sound>` placeholders in the text based on the post-pool frame counts of the
  117. audio windows, then tokenizes the provided strings as-is, and extracts log-mel features
  118. with [`WhisperFeatureExtractor`]. If `audio` is `None`, no audio processing is performed and
  119. the text is tokenized as-is (LM-only behavior).
  120. Args:
  121. text (`str` or `list[str]`):
  122. Input sequence or batch of sequences.
  123. audio (`np.ndarray` or `list[np.ndarray]`):
  124. Input audio or batch of audios as NumPy arrays. If provided, there must be as many `text` inputs as
  125. `audio` inputs.
  126. output_labels (bool, *optional*, default=False):
  127. Whether to return labels for training.
  128. Returns:
  129. [`BatchFeature`]: A dictionary with tokenized text (`input_ids`, `attention_mask`) and
  130. audio features (`input_features`, `input_features_mask`).
  131. """
  132. # Merge defaults with user kwargs
  133. call_kwargs = self._merge_kwargs(
  134. MusicFlamingoProcessorKwargs,
  135. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  136. **kwargs,
  137. )
  138. text_kwargs = call_kwargs["text_kwargs"]
  139. audio_kwargs = call_kwargs["audio_kwargs"]
  140. return_tensors = text_kwargs.get("return_tensors")
  141. if return_tensors != "pt":
  142. raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
  143. if isinstance(text, str):
  144. text = [text]
  145. elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
  146. raise ValueError("Invalid input text. Please provide a string, or a list of strings")
  147. audio_inputs = {}
  148. if audio is not None:
  149. audio = make_list_of_audio(audio)
  150. if len(text) != len(audio):
  151. raise ValueError(f"Got {len(text)} text but {len(audio)} audios; they must match 1:1.")
  152. # Determine number of chunks per sample, and flatten
  153. window_size = int(audio_kwargs["sampling_rate"] * self.feature_extractor.chunk_length)
  154. max_windows = int(self.max_audio_len // self.feature_extractor.chunk_length)
  155. per_sample_windows: list[int] = []
  156. flat_chunks: list[np.ndarray] = []
  157. for audio_el in audio:
  158. n_samples = int(audio_el.shape[0])
  159. n_win = max(1, (n_samples + window_size - 1) // window_size)
  160. if n_win > max_windows:
  161. logger.warning(
  162. f"Audio duration ({n_samples / audio_kwargs['sampling_rate']:.1f}s) exceeds {self.max_audio_len}s; truncating to first {self.max_audio_len}s."
  163. )
  164. n_win = max_windows
  165. per_sample_windows.append(n_win)
  166. time_cap = min(n_samples, n_win * window_size)
  167. for i in range(n_win):
  168. start = i * window_size
  169. end = min((i + 1) * window_size, time_cap)
  170. flat_chunks.append(audio_el[start:end])
  171. # Feature extraction
  172. audio_inputs = self.feature_extractor(flat_chunks, **audio_kwargs)
  173. padding_mask = audio_inputs.pop("attention_mask")
  174. audio_inputs["input_features_mask"] = padding_mask
  175. # Expand audio tokens in text
  176. text = self._expand_audio_tokens(text, padding_mask, per_sample_windows)
  177. # Tokenize
  178. text_inputs = self.tokenizer(text, **text_kwargs)
  179. data = {**text_inputs, **audio_inputs}
  180. if output_labels:
  181. labels = data["input_ids"].clone()
  182. labels[self._get_audio_tokens_mask(labels)] = -100
  183. labels[labels == self.tokenizer.pad_token_id] = -100
  184. data["labels"] = labels
  185. return BatchFeature(data=data, tensor_type=return_tensors)
  186. @property
  187. def model_input_names(self) -> list[str]:
  188. tok_names = self.tokenizer.model_input_names
  189. fea_names = self.feature_extractor.model_input_names
  190. return list(dict.fromkeys(tok_names + fea_names + ["input_features_mask"]))
  191. __all__ = ["MusicFlamingoProcessor"]