processing_dia.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. # Copyright 2025 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Processor class for Dia"""
  15. import math
  16. from pathlib import Path
  17. from ...audio_utils import AudioInput, make_list_of_audio
  18. from ...feature_extraction_utils import BatchFeature
  19. from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
  20. from ...utils import auto_docstring, is_soundfile_available, is_torch_available
  21. if is_torch_available():
  22. import torch
  23. if is_soundfile_available():
  24. import soundfile as sf
  25. class DiaAudioKwargs(AudioKwargs, total=False):
  26. """
  27. bos_token_id (`int`, *optional*, defaults to `1026`):
  28. The token ID used as the beginning-of-sequence token for audio codebooks. This token is prepended to each
  29. audio sequence during encoding.
  30. eos_token_id (`int`, *optional*, defaults to `1024`):
  31. The token ID used as the end-of-sequence token for audio codebooks. This token is appended to audio sequences
  32. during training (when `generation=False`) to mark the end of the audio.
  33. pad_token_id (`int`, *optional*, defaults to `1025`):
  34. The token ID used for padding audio codebook sequences. This token is used to fill positions in the delay
  35. pattern where no valid audio token exists.
  36. delay_pattern (`list[int]`, *optional*, defaults to `[0, 8, 9, 10, 11, 12, 13, 14, 15]`):
  37. A list of delay values (in frames) for each codebook channel. The delay pattern creates temporal offsets
  38. between different codebook channels, allowing the model to capture dependencies across channels. Each value
  39. represents the number of frames to delay that specific channel.
  40. generation (`bool`, *optional*, defaults to `True`):
  41. Whether the processor is being used for generation (text-to-speech) or training. When `True`, the processor
  42. prepares inputs for generation mode where audio is generated from text. When `False`, it prepares inputs for
  43. training where both text and audio are provided.
  44. """
  45. bos_token_id: int
  46. eos_token_id: int
  47. pad_token_id: int
  48. delay_pattern: list[int]
  49. generation: bool
  50. class DiaProcessorKwargs(ProcessingKwargs, total=False):
  51. audio_kwargs: DiaAudioKwargs
  52. _defaults = {
  53. "text_kwargs": {
  54. "padding": True,
  55. "padding_side": "right",
  56. "add_special_tokens": False,
  57. },
  58. "audio_kwargs": {
  59. "eos_token_id": 1024,
  60. "pad_token_id": 1025,
  61. "bos_token_id": 1026,
  62. "delay_pattern": [0, 8, 9, 10, 11, 12, 13, 14, 15],
  63. "generation": True,
  64. "sampling_rate": 44100,
  65. },
  66. "common_kwargs": {
  67. "return_tensors": "pt",
  68. },
  69. }
  70. @auto_docstring
  71. class DiaProcessor(ProcessorMixin):
  72. audio_tokenizer_class = "DacModel"
  73. def __init__(self, feature_extractor, tokenizer, audio_tokenizer):
  74. r"""
  75. audio_tokenizer (`DacModel`):
  76. An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is a required input.
  77. """
  78. super().__init__(feature_extractor, tokenizer, audio_tokenizer=audio_tokenizer)
  79. @auto_docstring
  80. def __call__(
  81. self,
  82. text: str | list[str],
  83. audio: AudioInput | None = None,
  84. output_labels: bool | None = False,
  85. **kwargs: Unpack[DiaProcessorKwargs],
  86. ):
  87. r"""
  88. output_labels (`bool`, *optional*, defaults to `False`):
  89. Whether to return labels for training. When `True`, the processor generates labels from the decoder input
  90. sequence by shifting it by one position. Labels use special values: `-100` for tokens to ignore in loss
  91. computation (padding and BOS tokens), and `-101` for audio frames used only for the backbone model (when
  92. `depth_decoder_labels_ratio < 1.0`). Cannot be used together with `generation=True`.
  93. """
  94. if not is_torch_available():
  95. raise ValueError(
  96. "The `DiaProcessor` relies on the `audio_tokenizer` which requires `torch` but we couldn't "
  97. "find it in your environment. You can install torch via `pip install torch`."
  98. )
  99. if text is None:
  100. raise ValueError("You need to specify the `text` input to process.")
  101. output_kwargs = self._merge_kwargs(
  102. DiaProcessorKwargs,
  103. **kwargs,
  104. )
  105. text_kwargs = output_kwargs["text_kwargs"]
  106. audio_kwargs = output_kwargs["audio_kwargs"]
  107. return_tensors = text_kwargs.get("return_tensors", None)
  108. if return_tensors != "pt":
  109. raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
  110. data = {}
  111. # Text
  112. if isinstance(text, str):
  113. text = [text]
  114. elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
  115. raise ValueError("Invalid input text. Please provide a string, or a list of strings")
  116. encodings = self.tokenizer(text, **text_kwargs)
  117. data.update(encodings)
  118. # Audio
  119. delay_pattern = audio_kwargs.pop("delay_pattern", None)
  120. audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
  121. audio_eos_token_id = audio_kwargs.pop("eos_token_id", None)
  122. audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
  123. generation = audio_kwargs.pop("generation", True)
  124. if (
  125. audio_bos_token_id is None
  126. or audio_eos_token_id is None
  127. or audio_pad_token_id is None
  128. or delay_pattern is None
  129. ):
  130. raise ValueError(
  131. "To enable processing for Dia, we need the `bos_token_id`, `eos_token_id`, "
  132. "`pad_token_id`, and `delay_pattern`. You may have accidentally overwritten one of those."
  133. )
  134. if generation and output_labels:
  135. raise ValueError(
  136. f"Labels with `generation` is incompatible, got generation={generation}, output_labels={output_labels}."
  137. )
  138. batch_size = data["input_ids"].shape[0]
  139. num_channels = len(delay_pattern)
  140. max_delay = max(delay_pattern)
  141. # Voice cloning generation / general training
  142. if audio is not None:
  143. audio = make_list_of_audio(audio)
  144. input_audios = self.feature_extractor(audio, **audio_kwargs)
  145. compression_rate = math.prod(self.audio_tokenizer.config.downsampling_ratios)
  146. max_encoded_sequence_len = input_audios["padding_mask"][0].shape[-1] // compression_rate
  147. decoder_input_ids = []
  148. decoder_attention_mask = []
  149. # TODO: dac with batching is currently broken, but non-batch is working
  150. # refer to https://gist.github.com/vasqu/643a45b680cf39fd7467271ee2eb6f80 for a validation script
  151. for padding_mask, audio in zip(input_audios["padding_mask"], input_audios["input_values"]):
  152. # get current length with hop length in mind (as if it were sampled as a single audio)
  153. base_pad_len = self.feature_extractor.hop_length
  154. current_audio_len = math.ceil(padding_mask.sum(dim=-1) / base_pad_len) * base_pad_len
  155. encoded_sequence_len = current_audio_len // compression_rate
  156. padding_len = max_encoded_sequence_len - encoded_sequence_len
  157. # compute non-padded forward pass; one extra bos (and eos if training) is added
  158. with torch.no_grad():
  159. audio = audio[None, ..., :current_audio_len].to(self.audio_tokenizer.device)
  160. input_ids = self.audio_tokenizer.encode(audio).audio_codes.transpose(1, 2)
  161. if not generation:
  162. input_ids = torch.nn.functional.pad(
  163. input_ids, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=audio_eos_token_id
  164. )
  165. # apply padding
  166. # +1 for the bos within the real sequence
  167. input_ids = torch.nn.functional.pad(
  168. input_ids, pad=(0, 0, padding_len + 1, 0, 0, 0), mode="constant", value=audio_bos_token_id
  169. )
  170. num_valid_inputs = encoded_sequence_len + 1 + max_delay # sequence + bos + delay
  171. num_valid_inputs += 0 if generation else 1 # eos if training
  172. attention_mask = torch.tensor([0] * padding_len + [1] * num_valid_inputs, dtype=torch.long)[None, :]
  173. decoder_input_ids.append(input_ids)
  174. decoder_attention_mask.append(attention_mask)
  175. decoder_input_ids = torch.cat(decoder_input_ids, dim=0)
  176. decoder_attention_mask = torch.cat(decoder_attention_mask, dim=0)
  177. # TTS generation
  178. elif generation:
  179. # all bos to start with TTS
  180. decoder_input_ids = torch.full((batch_size, 1, num_channels), audio_bos_token_id, dtype=torch.long)
  181. # we preemptively add the delay
  182. decoder_attention_mask = torch.ones(size=(batch_size, 1 + max_delay), dtype=torch.long)
  183. else:
  184. raise ValueError("If you try to train, you should provide audio data as well.")
  185. if batch_size != decoder_input_ids.shape[0]:
  186. raise ValueError(
  187. f"Need the same amount of samples for both text and audio, but got text samples={batch_size} and "
  188. f"audio samples = {decoder_input_ids.shape[0]} instead."
  189. )
  190. # prepare shift indices per delay
  191. max_seq_len = decoder_attention_mask.shape[-1]
  192. max_audio_len = max_seq_len - max_delay
  193. precomputed_idx = self.build_indices(
  194. bsz=batch_size,
  195. seq_len=max_seq_len,
  196. num_channels=num_channels,
  197. delay_pattern=delay_pattern,
  198. revert=False,
  199. )
  200. # create delay pattern input
  201. # the pad token will be used for masking which input is valid for prediction during generation
  202. prefill = torch.full(
  203. (batch_size, max_seq_len, num_channels),
  204. fill_value=audio_pad_token_id,
  205. dtype=torch.int,
  206. )
  207. prefill[:, :max_audio_len] = decoder_input_ids
  208. delayed_decoder_input_ids = self.apply_audio_delay(
  209. audio=prefill,
  210. pad_token_id=audio_pad_token_id,
  211. bos_token_id=audio_bos_token_id,
  212. precomputed_idx=precomputed_idx,
  213. )
  214. data.update({"decoder_input_ids": delayed_decoder_input_ids, "decoder_attention_mask": decoder_attention_mask})
  215. if output_labels:
  216. # Base idea is to shift on the sequence dim
  217. labels = data["decoder_input_ids"].clone()[:, 1:]
  218. labels[labels == audio_pad_token_id] = -100
  219. labels[labels == audio_bos_token_id] = -100
  220. data["labels"] = labels.transpose(1, 2).reshape(batch_size * num_channels, -1).contiguous().long()
  221. data["decoder_input_ids"] = data["decoder_input_ids"][:, :-1]
  222. data["decoder_attention_mask"] = data["decoder_attention_mask"][:, :-1]
  223. return BatchFeature(data=data, tensor_type=return_tensors)
  224. def batch_decode(
  225. self,
  226. decoder_input_ids: "torch.Tensor",
  227. audio_prompt_len: int | None = None,
  228. **kwargs: Unpack[DiaProcessorKwargs],
  229. ) -> list["torch.Tensor"]:
  230. """
  231. Decodes a batch of audio codebook sequences into their respective audio waveforms via the
  232. `audio_tokenizer`. See [`~DacModel.decode`] for more information.
  233. Args:
  234. decoder_input_ids (`torch.Tensor`): The complete output sequence of the decoder.
  235. audio_prompt_len (`int`): The audio prefix length (e.g. when using voice cloning).
  236. """
  237. output_kwargs = self._merge_kwargs(
  238. DiaProcessorKwargs,
  239. **kwargs,
  240. )
  241. audio_kwargs = output_kwargs["audio_kwargs"]
  242. delay_pattern = audio_kwargs.pop("delay_pattern", None)
  243. audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
  244. audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
  245. if audio_bos_token_id is None or audio_pad_token_id is None or delay_pattern is None:
  246. raise ValueError(
  247. "To enable decoding for Dia, we need the `bos_token_id`, `pad_token_id`, "
  248. "and `delay_pattern`. You may have accidentally overwritten one of those."
  249. )
  250. # either decode the whole audio sequence or only the generated parts
  251. if audio_prompt_len is not None:
  252. audio_prompt_len = torch.tensor(audio_prompt_len, device=decoder_input_ids.device, dtype=torch.long)
  253. start_of_generation_idx = audio_prompt_len[None].expand(decoder_input_ids.shape[0])
  254. else:
  255. start_of_generation_idx = (decoder_input_ids[:, :, 0] == audio_bos_token_id).sum(dim=-1)
  256. # -1 for the eos token
  257. end_of_generation_idx = (
  258. decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == audio_pad_token_id).sum(dim=-1) - 1
  259. )
  260. # revert delay
  261. bsz, seq_len, num_channels = decoder_input_ids.shape
  262. precomputed_idx = self.build_indices(
  263. bsz=bsz,
  264. seq_len=seq_len,
  265. num_channels=num_channels,
  266. delay_pattern=delay_pattern,
  267. revert=True,
  268. )
  269. output_sequences = self.apply_audio_delay(
  270. audio=decoder_input_ids,
  271. # We do not care about these values as we cut them out
  272. # with `start_of_generation_idx` and `end_of_generation_idx`
  273. pad_token_id=-1,
  274. bos_token_id=-1,
  275. precomputed_idx=precomputed_idx,
  276. ).transpose(1, 2)
  277. # retrieve the correct sequences each
  278. audios = []
  279. # TODO: see above, dac doesn't work in batches yet
  280. with torch.no_grad():
  281. for i in range(start_of_generation_idx.shape[0]):
  282. output_i = output_sequences[i, :, start_of_generation_idx[i] : end_of_generation_idx[i]][None, ...]
  283. output_i = output_i.to(self.audio_tokenizer.device)
  284. audio_i = self.audio_tokenizer.decode(audio_codes=output_i).audio_values.cpu().squeeze()
  285. audios.append(audio_i)
  286. return audios
  287. def decode(
  288. self,
  289. decoder_input_ids: "torch.Tensor",
  290. audio_prompt_len: int | None = None,
  291. **kwargs: Unpack[DiaProcessorKwargs],
  292. ) -> "torch.Tensor":
  293. """
  294. Decodes a single sequence of audio codebooks into the respective audio waveform via the
  295. `audio_tokenizer`. See [`~DacModel.decode`] and [`~DiaProcessor.batch_decode`] for more information.
  296. """
  297. if decoder_input_ids.shape[0] != 1:
  298. raise ValueError(
  299. f"Expecting a single output to be decoded but received {decoder_input_ids.shape[0]} samples instead."
  300. )
  301. return self.batch_decode(decoder_input_ids, audio_prompt_len, **kwargs)[0]
  302. def get_audio_prompt_len(
  303. self,
  304. decoder_attention_mask: "torch.Tensor",
  305. **kwargs: Unpack[DiaProcessorKwargs],
  306. ) -> int:
  307. """Utility function to get the audio prompt length."""
  308. output_kwargs = self._merge_kwargs(
  309. DiaProcessorKwargs,
  310. **kwargs,
  311. )
  312. audio_kwargs = output_kwargs["audio_kwargs"]
  313. delay_pattern = audio_kwargs.pop("delay_pattern", None)
  314. if delay_pattern is None:
  315. raise ValueError(
  316. "To enable the utility of retrieving the prompt length for Dia, we need the "
  317. "`delay_pattern`. You may have accidentally overwritten this."
  318. )
  319. return decoder_attention_mask.shape[1] - max(delay_pattern)
  320. # Copied from transformers.models.csm.processing_csm.CsmProcessor.save_audio with Csm->Dia
  321. def save_audio(
  322. self,
  323. audio: AudioInput,
  324. saving_path: str | Path | list[str | Path],
  325. **kwargs: Unpack[DiaProcessorKwargs],
  326. ):
  327. # TODO: @eustlb, this should be in AudioProcessor
  328. if not is_soundfile_available():
  329. raise ImportError("Please install `soundfile` to save audio files.")
  330. # ensure correct audio input
  331. audio = make_list_of_audio(audio)
  332. # ensure correct saving path
  333. if isinstance(saving_path, (str, Path)):
  334. saving_path = [saving_path]
  335. elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
  336. raise ValueError("Invalid input path. Please provide a string, or a list of strings")
  337. if len(audio) != len(saving_path):
  338. raise ValueError("The number of audio and saving paths must be the same")
  339. output_kwargs = self._merge_kwargs(
  340. DiaProcessorKwargs,
  341. **kwargs,
  342. )
  343. audio_kwargs = output_kwargs["audio_kwargs"]
  344. sampling_rate = audio_kwargs["sampling_rate"]
  345. for audio_value, p in zip(audio, saving_path):
  346. if isinstance(audio_value, torch.Tensor):
  347. audio_value = audio_value.cpu().float().numpy()
  348. sf.write(p, audio_value, sampling_rate)
  349. @staticmethod
  350. def build_indices(
  351. bsz: int,
  352. seq_len: int,
  353. num_channels: int,
  354. delay_pattern: list[int],
  355. revert: bool = False,
  356. ) -> tuple["torch.Tensor", "torch.Tensor"]:
  357. """
  358. Precompute (sequence_idx, all_idx) so that out[seq, channel] = in[seq - delay[channel], channel]
  359. or in[seq, channel] = out[seq + delay[channel], channel] if `revert`.
  360. Negative sequence_idx => BOS; sequence_idx >= seq_len => PAD.
  361. """
  362. delay_array = torch.tensor(delay_pattern, dtype=torch.int32)
  363. # (0..seq_len-1)
  364. sequence_idx = torch.arange(seq_len, dtype=torch.int32)[None, :].expand(bsz, seq_len)[..., None]
  365. # + or - delay depending if we delay or revert the delay
  366. if not revert:
  367. sequence_idx = sequence_idx - delay_array[None, None, :]
  368. else:
  369. sequence_idx = sequence_idx + delay_array[None, None, :]
  370. # if delay goes over the range we clamp back to valid values
  371. valid_sequence_idx = torch.clamp(sequence_idx, 0, seq_len - 1)
  372. batch_idx = torch.arange(bsz, dtype=torch.int32)[:, None, None].expand(bsz, seq_len, num_channels)
  373. channel_idx = torch.arange(num_channels, dtype=torch.int32)[None, None, :].expand(bsz, seq_len, num_channels)
  374. all_idx = torch.stack(
  375. [batch_idx.reshape(-1), valid_sequence_idx.reshape(-1), channel_idx.reshape(-1)],
  376. dim=1,
  377. ).long()
  378. return sequence_idx, all_idx
  379. @staticmethod
  380. def apply_audio_delay(
  381. audio: "torch.Tensor",
  382. pad_token_id: int,
  383. bos_token_id: int,
  384. precomputed_idx: tuple["torch.Tensor", "torch.Tensor"],
  385. ) -> "torch.Tensor":
  386. """
  387. Applies or reverts the delay pattern to batched audio tokens using precomputed indices,
  388. inserting BOS where sequence_idx < 0 and PAD where sequence_idx >= seq_len.
  389. Args:
  390. audio: audio tokens of shape [bsz, seq_len, num_channels]
  391. pad_token_id: the PAD token
  392. bos_token_id: the BOS token
  393. precomputed_idx: from `build_indices`
  394. Returns:
  395. final_audio: delayed or reverted audio tokens of shape [bsz, seq_len, num_channels]
  396. """
  397. # Move everything to the same device
  398. device = audio.device
  399. sequence_idx, all_idx = precomputed_idx
  400. sequence_idx = sequence_idx.to(device)
  401. all_idx = all_idx.to(device)
  402. # Gather per precomputed indices
  403. batch_idx, valid_sequence_idx, channel_idx = torch.unbind(all_idx, dim=-1)
  404. gathered_audio = audio[batch_idx, valid_sequence_idx, channel_idx].view(audio.size())
  405. # Mask according to negative sequence_idx => BOS; sequence_idx >= seq_len => PAD
  406. mask_bos = sequence_idx < 0
  407. mask_pad = sequence_idx >= audio.shape[1]
  408. final_audio = torch.where(mask_bos, bos_token_id, torch.where(mask_pad, pad_token_id, gathered_audio))
  409. return final_audio
  410. __all__ = ["DiaProcessor"]