tokenization_pop2piano.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714
  1. # Copyright 2023 The Pop2Piano Authors and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tokenization class for Pop2Piano."""
  15. import json
  16. import os
  17. import numpy as np
  18. from ...feature_extraction_utils import BatchFeature
  19. from ...tokenization_python import AddedToken, BatchEncoding, PaddingStrategy, PreTrainedTokenizer, TruncationStrategy
  20. from ...utils import TensorType, is_pretty_midi_available, logging, requires_backends, to_numpy
  21. from ...utils.import_utils import requires
  22. if is_pretty_midi_available():
  23. import pretty_midi
  24. logger = logging.get_logger(__name__)
  25. VOCAB_FILES_NAMES = {
  26. "vocab": "vocab.json",
  27. }
  28. def token_time_to_note(number, cutoff_time_idx, current_idx):
  29. current_idx += number
  30. if cutoff_time_idx is not None:
  31. current_idx = min(current_idx, cutoff_time_idx)
  32. return current_idx
  33. def token_note_to_note(number, current_velocity, default_velocity, note_onsets_ready, current_idx, notes):
  34. if note_onsets_ready[number] is not None:
  35. # offset with onset
  36. onset_idx = note_onsets_ready[number]
  37. if onset_idx < current_idx:
  38. # Time shift after previous note_on
  39. offset_idx = current_idx
  40. notes.append([onset_idx, offset_idx, number, default_velocity])
  41. onsets_ready = None if current_velocity == 0 else current_idx
  42. note_onsets_ready[number] = onsets_ready
  43. else:
  44. note_onsets_ready[number] = current_idx
  45. return notes
  46. @requires(backends=("pretty_midi", "torch"))
  47. class Pop2PianoTokenizer(PreTrainedTokenizer):
  48. """
  49. Constructs a Pop2Piano tokenizer. This tokenizer does not require training.
  50. This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
  51. this superclass for more information regarding those methods.
  52. Args:
  53. vocab (`str`):
  54. Path to the vocab file which contains the vocabulary.
  55. default_velocity (`int`, *optional*, defaults to 77):
  56. Determines the default velocity to be used while creating midi Notes.
  57. num_bars (`int`, *optional*, defaults to 2):
  58. Determines cutoff_time_idx in for each token.
  59. unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"-1"`):
  60. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  61. token instead.
  62. eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 1):
  63. The end of sequence token.
  64. pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 0):
  65. A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
  66. attention mechanisms or loss computation.
  67. bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 2):
  68. The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
  69. """
  70. model_input_names = ["token_ids", "attention_mask"]
  71. vocab_files_names = VOCAB_FILES_NAMES
  72. def __init__(
  73. self,
  74. vocab,
  75. default_velocity=77,
  76. num_bars=2,
  77. unk_token="-1",
  78. eos_token="1",
  79. pad_token="0",
  80. bos_token="2",
  81. **kwargs,
  82. ):
  83. unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
  84. eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
  85. pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
  86. bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
  87. self.default_velocity = default_velocity
  88. self.num_bars = num_bars
  89. # Load the vocab
  90. with open(vocab, "rb") as file:
  91. self.encoder = json.load(file)
  92. # create mappings for encoder
  93. self.decoder = {v: k for k, v in self.encoder.items()}
  94. super().__init__(
  95. unk_token=unk_token,
  96. eos_token=eos_token,
  97. pad_token=pad_token,
  98. bos_token=bos_token,
  99. **kwargs,
  100. )
  101. @property
  102. def vocab_size(self):
  103. """Returns the vocabulary size of the tokenizer."""
  104. return len(self.encoder)
  105. def get_vocab(self):
  106. """Returns the vocabulary of the tokenizer."""
  107. return dict(self.encoder, **self.added_tokens_encoder)
  108. def _convert_id_to_token(self, token_id: int) -> list:
  109. """
  110. Decodes the token ids generated by the transformer into notes.
  111. Args:
  112. token_id (`int`):
  113. This denotes the ids generated by the transformers to be converted to Midi tokens.
  114. Returns:
  115. `List`: A list consists of token_type (`str`) and value (`int`).
  116. """
  117. token_type_value = self.decoder.get(token_id, f"{self.unk_token}_TOKEN_TIME")
  118. token_type_value = token_type_value.split("_")
  119. token_type, value = "_".join(token_type_value[1:]), int(token_type_value[0])
  120. return [token_type, value]
  121. def _convert_token_to_id(self, token, token_type="TOKEN_TIME") -> int:
  122. """
  123. Encodes the Midi tokens to transformer generated token ids.
  124. Args:
  125. token (`int`):
  126. This denotes the token value.
  127. token_type (`str`):
  128. This denotes the type of the token. There are four types of midi tokens such as "TOKEN_TIME",
  129. "TOKEN_VELOCITY", "TOKEN_NOTE" and "TOKEN_SPECIAL".
  130. Returns:
  131. `int`: returns the id of the token.
  132. """
  133. return self.encoder.get(f"{token}_{token_type}", int(self.unk_token))
  134. def relative_batch_tokens_ids_to_notes(
  135. self,
  136. tokens: np.ndarray,
  137. beat_offset_idx: int,
  138. bars_per_batch: int,
  139. cutoff_time_idx: int,
  140. ):
  141. """
  142. Converts relative tokens to notes which are then used to generate pretty midi object.
  143. Args:
  144. tokens (`numpy.ndarray`):
  145. Tokens to be converted to notes.
  146. beat_offset_idx (`int`):
  147. Denotes beat offset index for each note in generated Midi.
  148. bars_per_batch (`int`):
  149. A parameter to control the Midi output generation.
  150. cutoff_time_idx (`int`):
  151. Denotes the cutoff time index for each note in generated Midi.
  152. """
  153. notes = None
  154. for index in range(len(tokens)):
  155. _tokens = tokens[index]
  156. _start_idx = beat_offset_idx + index * bars_per_batch * 4
  157. _cutoff_time_idx = cutoff_time_idx + _start_idx
  158. _notes = self.relative_tokens_ids_to_notes(
  159. _tokens,
  160. start_idx=_start_idx,
  161. cutoff_time_idx=_cutoff_time_idx,
  162. )
  163. if len(_notes) == 0:
  164. pass
  165. elif notes is None:
  166. notes = _notes
  167. else:
  168. notes = np.concatenate((notes, _notes), axis=0)
  169. if notes is None:
  170. return []
  171. return notes
  172. def relative_batch_tokens_ids_to_midi(
  173. self,
  174. tokens: np.ndarray,
  175. beatstep: np.ndarray,
  176. beat_offset_idx: int = 0,
  177. bars_per_batch: int = 2,
  178. cutoff_time_idx: int = 12,
  179. ):
  180. """
  181. Converts tokens to Midi. This method calls `relative_batch_tokens_ids_to_notes` method to convert batch tokens
  182. to notes then uses `notes_to_midi` method to convert them to Midi.
  183. Args:
  184. tokens (`numpy.ndarray`):
  185. Denotes tokens which alongside beatstep will be converted to Midi.
  186. beatstep (`np.ndarray`):
  187. We get beatstep from feature extractor which is also used to get Midi.
  188. beat_offset_idx (`int`, *optional*, defaults to 0):
  189. Denotes beat offset index for each note in generated Midi.
  190. bars_per_batch (`int`, *optional*, defaults to 2):
  191. A parameter to control the Midi output generation.
  192. cutoff_time_idx (`int`, *optional*, defaults to 12):
  193. Denotes the cutoff time index for each note in generated Midi.
  194. """
  195. beat_offset_idx = 0 if beat_offset_idx is None else beat_offset_idx
  196. notes = self.relative_batch_tokens_ids_to_notes(
  197. tokens=tokens,
  198. beat_offset_idx=beat_offset_idx,
  199. bars_per_batch=bars_per_batch,
  200. cutoff_time_idx=cutoff_time_idx,
  201. )
  202. midi = self.notes_to_midi(notes, beatstep, offset_sec=beatstep[beat_offset_idx])
  203. return midi
  204. # Taken from the original code
  205. # Please see https://github.com/sweetcocoa/pop2piano/blob/fac11e8dcfc73487513f4588e8d0c22a22f2fdc5/midi_tokenizer.py#L257
  206. def relative_tokens_ids_to_notes(self, tokens: np.ndarray, start_idx: float, cutoff_time_idx: float | None = None):
  207. """
  208. Converts relative tokens to notes which will then be used to create Pretty Midi objects.
  209. Args:
  210. tokens (`numpy.ndarray`):
  211. Relative Tokens which will be converted to notes.
  212. start_idx (`float`):
  213. A parameter which denotes the starting index.
  214. cutoff_time_idx (`float`, *optional*):
  215. A parameter used while converting tokens to notes.
  216. """
  217. words = [self._convert_id_to_token(token) for token in tokens]
  218. current_idx = start_idx
  219. current_velocity = 0
  220. note_onsets_ready = [None for i in range(sum(k.endswith("NOTE") for k in self.encoder) + 1)]
  221. notes = []
  222. for token_type, number in words:
  223. if token_type == "TOKEN_SPECIAL":
  224. if number == 1:
  225. break
  226. elif token_type == "TOKEN_TIME":
  227. current_idx = token_time_to_note(
  228. number=number, cutoff_time_idx=cutoff_time_idx, current_idx=current_idx
  229. )
  230. elif token_type == "TOKEN_VELOCITY":
  231. current_velocity = number
  232. elif token_type == "TOKEN_NOTE":
  233. notes = token_note_to_note(
  234. number=number,
  235. current_velocity=current_velocity,
  236. default_velocity=self.default_velocity,
  237. note_onsets_ready=note_onsets_ready,
  238. current_idx=current_idx,
  239. notes=notes,
  240. )
  241. else:
  242. raise ValueError("Token type not understood!")
  243. for pitch, note_onset in enumerate(note_onsets_ready):
  244. # force offset if no offset for each pitch
  245. if note_onset is not None:
  246. if cutoff_time_idx is None:
  247. cutoff = note_onset + 1
  248. else:
  249. cutoff = max(cutoff_time_idx, note_onset + 1)
  250. offset_idx = max(current_idx, cutoff)
  251. notes.append([note_onset, offset_idx, pitch, self.default_velocity])
  252. if len(notes) == 0:
  253. return []
  254. else:
  255. notes = np.array(notes)
  256. note_order = notes[:, 0] * 128 + notes[:, 1]
  257. notes = notes[note_order.argsort()]
  258. return notes
  259. def notes_to_midi(self, notes: np.ndarray, beatstep: np.ndarray, offset_sec: int = 0.0):
  260. """
  261. Converts notes to Midi.
  262. Args:
  263. notes (`numpy.ndarray`):
  264. This is used to create Pretty Midi objects.
  265. beatstep (`numpy.ndarray`):
  266. This is the extrapolated beatstep that we get from feature extractor.
  267. offset_sec (`int`, *optional*, defaults to 0.0):
  268. This represents the offset seconds which is used while creating each Pretty Midi Note.
  269. """
  270. requires_backends(self, ["pretty_midi"])
  271. new_pm = pretty_midi.PrettyMIDI(resolution=384, initial_tempo=120.0)
  272. new_inst = pretty_midi.Instrument(program=0)
  273. new_notes = []
  274. for onset_idx, offset_idx, pitch, velocity in notes:
  275. new_note = pretty_midi.Note(
  276. velocity=velocity,
  277. pitch=pitch,
  278. start=beatstep[onset_idx] - offset_sec,
  279. end=beatstep[offset_idx] - offset_sec,
  280. )
  281. new_notes.append(new_note)
  282. new_inst.notes = new_notes
  283. new_pm.instruments.append(new_inst)
  284. new_pm.remove_invalid_notes()
  285. return new_pm
  286. def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
  287. """
  288. Saves the tokenizer's vocabulary dictionary to the provided save_directory.
  289. Args:
  290. save_directory (`str`):
  291. A path to the directory where to saved. It will be created if it doesn't exist.
  292. filename_prefix (`Optional[str]`, *optional*):
  293. A prefix to add to the names of the files saved by the tokenizer.
  294. """
  295. if not os.path.isdir(save_directory):
  296. logger.error(f"Vocabulary path ({save_directory}) should be a directory")
  297. return
  298. # Save the encoder.
  299. out_vocab_file = os.path.join(
  300. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab"]
  301. )
  302. with open(out_vocab_file, "w") as file:
  303. file.write(json.dumps(self.encoder))
  304. return (out_vocab_file,)
  305. def encode_plus(
  306. self,
  307. notes: np.ndarray | list[pretty_midi.Note],
  308. truncation_strategy: TruncationStrategy | None = None,
  309. max_length: int | None = None,
  310. **kwargs,
  311. ) -> BatchEncoding:
  312. r"""
  313. This is the `encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer
  314. generated token ids. It only works on a single batch, to process multiple batches please use
  315. `batch_encode_plus` or `__call__` method.
  316. Args:
  317. notes (`numpy.ndarray` of shape `[sequence_length, 4]` or `list` of `pretty_midi.Note` objects):
  318. This represents the midi notes. If `notes` is a `numpy.ndarray`:
  319. - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
  320. If `notes` is a `list` containing `pretty_midi.Note` objects:
  321. - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
  322. truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*):
  323. Indicates the truncation strategy that is going to be used during truncation.
  324. max_length (`int`, *optional*):
  325. Maximum length of the returned list and optionally padding length (see above).
  326. Returns:
  327. `BatchEncoding` containing the tokens ids.
  328. """
  329. requires_backends(self, ["pretty_midi"])
  330. # check if notes is a pretty_midi object or not, if yes then extract the attributes and put them into a numpy
  331. # array.
  332. if isinstance(notes[0], pretty_midi.Note):
  333. notes = np.array(
  334. [[each_note.start, each_note.end, each_note.pitch, each_note.velocity] for each_note in notes]
  335. ).reshape(-1, 4)
  336. # to round up all the values to the closest int values.
  337. notes = np.round(notes).astype(np.int32)
  338. max_time_idx = notes[:, :2].max()
  339. times = [[] for i in range(max_time_idx + 1)]
  340. for onset, offset, pitch, velocity in notes:
  341. times[onset].append([pitch, velocity])
  342. times[offset].append([pitch, 0])
  343. tokens = []
  344. current_velocity = 0
  345. for i, time in enumerate(times):
  346. if len(time) == 0:
  347. continue
  348. tokens.append(self._convert_token_to_id(i, "TOKEN_TIME"))
  349. for pitch, velocity in time:
  350. velocity = int(velocity > 0)
  351. if current_velocity != velocity:
  352. current_velocity = velocity
  353. tokens.append(self._convert_token_to_id(velocity, "TOKEN_VELOCITY"))
  354. tokens.append(self._convert_token_to_id(pitch, "TOKEN_NOTE"))
  355. total_len = len(tokens)
  356. # truncation
  357. if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
  358. tokens, _, _ = self.truncate_sequences(
  359. ids=tokens,
  360. num_tokens_to_remove=total_len - max_length,
  361. truncation_strategy=truncation_strategy,
  362. **kwargs,
  363. )
  364. return BatchEncoding({"token_ids": tokens})
  365. def batch_encode_plus(
  366. self,
  367. notes: np.ndarray | list[pretty_midi.Note],
  368. truncation_strategy: TruncationStrategy | None = None,
  369. max_length: int | None = None,
  370. **kwargs,
  371. ) -> BatchEncoding:
  372. r"""
  373. This is the `batch_encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer
  374. generated token ids. It works on multiple batches by calling `encode_plus` multiple times in a loop.
  375. Args:
  376. notes (`numpy.ndarray` of shape `[batch_size, sequence_length, 4]` or `list` of `pretty_midi.Note` objects):
  377. This represents the midi notes. If `notes` is a `numpy.ndarray`:
  378. - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
  379. If `notes` is a `list` containing `pretty_midi.Note` objects:
  380. - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
  381. truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*):
  382. Indicates the truncation strategy that is going to be used during truncation.
  383. max_length (`int`, *optional*):
  384. Maximum length of the returned list and optionally padding length (see above).
  385. Returns:
  386. `BatchEncoding` containing the tokens ids.
  387. """
  388. encoded_batch_token_ids = []
  389. for i in range(len(notes)):
  390. encoded_batch_token_ids.append(
  391. self.encode_plus(
  392. notes[i],
  393. truncation_strategy=truncation_strategy,
  394. max_length=max_length,
  395. **kwargs,
  396. )["token_ids"]
  397. )
  398. return BatchEncoding({"token_ids": encoded_batch_token_ids})
  399. def __call__(
  400. self,
  401. notes: np.ndarray | list[pretty_midi.Note] | list[list[pretty_midi.Note]],
  402. padding: bool | str | PaddingStrategy = False,
  403. truncation: bool | str | TruncationStrategy = None,
  404. max_length: int | None = None,
  405. pad_to_multiple_of: int | None = None,
  406. return_attention_mask: bool | None = None,
  407. return_tensors: str | TensorType | None = None,
  408. verbose: bool = True,
  409. **kwargs,
  410. ) -> BatchEncoding:
  411. r"""
  412. This is the `__call__` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer generated
  413. token ids.
  414. Args:
  415. notes (`numpy.ndarray` of shape `[batch_size, max_sequence_length, 4]` or `list` of `pretty_midi.Note` objects):
  416. This represents the midi notes.
  417. If `notes` is a `numpy.ndarray`:
  418. - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
  419. If `notes` is a `list` containing `pretty_midi.Note` objects:
  420. - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
  421. padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
  422. Activates and controls padding. Accepts the following values:
  423. - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
  424. sequence if provided).
  425. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  426. acceptable input length for the model if that argument is not provided.
  427. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
  428. lengths).
  429. truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
  430. Activates and controls truncation. Accepts the following values:
  431. - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
  432. to the maximum acceptable input length for the model if that argument is not provided. This will
  433. truncate token by token, removing a token from the longest sequence in the pair if a pair of
  434. sequences (or a batch of pairs) is provided.
  435. - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
  436. maximum acceptable input length for the model if that argument is not provided. This will only
  437. truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
  438. - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
  439. maximum acceptable input length for the model if that argument is not provided. This will only
  440. truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
  441. - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
  442. greater than the model maximum admissible input size).
  443. max_length (`int`, *optional*):
  444. Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to
  445. `None`, this will use the predefined model maximum length if a maximum length is required by one of the
  446. truncation/padding parameters. If the model has no specific maximum input length (like XLNet)
  447. truncation/padding to a maximum length will be deactivated.
  448. pad_to_multiple_of (`int`, *optional*):
  449. If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
  450. the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
  451. return_attention_mask (`bool`, *optional*):
  452. Whether to return the attention mask. If left to the default, will return the attention mask according
  453. to the specific tokenizer's default, defined by the `return_outputs` attribute.
  454. [What are attention masks?](../glossary#attention-mask)
  455. return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
  456. If set, will return tensors instead of list of python integers. Acceptable values are:
  457. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  458. - `'np'`: Return Numpy `np.ndarray` objects.
  459. verbose (`bool`, *optional*, defaults to `True`):
  460. Whether or not to print more information and warnings.
  461. Returns:
  462. `BatchEncoding` containing the token_ids.
  463. """
  464. # check if it is batched or not
  465. # it is batched if its a list containing a list of `pretty_midi.Notes` where the outer list contains all the
  466. # batches and the inner list contains all Notes for a single batch. Otherwise if np.ndarray is passed it will be
  467. # considered batched if it has shape of `[batch_size, sequence_length, 4]` or ndim=3.
  468. is_batched = notes.ndim == 3 if isinstance(notes, np.ndarray) else isinstance(notes[0], list)
  469. # get the truncation and padding strategy
  470. padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
  471. padding=padding,
  472. truncation=truncation,
  473. max_length=max_length,
  474. pad_to_multiple_of=pad_to_multiple_of,
  475. verbose=verbose,
  476. **kwargs,
  477. )
  478. if is_batched:
  479. # If the user has not explicitly mentioned `return_attention_mask` as False, we change it to True
  480. return_attention_mask = True if return_attention_mask is None else return_attention_mask
  481. token_ids = self.batch_encode_plus(
  482. notes=notes,
  483. truncation_strategy=truncation_strategy,
  484. max_length=max_length,
  485. **kwargs,
  486. )
  487. else:
  488. token_ids = self.encode_plus(
  489. notes=notes,
  490. truncation_strategy=truncation_strategy,
  491. max_length=max_length,
  492. **kwargs,
  493. )
  494. # since we already have truncated sequnences we are just left to do padding
  495. token_ids = self.pad(
  496. token_ids,
  497. padding=padding_strategy,
  498. max_length=max_length,
  499. pad_to_multiple_of=pad_to_multiple_of,
  500. return_attention_mask=return_attention_mask,
  501. return_tensors=return_tensors,
  502. verbose=verbose,
  503. )
  504. return token_ids
  505. def batch_decode(
  506. self,
  507. token_ids,
  508. feature_extractor_output: BatchFeature,
  509. return_midi: bool = True,
  510. ):
  511. r"""
  512. This is the `batch_decode` method for `Pop2PianoTokenizer`. It converts the token_ids generated by the
  513. transformer to midi_notes and returns them.
  514. Args:
  515. token_ids (`Union[np.ndarray, torch.Tensor]`):
  516. Output token_ids of `Pop2PianoConditionalGeneration` model.
  517. feature_extractor_output (`BatchFeature`):
  518. Denotes the output of `Pop2PianoFeatureExtractor.__call__`. It must contain `"beatstep"` and
  519. `"extrapolated_beatstep"`. Also `"attention_mask_beatsteps"` and
  520. `"attention_mask_extrapolated_beatstep"`
  521. should be present if they were returned by the feature extractor.
  522. return_midi (`bool`, *optional*, defaults to `True`):
  523. Whether to return midi object or not.
  524. Returns:
  525. If `return_midi` is True:
  526. - `BatchEncoding` containing both `notes` and `pretty_midi.pretty_midi.PrettyMIDI` objects.
  527. If `return_midi` is False:
  528. - `BatchEncoding` containing `notes`.
  529. """
  530. # check if they have attention_masks(attention_mask, attention_mask_beatsteps, attention_mask_extrapolated_beatstep) or not
  531. attention_masks_present = bool(
  532. hasattr(feature_extractor_output, "attention_mask")
  533. and hasattr(feature_extractor_output, "attention_mask_beatsteps")
  534. and hasattr(feature_extractor_output, "attention_mask_extrapolated_beatstep")
  535. )
  536. # if we are processing batched inputs then we must need attention_masks
  537. if not attention_masks_present and feature_extractor_output["beatsteps"].shape[0] > 1:
  538. raise ValueError(
  539. "attention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep must be present "
  540. "for batched inputs! But one of them were not present."
  541. )
  542. # check for length mismatch between inputs_embeds, beatsteps and extrapolated_beatstep
  543. if attention_masks_present:
  544. # since we know about the number of examples in token_ids from attention_mask
  545. if (
  546. sum(feature_extractor_output["attention_mask"][:, 0] == 0)
  547. != feature_extractor_output["beatsteps"].shape[0]
  548. or feature_extractor_output["beatsteps"].shape[0]
  549. != feature_extractor_output["extrapolated_beatstep"].shape[0]
  550. ):
  551. raise ValueError(
  552. "Length mistamtch between token_ids, beatsteps and extrapolated_beatstep! Found "
  553. f"token_ids length - {token_ids.shape[0]}, beatsteps shape - {feature_extractor_output['beatsteps'].shape[0]} "
  554. f"and extrapolated_beatsteps shape - {feature_extractor_output['extrapolated_beatstep'].shape[0]}"
  555. )
  556. if feature_extractor_output["attention_mask"].shape[0] != token_ids.shape[0]:
  557. raise ValueError(
  558. f"Found attention_mask of length - {feature_extractor_output['attention_mask'].shape[0]} but token_ids of length - {token_ids.shape[0]}"
  559. )
  560. else:
  561. # if there is no attention mask present then it's surely a single example
  562. if (
  563. feature_extractor_output["beatsteps"].shape[0] != 1
  564. or feature_extractor_output["extrapolated_beatstep"].shape[0] != 1
  565. ):
  566. raise ValueError(
  567. "Length mistamtch of beatsteps and extrapolated_beatstep! Since attention_mask is not present the number of examples must be 1, "
  568. f"But found beatsteps length - {feature_extractor_output['beatsteps'].shape[0]}, extrapolated_beatsteps length - {feature_extractor_output['extrapolated_beatstep'].shape[0]}."
  569. )
  570. if attention_masks_present:
  571. # check for zeros(since token_ids are separated by zero arrays)
  572. batch_idx = np.where(feature_extractor_output["attention_mask"][:, 0] == 0)[0]
  573. else:
  574. batch_idx = [token_ids.shape[0]]
  575. notes_list = []
  576. pretty_midi_objects_list = []
  577. start_idx = 0
  578. for index, end_idx in enumerate(batch_idx):
  579. each_tokens_ids = token_ids[start_idx:end_idx]
  580. # check where the whole example ended by searching for eos_token_id and getting the upper bound
  581. each_tokens_ids = each_tokens_ids[:, : np.max(np.where(each_tokens_ids == int(self.eos_token))[1]) + 1]
  582. beatsteps = feature_extractor_output["beatsteps"][index]
  583. extrapolated_beatstep = feature_extractor_output["extrapolated_beatstep"][index]
  584. # if attention mask is present then mask out real array/tensor
  585. if attention_masks_present:
  586. attention_mask_beatsteps = feature_extractor_output["attention_mask_beatsteps"][index]
  587. attention_mask_extrapolated_beatstep = feature_extractor_output[
  588. "attention_mask_extrapolated_beatstep"
  589. ][index]
  590. beatsteps = beatsteps[: np.max(np.where(attention_mask_beatsteps == 1)[0]) + 1]
  591. extrapolated_beatstep = extrapolated_beatstep[
  592. : np.max(np.where(attention_mask_extrapolated_beatstep == 1)[0]) + 1
  593. ]
  594. each_tokens_ids = to_numpy(each_tokens_ids)
  595. beatsteps = to_numpy(beatsteps)
  596. extrapolated_beatstep = to_numpy(extrapolated_beatstep)
  597. pretty_midi_object = self.relative_batch_tokens_ids_to_midi(
  598. tokens=each_tokens_ids,
  599. beatstep=extrapolated_beatstep,
  600. bars_per_batch=self.num_bars,
  601. cutoff_time_idx=(self.num_bars + 1) * 4,
  602. )
  603. for note in pretty_midi_object.instruments[0].notes:
  604. note.start += beatsteps[0]
  605. note.end += beatsteps[0]
  606. notes_list.append(note)
  607. pretty_midi_objects_list.append(pretty_midi_object)
  608. start_idx += end_idx + 1 # 1 represents the zero array
  609. if return_midi:
  610. return BatchEncoding({"notes": notes_list, "pretty_midi_objects": pretty_midi_objects_list})
  611. return BatchEncoding({"notes": notes_list})
  612. __all__ = ["Pop2PianoTokenizer"]