processing_speecht5.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Speech processor class for SpeechT5."""
  15. from ...processing_utils import ProcessorMixin
  16. from ...utils import auto_docstring
  17. @auto_docstring
  18. class SpeechT5Processor(ProcessorMixin):
  19. def __init__(self, feature_extractor, tokenizer):
  20. super().__init__(feature_extractor, tokenizer)
  21. @auto_docstring
  22. def __call__(self, *args, **kwargs):
  23. audio = kwargs.pop("audio", None)
  24. text = kwargs.pop("text", None)
  25. text_target = kwargs.pop("text_target", None)
  26. audio_target = kwargs.pop("audio_target", None)
  27. sampling_rate = kwargs.pop("sampling_rate", None)
  28. if audio is not None and text is not None:
  29. raise ValueError(
  30. "Cannot process both `audio` and `text` inputs. Did you mean `audio_target` or `text_target`?"
  31. )
  32. if audio_target is not None and text_target is not None:
  33. raise ValueError(
  34. "Cannot process both `audio_target` and `text_target` inputs. Did you mean `audio` or `text`?"
  35. )
  36. if audio is None and audio_target is None and text is None and text_target is None:
  37. raise ValueError(
  38. "You need to specify either an `audio`, `audio_target`, `text`, or `text_target` input to process."
  39. )
  40. if audio is not None:
  41. inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
  42. elif text is not None:
  43. inputs = self.tokenizer(text, **kwargs)
  44. else:
  45. inputs = None
  46. if audio_target is not None:
  47. targets = self.feature_extractor(audio_target=audio_target, *args, sampling_rate=sampling_rate, **kwargs)
  48. labels = targets["input_values"]
  49. elif text_target is not None:
  50. targets = self.tokenizer(text_target, **kwargs)
  51. labels = targets["input_ids"]
  52. else:
  53. targets = None
  54. if inputs is None:
  55. return targets
  56. if targets is not None:
  57. inputs["labels"] = labels
  58. decoder_attention_mask = targets.get("attention_mask")
  59. if decoder_attention_mask is not None:
  60. inputs["decoder_attention_mask"] = decoder_attention_mask
  61. return inputs
  62. def pad(self, *args, **kwargs):
  63. """
  64. Collates the audio and text inputs, as well as their targets, into a padded batch.
  65. Audio inputs are padded by SpeechT5FeatureExtractor's [`~SpeechT5FeatureExtractor.pad`]. Text inputs are padded
  66. by SpeechT5Tokenizer's [`~SpeechT5Tokenizer.pad`].
  67. Valid input combinations are:
  68. - `input_ids` only
  69. - `input_values` only
  70. - `labels` only, either log-mel spectrograms or text tokens
  71. - `input_ids` and log-mel spectrogram `labels`
  72. - `input_values` and text `labels`
  73. Please refer to the docstring of the above two methods for more information.
  74. """
  75. input_values = kwargs.pop("input_values", None)
  76. input_ids = kwargs.pop("input_ids", None)
  77. labels = kwargs.pop("labels", None)
  78. if input_values is not None and input_ids is not None:
  79. raise ValueError("Cannot process both `input_values` and `input_ids` inputs.")
  80. if input_values is None and input_ids is None and labels is None:
  81. raise ValueError(
  82. "You need to specify either an `input_values`, `input_ids`, or `labels` input to be padded."
  83. )
  84. if input_values is not None:
  85. inputs = self.feature_extractor.pad(input_values, *args, **kwargs)
  86. elif input_ids is not None:
  87. inputs = self.tokenizer.pad(input_ids, **kwargs)
  88. else:
  89. inputs = None
  90. if labels is not None:
  91. if "input_ids" in labels or (isinstance(labels, list) and "input_ids" in labels[0]):
  92. targets = self.tokenizer.pad(labels, **kwargs)
  93. labels = targets["input_ids"]
  94. else:
  95. feature_size_hack = self.feature_extractor.feature_size
  96. self.feature_extractor.feature_size = self.feature_extractor.num_mel_bins
  97. targets = self.feature_extractor.pad(labels, *args, **kwargs)
  98. self.feature_extractor.feature_size = feature_size_hack
  99. labels = targets["input_values"]
  100. else:
  101. targets = None
  102. if inputs is None:
  103. return targets
  104. if targets is not None:
  105. inputs["labels"] = labels
  106. decoder_attention_mask = targets.get("attention_mask")
  107. if decoder_attention_mask is not None:
  108. inputs["decoder_attention_mask"] = decoder_attention_mask
  109. return inputs
  110. __all__ = ["SpeechT5Processor"]