processing_pop2piano.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright 2023 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Processor class for Pop2Piano."""
  15. import os
  16. import numpy as np
  17. from ...feature_extraction_utils import BatchFeature
  18. from ...processing_utils import ProcessorMixin
  19. from ...tokenization_python import BatchEncoding, PaddingStrategy, TruncationStrategy
  20. from ...utils import TensorType, auto_docstring
  21. from ...utils.import_utils import requires
  22. @requires(backends=("essentia", "librosa", "pretty_midi", "scipy", "torch"))
  23. @auto_docstring
  24. class Pop2PianoProcessor(ProcessorMixin):
  25. def __init__(self, feature_extractor, tokenizer):
  26. super().__init__(feature_extractor, tokenizer)
  27. @auto_docstring
  28. def __call__(
  29. self,
  30. audio: np.ndarray | list[float] | list[np.ndarray] = None,
  31. sampling_rate: int | list[int] | None = None,
  32. steps_per_beat: int = 2,
  33. resample: bool | None = True,
  34. notes: list | TensorType = None,
  35. padding: bool | str | PaddingStrategy = False,
  36. truncation: bool | str | TruncationStrategy = None,
  37. max_length: int | None = None,
  38. pad_to_multiple_of: int | None = None,
  39. verbose: bool = True,
  40. **kwargs,
  41. ) -> BatchFeature | BatchEncoding:
  42. # Since Feature Extractor needs both audio and sampling_rate and tokenizer needs both token_ids and
  43. # feature_extractor_output, we must check for both.
  44. r"""
  45. sampling_rate (`int` or `list[int]`, *optional*):
  46. The sampling rate of the input audio in Hz. This should match the sampling rate used by the feature
  47. extractor. If not provided, the default sampling rate from the processor configuration will be used.
  48. steps_per_beat (`int`, *optional*, defaults to `2`):
  49. The number of time steps per musical beat. This parameter controls the temporal resolution of the
  50. musical representation. A higher value provides finer temporal granularity but increases the sequence
  51. length. Used when processing audio to extract musical features.
  52. notes (`list` or `TensorType`, *optional*):
  53. Pre-extracted musical notes in MIDI format. When provided, the processor skips audio feature extraction
  54. and directly processes the notes through the tokenizer. Each note should be represented as a list or
  55. tensor containing pitch, velocity, and timing information.
  56. """
  57. if (audio is None and sampling_rate is None) and (notes is None):
  58. raise ValueError(
  59. "You have to specify at least audios and sampling_rate in order to use feature extractor or "
  60. "notes to use the tokenizer part."
  61. )
  62. if audio is not None and sampling_rate is not None:
  63. inputs = self.feature_extractor(
  64. audio=audio,
  65. sampling_rate=sampling_rate,
  66. steps_per_beat=steps_per_beat,
  67. resample=resample,
  68. **kwargs,
  69. )
  70. if notes is not None:
  71. encoded_token_ids = self.tokenizer(
  72. notes=notes,
  73. padding=padding,
  74. truncation=truncation,
  75. max_length=max_length,
  76. pad_to_multiple_of=pad_to_multiple_of,
  77. verbose=verbose,
  78. **kwargs,
  79. )
  80. if notes is None:
  81. return inputs
  82. elif audio is None or sampling_rate is None:
  83. return encoded_token_ids
  84. else:
  85. inputs["token_ids"] = encoded_token_ids["token_ids"]
  86. return inputs
  87. def batch_decode(
  88. self,
  89. token_ids,
  90. feature_extractor_output: BatchFeature,
  91. return_midi: bool = True,
  92. ) -> BatchEncoding:
  93. """
  94. This method uses [`Pop2PianoTokenizer.batch_decode`] method to convert model generated token_ids to midi_notes.
  95. Please refer to the docstring of the above two methods for more information.
  96. """
  97. return self.tokenizer.batch_decode(
  98. token_ids=token_ids, feature_extractor_output=feature_extractor_output, return_midi=return_midi
  99. )
  100. def save_pretrained(self, save_directory, **kwargs):
  101. if os.path.isfile(save_directory):
  102. raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
  103. os.makedirs(save_directory, exist_ok=True)
  104. return super().save_pretrained(save_directory, **kwargs)
  105. @classmethod
  106. def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
  107. args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
  108. return cls(*args)
  109. __all__ = ["Pop2PianoProcessor"]