processing_glmasr.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/glmasr/modular_glmasr.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_glmasr.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 the HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import re
  21. import numpy as np
  22. from ...audio_utils import AudioInput, make_list_of_audio
  23. from ...feature_extraction_utils import BatchFeature
  24. from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
  25. from ...tokenization_utils_base import TextInput
  26. from ...utils import is_torch_available, logging
  27. if is_torch_available():
  28. import torch
  29. logger = logging.get_logger(__name__)
  30. class GlmAsrProcessorKwargs(ProcessingKwargs, total=False):
  31. _defaults = {
  32. "text_kwargs": {
  33. "padding": True,
  34. },
  35. "audio_kwargs": {
  36. "sampling_rate": 16000,
  37. "return_attention_mask": True,
  38. "padding": "max_length",
  39. },
  40. "common_kwargs": {
  41. "return_tensors": "pt",
  42. "padding_side": "left",
  43. },
  44. }
  45. class GlmAsrProcessor(ProcessorMixin):
  46. r"""
  47. Constructs an GlmAsr processor which wraps an GlmAsr feature extractor and an GlmAsr
  48. tokenizer into a single processor.
  49. [`GlmAsrProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`] and
  50. [`Qwen2TokenizerFast`]. See the [`~GlmAsrProcessor.__call__`] for more information.
  51. Args:
  52. feature_extractor ([`WhisperFeatureExtractor`]):
  53. The feature extractor is a required input.
  54. tokenizer ([`Qwen2TokenizerFast`]):
  55. The tokenizer is a required input.
  56. chat_template (`Optional[str]`, *optional*):
  57. The Jinja template to use for formatting the conversation. If not provided, the tokenizer's default chat
  58. template will be used.
  59. audio_token (`Optional[str]`, *optional*, defaults to `"<|pad|>`"):
  60. Special token used to represent audio inputs in the chat template.
  61. default_transcription_prompt (`str`, *optional*, defaults to `"Please transcribe this audio into text"`):
  62. Default prompt to use for transcription tasks when applying transcription requests.
  63. max_audio_len (`int`, *optional*, defaults to 655):
  64. Maximum length of audio sequences in seconds. Audio longer than this will be truncated.
  65. 655 gives approximately 8192 tokens, corresponding to the maximum sequence length of the text model.
  66. """
  67. def __init__(
  68. self,
  69. feature_extractor,
  70. tokenizer,
  71. chat_template=None,
  72. audio_token="<|pad|>",
  73. default_transcription_prompt="Please transcribe this audio into text",
  74. max_audio_len=655,
  75. ):
  76. self.audio_token = audio_token
  77. self.audio_token_id = tokenizer.convert_tokens_to_ids(audio_token)
  78. self.default_transcription_prompt = default_transcription_prompt
  79. self.max_audio_len = max_audio_len
  80. super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
  81. def _get_audio_token_length(self, audio_lengths: "torch.Tensor") -> "torch.Tensor":
  82. merge_factor = 4
  83. for padding, kernel_size, stride in [(1, 3, 1), (1, 3, 2)]:
  84. audio_lengths = (audio_lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
  85. num_tokens = (audio_lengths - merge_factor) // merge_factor + 1
  86. return num_tokens
  87. def _expand_audio_tokens(self, text, padding_mask, per_sample_windows):
  88. audio_lengths = torch.stack([s.sum() for s in torch.split(padding_mask.sum(-1), per_sample_windows)])
  89. audio_tokens_lengths = self._get_audio_token_length(audio_lengths)
  90. audio_token_pattern = re.compile(re.escape(self.audio_token))
  91. for i, audio_length in enumerate(audio_tokens_lengths):
  92. text[i] = audio_token_pattern.sub(self.audio_token * audio_length, text[i])
  93. return text
  94. def _get_audio_tokens_mask(self, input_ids):
  95. return input_ids == self.audio_token_id
  96. def __call__(
  97. self,
  98. text: TextInput | list[TextInput],
  99. audio: AudioInput | None = None,
  100. output_labels: bool | None = False,
  101. **kwargs: Unpack[GlmAsrProcessorKwargs],
  102. ) -> BatchFeature:
  103. r"""
  104. Main method to prepare one or several text sequence(s) and audio waveform(s) for the model. This
  105. method expands `<sound>` placeholders in the text based on the post-pool frame counts of the
  106. audio windows, then tokenizes the provided strings as-is, and extracts log-mel features
  107. with [`WhisperFeatureExtractor`]. If `audio` is `None`, no audio processing is performed and
  108. the text is tokenized as-is (LM-only behavior).
  109. Args:
  110. text (`str` or `list[str]`):
  111. Input sequence or batch of sequences.
  112. audio (`np.ndarray` or `list[np.ndarray]`):
  113. Input audio or batch of audios as NumPy arrays. If provided, there must be as many `text` inputs as
  114. `audio` inputs.
  115. output_labels (bool, *optional*, default=False):
  116. Whether to return labels for training.
  117. Returns:
  118. [`BatchFeature`]: A dictionary with tokenized text (`input_ids`, `attention_mask`) and
  119. audio features (`input_features`, `input_features_mask`).
  120. """
  121. # Merge defaults with user kwargs
  122. call_kwargs = self._merge_kwargs(
  123. GlmAsrProcessorKwargs,
  124. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  125. **kwargs,
  126. )
  127. text_kwargs = call_kwargs["text_kwargs"]
  128. audio_kwargs = call_kwargs["audio_kwargs"]
  129. return_tensors = text_kwargs.get("return_tensors")
  130. if return_tensors != "pt":
  131. raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
  132. if isinstance(text, str):
  133. text = [text]
  134. elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
  135. raise ValueError("Invalid input text. Please provide a string, or a list of strings")
  136. audio_inputs = {}
  137. if audio is not None:
  138. audio = make_list_of_audio(audio)
  139. if len(text) != len(audio):
  140. raise ValueError(f"Got {len(text)} text but {len(audio)} audios; they must match 1:1.")
  141. # Determine number of chunks per sample, and flatten
  142. window_size = int(audio_kwargs["sampling_rate"] * self.feature_extractor.chunk_length)
  143. max_windows = int(self.max_audio_len // self.feature_extractor.chunk_length)
  144. per_sample_windows: list[int] = []
  145. flat_chunks: list[np.ndarray] = []
  146. for audio_el in audio:
  147. n_samples = int(audio_el.shape[0])
  148. n_win = max(1, (n_samples + window_size - 1) // window_size)
  149. if n_win > max_windows:
  150. logger.warning(
  151. f"Audio duration ({n_samples / audio_kwargs['sampling_rate']:.1f}s) exceeds {self.max_audio_len}s; truncating to first {self.max_audio_len}s."
  152. )
  153. n_win = max_windows
  154. per_sample_windows.append(n_win)
  155. time_cap = min(n_samples, n_win * window_size)
  156. for i in range(n_win):
  157. start = i * window_size
  158. end = min((i + 1) * window_size, time_cap)
  159. flat_chunks.append(audio_el[start:end])
  160. # Feature extraction
  161. audio_inputs = self.feature_extractor(flat_chunks, **audio_kwargs)
  162. padding_mask = audio_inputs.pop("attention_mask")
  163. audio_inputs["input_features_mask"] = padding_mask
  164. # Expand audio tokens in text
  165. text = self._expand_audio_tokens(text, padding_mask, per_sample_windows)
  166. # Tokenize
  167. text_inputs = self.tokenizer(text, **text_kwargs)
  168. data = {**text_inputs, **audio_inputs}
  169. if output_labels:
  170. labels = data["input_ids"].clone()
  171. labels[self._get_audio_tokens_mask(labels)] = -100
  172. labels[labels == self.tokenizer.pad_token_id] = -100
  173. data["labels"] = labels
  174. return BatchFeature(data=data, tensor_type=return_tensors)
  175. @property
  176. def model_input_names(self) -> list[str]:
  177. tok_names = self.tokenizer.model_input_names
  178. fea_names = self.feature_extractor.model_input_names
  179. return list(dict.fromkeys(tok_names + fea_names + ["input_features_mask"]))
  180. def apply_transcription_request(
  181. self,
  182. audio: str | list[str] | AudioInput,
  183. prompt: str | list[str] | None = None,
  184. **kwargs: Unpack[GlmAsrProcessorKwargs],
  185. ) -> BatchFeature:
  186. """
  187. Prepare inputs for automatic speech recognition without manually writing the default transcription prompt.
  188. Args:
  189. audio (`str`, `list[str]`, `np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
  190. Audio to transcribe. Strings are interpreted as local paths or URLs and will be loaded automatically by
  191. the chat template loader; NumPy arrays and PyTorch tensors are forwarded directly.
  192. prompt (`str` or `list[str]`, *optional*):
  193. Custom prompt(s) to include in the user turn. A list must be the same length as the batch. When `None`,
  194. each sample uses `"Transcribe the input speech."`.
  195. **kwargs:
  196. Additional keyword arguments forwarded to [`~GlmAsrProcessor.apply_chat_template`] (for example
  197. `text_kwargs`, `audio_kwargs`, ...).
  198. Returns:
  199. [`BatchFeature`]: Processor outputs ready to be passed to [`GlmAsrForConditionalGeneration.generate`].
  200. """
  201. if isinstance(audio, str):
  202. audio_items: list[str | np.ndarray] = [audio]
  203. elif isinstance(audio, (list, tuple)) and audio and all(isinstance(el, str) for el in audio):
  204. audio_items = list(audio)
  205. else:
  206. audio_items = list(make_list_of_audio(audio))
  207. if is_torch_available():
  208. audio_items = [el.detach().cpu().numpy() if isinstance(el, torch.Tensor) else el for el in audio_items]
  209. batch_size = len(audio_items)
  210. if batch_size == 0:
  211. raise ValueError("`audio` must contain at least one sample.")
  212. if prompt is None:
  213. prompts = [self.default_transcription_prompt] * batch_size
  214. elif isinstance(prompt, str):
  215. prompts = [prompt] * batch_size
  216. elif isinstance(prompt, (list, tuple)):
  217. if len(prompt) != batch_size:
  218. raise ValueError(
  219. f"Received {len(prompt)} prompt(s) for {batch_size} audio sample(s); counts must match."
  220. )
  221. prompts = []
  222. for item in prompt:
  223. if item is None:
  224. prompts.append(self.default_transcription_prompt)
  225. elif isinstance(item, str):
  226. prompts.append(item)
  227. else:
  228. raise TypeError("Each prompt must be a string or `None`.")
  229. else:
  230. raise TypeError("`prompt` must be a string, a sequence of strings, or `None`.")
  231. conversations = [
  232. [
  233. {
  234. "role": "user",
  235. "content": [
  236. {"type": "audio", "path": audio_item}
  237. if isinstance(audio_item, str)
  238. else {"type": "audio", "audio": audio_item},
  239. {"type": "text", "text": prompt_text},
  240. ],
  241. }
  242. ]
  243. for prompt_text, audio_item in zip(prompts, audio_items)
  244. ]
  245. return self.apply_chat_template(
  246. conversations,
  247. tokenize=True,
  248. add_generation_prompt=True,
  249. return_dict=True,
  250. **kwargs,
  251. )
  252. def decode(self, *args, strip_prefix=False, **kwargs):
  253. """
  254. Forward arguments to [`~PreTrainedTokenizer.decode`] and optionally remove the assistant framing the model
  255. was trained to produce.
  256. AF3 transcription requests respond with sentences such as `"The spoken content of the audio is \"...\"."`.
  257. Setting `strip_prefix=True` trims the fixed prefix for just the transcription text.
  258. """
  259. decoded = self.tokenizer.decode(*args, **kwargs)
  260. if strip_prefix:
  261. decoded = [self._strip_assistant_prefix_and_quotes(text) for text in decoded]
  262. return decoded
  263. def batch_decode(self, *args, **kwargs):
  264. """BC as previous examples used batch_decode"""
  265. return self.decode(*args, **kwargs)
  266. def _strip_assistant_prefix_and_quotes(self, text: str) -> str:
  267. """
  268. Remove the assistant prefix and surrounding quotes from a decoded transcription string.
  269. """
  270. stripped = text.strip()
  271. for prefix in (
  272. "The spoken content of the audio is",
  273. "The transcription of the audio is",
  274. "The content of the input audio is",
  275. ):
  276. if stripped.startswith(prefix):
  277. stripped = stripped[len(prefix) :].strip()
  278. break
  279. if stripped.endswith("."):
  280. stripped = stripped[:-1].strip()
  281. if len(stripped) >= 2 and stripped[0] == stripped[-1] and stripped[0] in {"'", '"'}:
  282. stripped = stripped[1:-1].strip()
  283. return stripped
  284. __all__ = ["GlmAsrProcessor"]