processing_musicgen.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright 2023 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Text/audio processor class for MusicGen
  16. """
  17. from typing import Any
  18. import numpy as np
  19. from ...processing_utils import ProcessorMixin
  20. from ...utils import auto_docstring, to_numpy
  21. @auto_docstring
  22. class MusicgenProcessor(ProcessorMixin):
  23. def __init__(self, feature_extractor, tokenizer):
  24. super().__init__(feature_extractor, tokenizer)
  25. def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
  26. return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
  27. @auto_docstring
  28. def __call__(self, *args, **kwargs):
  29. if len(args) > 0:
  30. kwargs["audio"] = args[0]
  31. return super().__call__(*args, **kwargs)
  32. def batch_decode(self, *args, **kwargs):
  33. """
  34. This method is used to decode either batches of audio outputs from the MusicGen model, or batches of token ids
  35. from the tokenizer. In the case of decoding token ids, this method forwards all its arguments to T5Tokenizer's
  36. [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information.
  37. """
  38. audio_values = kwargs.pop("audio", None)
  39. padding_mask = kwargs.pop("padding_mask", None)
  40. if len(args) > 0:
  41. audio_values = args[0]
  42. args = args[1:]
  43. if audio_values is not None:
  44. return self._decode_audio(audio_values, padding_mask=padding_mask)
  45. else:
  46. return self.tokenizer.batch_decode(*args, **kwargs)
  47. def _decode_audio(self, audio_values, padding_mask: Any = None) -> list[np.ndarray]:
  48. """
  49. This method strips any padding from the audio values to return a list of numpy audio arrays.
  50. """
  51. audio_values = to_numpy(audio_values)
  52. bsz, channels, seq_len = audio_values.shape
  53. if padding_mask is None:
  54. return list(audio_values)
  55. padding_mask = to_numpy(padding_mask)
  56. # match the sequence length of the padding mask to the generated audio arrays by padding with the **non-padding**
  57. # token (so that the generated audio values are **not** treated as padded tokens)
  58. difference = seq_len - padding_mask.shape[-1]
  59. padding_value = 1 - self.feature_extractor.padding_value
  60. padding_mask = np.pad(padding_mask, ((0, 0), (0, difference)), "constant", constant_values=padding_value)
  61. audio_values = audio_values.tolist()
  62. for i in range(bsz):
  63. sliced_audio = np.asarray(audio_values[i])[
  64. padding_mask[i][None, :] != self.feature_extractor.padding_value
  65. ]
  66. audio_values[i] = sliced_audio.reshape(channels, -1)
  67. return audio_values
  68. __all__ = ["MusicgenProcessor"]