processing_voxtral.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import io
  15. from ...utils import auto_docstring, is_mistral_common_available, is_soundfile_available, is_torch_available, logging
  16. if is_torch_available():
  17. import torch
  18. if is_soundfile_available():
  19. import soundfile as sf
  20. if is_mistral_common_available():
  21. from mistral_common.protocol.transcription.request import TranscriptionRequest
  22. from ...audio_utils import AudioInput, load_audio_as, make_list_of_audio
  23. from ...feature_extraction_utils import BatchFeature
  24. from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
  25. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  26. from ...utils.chat_template_utils import _get_template_variables
  27. logger = logging.get_logger(__name__)
  28. class VoxtralAudioKwargs(AudioKwargs, total=False):
  29. """
  30. max_source_positions (`int`, *optional*, defaults to `3000`):
  31. Maximum number of positions per chunk when splitting mel spectrogram features along the time dimension.
  32. """
  33. max_source_positions: int | None
  34. class VoxtralProcessorKwargs(ProcessingKwargs, total=False):
  35. audio_kwargs: VoxtralAudioKwargs
  36. _defaults = {
  37. "text_kwargs": {
  38. "padding": True,
  39. },
  40. "audio_kwargs": {
  41. "sampling_rate": 16000,
  42. "padding": True,
  43. "truncation": False,
  44. "pad_to_multiple_of": 480000,
  45. "max_source_positions": 3000,
  46. },
  47. "common_kwargs": {
  48. "return_tensors": "pt",
  49. "return_dict": True,
  50. "tokenize": True,
  51. },
  52. }
  53. @auto_docstring
  54. class VoxtralProcessor(ProcessorMixin):
  55. def __init__(
  56. self,
  57. feature_extractor,
  58. tokenizer,
  59. ):
  60. self.audio_token_id = 24
  61. self.audio_token = tokenizer.convert_ids_to_tokens(self.audio_token_id)
  62. super().__init__(feature_extractor, tokenizer)
  63. def _retrieve_input_features(self, audio, max_source_positions, **kwargs):
  64. """
  65. Handles specific logic of Voxtral expected input features: audio arrays should be padded to next multiple of 480000 (duration is a multiple of 30s), see VoxtralProcessorKwargs' default audio_kwargs.
  66. Then mel input features are extracted and stacked along batch dimension, splitting into chunks of max_source_positions.
  67. """
  68. input_features_list = []
  69. for audio_array in audio:
  70. audio_inputs = self.feature_extractor(audio_array, **kwargs)
  71. # let's split into chunks of max_source_positions, and then stack them along batch dimension
  72. input_features = audio_inputs["input_features"].reshape(
  73. self.feature_extractor.feature_size, -1, max_source_positions
  74. )
  75. input_features_list.append(input_features.transpose(0, 1))
  76. return torch.cat(input_features_list)
  77. def apply_chat_template(
  78. self,
  79. conversation: list[dict[str, str]] | list[list[dict[str, str]]],
  80. chat_template: str | None = None,
  81. tools: list[dict] | None = None,
  82. documents: list[dict[str, str]] | None = None,
  83. add_generation_prompt: bool = False,
  84. continue_final_message: bool = False,
  85. return_assistant_tokens_mask: bool = False,
  86. tokenize: bool = False,
  87. return_tensors: str | None = None,
  88. return_dict: bool = False,
  89. load_audio_from_video: bool = False,
  90. processor_kwargs: dict | None = None,
  91. **kwargs,
  92. ) -> str:
  93. """
  94. This method applies the model's chat completion template given a conversation. It relies on MistralCommonBackend's
  95. [`~MistralCommonBackend.apply_chat_template`] to prepare input ids to the model and on WhisperFeatureExtractor's
  96. [`~WhisperFeatureExtractor.__call__`] to prepare input features to the model.
  97. Note that audio is padded to the nearest 30-second multiple prior to mel feature extraction.
  98. A `conversation` is a list of messages, where each message is a dictionary with a `role` and a `content` field.
  99. For Voxtral, `role` can be `"user"` or `"assistant"`.
  100. The `content` field can be a string or a list of dictionaries with a `type` field. See example below.
  101. ```python
  102. from huggingface_hub import hf_hub_download
  103. from transformers.audio_utils import load_audio_as
  104. audio_url = "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3"
  105. audio_path = hf_hub_download(repo_id="hf-internal-testing/dummy-audio-samples", filename="bcn_weather.mp3", repo_type="dataset")
  106. audio_base64 = load_audio_as(audio_path, return_format="base64", force_mono=True)
  107. # audio + text
  108. conversation = [
  109. {
  110. "role": "user",
  111. "content": [
  112. {"type": "audio", "url": audio_url},
  113. {"type": "audio", "path": audio_path},
  114. {"type": "audio", "base64": audio_base64},
  115. {"type": "text", "text": "How many audio do you hear?"},
  116. ],
  117. },
  118. ]
  119. processor = VoxtralProcessor.from_pretrained("mistralai/Voxtral-Mini-3B-2507")
  120. inputs = processor.apply_chat_template(conversation)
  121. ```
  122. Args:
  123. conversation (`Union[list[Dict, [str, str]], list[list[dict[str, str]]]]`):
  124. The conversation to format.
  125. """
  126. if continue_final_message:
  127. if add_generation_prompt:
  128. raise ValueError(
  129. "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
  130. )
  131. if return_assistant_tokens_mask:
  132. raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
  133. if isinstance(conversation, (list, tuple)) and (
  134. isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
  135. ):
  136. is_batched = True
  137. conversations = conversation
  138. else:
  139. is_batched = False
  140. conversations = [conversation]
  141. # Users might still be passing processing kwargs in `**kwargs` so we need to filter
  142. # out additional kwargs that the template expects via Jinja2 template introspection
  143. # We strip unrelated kwargs to avoid passing unrecognized kwargs to `_merge_kwargs`.
  144. processor_kwargs = processor_kwargs or {}
  145. template_kwargs = _get_template_variables(chat_template)
  146. processor_kwargs_from_kwargs = {k: v for k, v in kwargs.items() if k not in template_kwargs}
  147. if processor_kwargs_from_kwargs:
  148. logger.warning(
  149. "Kwargs passed to `processor.__call__` have to be in `processor_kwargs` dict, not in `**kwargs`"
  150. )
  151. processor_kwargs = processor_kwargs_from_kwargs
  152. if return_tensors:
  153. processor_kwargs["return_tensors"] = return_tensors
  154. output_kwargs = self._merge_kwargs(
  155. VoxtralProcessorKwargs,
  156. **processor_kwargs,
  157. )
  158. text_kwargs = output_kwargs["text_kwargs"]
  159. audio_kwargs = output_kwargs["audio_kwargs"]
  160. return_tensors = text_kwargs.get("return_tensors", None)
  161. if return_tensors != "pt":
  162. raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
  163. tokenizer_kwargs = output_kwargs["text_kwargs"]
  164. tokenizer_kwargs["return_tensors"] = None # let's not return tensors here
  165. encoded_instruct_inputs = self.tokenizer.apply_chat_template(conversations, **tokenizer_kwargs)
  166. if text_kwargs.get("tokenize", False):
  167. if text_kwargs.get("return_dict", False):
  168. audio = encoded_instruct_inputs.pop("audio", None)
  169. data = dict(encoded_instruct_inputs)
  170. if audio is not None:
  171. max_source_positions = audio_kwargs.pop("max_source_positions")
  172. data["input_features"] = self._retrieve_input_features(audio, max_source_positions, **audio_kwargs)
  173. return BatchFeature(data=data, tensor_type=return_tensors)
  174. if not is_batched:
  175. return encoded_instruct_inputs[0]
  176. return encoded_instruct_inputs
  177. @auto_docstring(
  178. custom_intro=r"""
  179. Method to prepare text to be fed as input to the model. This method forwards the `text`
  180. arguments to MistralCommonBackend's [`~MistralCommonBackend.__call__`] to encode
  181. the text. Please refer to the docstring of the above methods for more information.
  182. This method does not support audio. To prepare the audio, please use:
  183. 1. `apply_chat_template` [`~VoxtralProcessor.apply_chat_template`] method.
  184. 2. `apply_transcription_request` [`~VoxtralProcessor.apply_transcription_request`] method.
  185. """
  186. )
  187. def __call__(
  188. self,
  189. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None,
  190. **kwargs: Unpack[VoxtralProcessorKwargs],
  191. ):
  192. if isinstance(text, str):
  193. text = [text]
  194. if any(self.audio_token in t for t in text):
  195. raise ValueError(
  196. f"{self.audio_token} is present in the provided text which is not supported by VoxtralProcessor. Please use the `apply_chat_template` method instead."
  197. )
  198. output_kwargs = self._merge_kwargs(VoxtralProcessorKwargs, **kwargs)
  199. out = self.tokenizer(text, **output_kwargs["text_kwargs"])
  200. return BatchFeature(data=out, tensor_type=output_kwargs["text_kwargs"].get("return_tensors", None))
  201. # TODO: @eustlb, this should be moved to mistral_common + testing
  202. def apply_transcription_request(
  203. self,
  204. audio: str | list[str] | AudioInput,
  205. model_id: str,
  206. language: str | list[str | None] | None = None,
  207. sampling_rate: int | None = None,
  208. format: str | list[str] | None = None,
  209. **kwargs: Unpack[VoxtralProcessorKwargs],
  210. ):
  211. """
  212. This method applies the model's transcription request template given a language and audio.
  213. It relies on MistralCommonBackend and WhisperFeatureExtractor to prepare input ids and input features to the model.
  214. ```python
  215. from transformers import VoxtralProcessor
  216. model_id = "mistralai/Voxtral-Mini-3B-2507"
  217. processor = VoxtralProcessor.from_pretrained(model_id)
  218. language = "en"
  219. audio = "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama.mp3"
  220. # set the language is already know for better accuracy
  221. inputs = processor.apply_transcription_request(language=language, audio=audio, model_id=model_id)
  222. # but you can also let the model detect the language automatically
  223. inputs = processor.apply_transcription_request(audio=audio, model_id=model_id)
  224. ```
  225. Args:
  226. audio (`str`, `list[str]`, `np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
  227. The audio or batch of audio to be prepared. If provided as a string, it should correspond to the path or url of the audio file.
  228. model_id (`str`:
  229. The hub model id of the model to use for transcription.
  230. language (`str`, `list[Union[str, None]]`, *optional*):
  231. The language or languages of the audio.
  232. If not provided or None, automatic language detection will be used for all audio.
  233. If provided as a string (a language code in the [ISO 639-1 alpha-2 format](https://en.wikipedia.org/wiki/ISO_639-1) e.g. `"en"`), it will be applied uniformly to all audio.
  234. If provided as a list of strings/ None values, e.g. `["en", None, "fr"]`, will be applied to each audio individually with a one-to-one mapping,
  235. with a None value indicating automatic language detection for that audio.
  236. sampling_rate (`int`, *optional*):
  237. The sampling rate of the audio. Necessary if it is provided as `np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`.
  238. Used to avoid silent errors when passing audio that is not in the expected sampling rate.
  239. format (`str`, `list[str]`, *optional*):
  240. The format of the audio, necessary if is provided as `np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`.
  241. """
  242. output_kwargs = self._merge_kwargs(
  243. VoxtralProcessorKwargs,
  244. **kwargs,
  245. )
  246. text_kwargs = output_kwargs["text_kwargs"]
  247. audio_kwargs = output_kwargs["audio_kwargs"]
  248. is_str = isinstance(audio, str)
  249. is_list_of_str = all(isinstance(el, str) for el in audio)
  250. is_list_of_audio = not (is_str or is_list_of_str)
  251. if is_list_of_audio:
  252. if sampling_rate is None:
  253. logger.warning_once(
  254. f"You've provided audio without specifying the sampling rate. It will be assumed to be {audio_kwargs['sampling_rate']}, which can result in silent errors."
  255. )
  256. elif sampling_rate != audio_kwargs["sampling_rate"]:
  257. raise ValueError(
  258. f"The sampling rate of the audio ({sampling_rate}) does not match the sampling rate of the processor ({audio_kwargs['sampling_rate']}). Please provide resampled the audio to the expected sampling rate."
  259. )
  260. sampling_rate = audio_kwargs["sampling_rate"]
  261. # make sure to remove from text_kwargs and audio_kwargs
  262. return_dict = text_kwargs.pop("return_dict", False)
  263. tokenize = text_kwargs.pop("tokenize", False)
  264. _ = audio_kwargs.pop("return_dict", False)
  265. _ = audio_kwargs.pop("tokenize", False)
  266. return_tensors = text_kwargs.pop("return_tensors", None)
  267. if return_tensors != "pt":
  268. raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
  269. # validate audio input
  270. if is_str:
  271. audio = [load_audio_as(audio, return_format="buffer", force_mono=True, sampling_rate=sampling_rate)]
  272. elif is_list_of_str:
  273. audio = [
  274. load_audio_as(el, return_format="buffer", force_mono=True, sampling_rate=sampling_rate) for el in audio
  275. ]
  276. else:
  277. audio = make_list_of_audio(audio)
  278. if len(audio) != len(format):
  279. raise ValueError(
  280. f"When passed as a list of audio, the length ({len(audio)}) must match the number of format ({len(format)})"
  281. )
  282. audio_buffers = []
  283. for array, f in zip(audio, format):
  284. # Create new BytesIO object and write audio data to it
  285. buffer = io.BytesIO()
  286. # Convert to mono if needed
  287. if array.ndim == 2:
  288. array = array.mean(axis=1)
  289. # Write to buffer with default format and sampling rate
  290. sf.write(buffer, array, samplerate=audio_kwargs["sampling_rate"], format=f)
  291. buffer.seek(0)
  292. audio_buffers.append(buffer)
  293. audio = audio_buffers
  294. # validate language input
  295. n_audio = len(audio)
  296. if isinstance(language, str):
  297. language = [language] * n_audio
  298. elif language is None:
  299. language = [None] * n_audio
  300. if len(language) != n_audio:
  301. raise ValueError(
  302. f"When passed as a list of languages, the length ({len(language)}) must match the number of audio ({n_audio})"
  303. )
  304. input_ids = []
  305. texts = []
  306. audio_arrays = []
  307. for audio_el, language_el in zip(audio, language):
  308. openai_transcription_request = {
  309. "model": model_id,
  310. "file": audio_el,
  311. "language": language_el,
  312. }
  313. transcription_request = TranscriptionRequest.from_openai(openai_transcription_request)
  314. tokenized_transcription_request = self.tokenizer.tokenizer.encode_transcription(transcription_request)
  315. input_ids.append(tokenized_transcription_request.tokens)
  316. texts.append(tokenized_transcription_request.text)
  317. audio_arrays.extend([el.audio_array for el in tokenized_transcription_request.audios])
  318. if tokenize:
  319. if return_dict:
  320. # text are already tokenized but we need to pad etc
  321. encoding = self.tokenizer(
  322. input_ids,
  323. add_special_tokens=False,
  324. **text_kwargs,
  325. )
  326. data = dict(encoding)
  327. # extract the input features
  328. max_source_positions = audio_kwargs.pop("max_source_positions")
  329. data["input_features"] = self._retrieve_input_features(
  330. audio_arrays, max_source_positions, **audio_kwargs
  331. )
  332. return BatchFeature(data=data, tensor_type=return_tensors)
  333. return texts
  334. __all__ = ["VoxtralProcessor"]