processing_parakeet.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ...audio_utils import AudioInput, make_list_of_audio
  15. from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
  16. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  17. from ...utils import auto_docstring, logging
  18. logger = logging.get_logger(__name__)
  19. class ParakeetProcessorKwargs(ProcessingKwargs, total=False):
  20. _defaults = {
  21. "audio_kwargs": {
  22. "sampling_rate": 16000,
  23. "padding": "longest",
  24. "return_attention_mask": True,
  25. },
  26. "text_kwargs": {
  27. "padding": True,
  28. "padding_side": "right",
  29. "add_special_tokens": False,
  30. },
  31. "common_kwargs": {"return_tensors": "pt"},
  32. }
  33. @auto_docstring
  34. class ParakeetProcessor(ProcessorMixin):
  35. def __init__(self, feature_extractor, tokenizer):
  36. super().__init__(feature_extractor, tokenizer)
  37. @auto_docstring
  38. def __call__(
  39. self,
  40. audio: AudioInput,
  41. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
  42. sampling_rate: int | None = None,
  43. **kwargs: Unpack[ParakeetProcessorKwargs],
  44. ):
  45. r"""
  46. sampling_rate (`int`, *optional*):
  47. The sampling rate of the input audio in Hz. This should match the sampling rate expected by the feature
  48. extractor (defaults to 16000 Hz). If provided, it will be validated against the processor's expected
  49. sampling rate, and an error will be raised if they don't match. If not provided, a warning will be
  50. issued and the default sampling rate will be assumed.
  51. """
  52. audio = make_list_of_audio(audio)
  53. output_kwargs = self._merge_kwargs(
  54. ParakeetProcessorKwargs,
  55. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  56. **kwargs,
  57. )
  58. if sampling_rate is None:
  59. logger.warning_once(
  60. f"You've provided audio without specifying the sampling rate. It will be assumed to be {output_kwargs['audio_kwargs']['sampling_rate']}, which can result in silent errors."
  61. )
  62. elif sampling_rate != output_kwargs["audio_kwargs"]["sampling_rate"]:
  63. raise ValueError(
  64. f"The sampling rate of the audio ({sampling_rate}) does not match the sampling rate of the processor ({output_kwargs['audio_kwargs']['sampling_rate']}). Please provide resampled the audio to the expected sampling rate."
  65. )
  66. if audio is not None:
  67. inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
  68. if text is not None:
  69. encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
  70. if text is None:
  71. return inputs
  72. else:
  73. inputs["labels"] = encodings["input_ids"]
  74. return inputs
  75. @property
  76. def model_input_names(self):
  77. feature_extractor_input_names = self.feature_extractor.model_input_names
  78. return feature_extractor_input_names + ["labels"]
  79. __all__ = ["ParakeetProcessor"]