| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482 |
- # Copyright 2025 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Processor class for Dia"""
- import math
- from pathlib import Path
- from ...audio_utils import AudioInput, make_list_of_audio
- from ...feature_extraction_utils import BatchFeature
- from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
- from ...utils import auto_docstring, is_soundfile_available, is_torch_available
- if is_torch_available():
- import torch
- if is_soundfile_available():
- import soundfile as sf
- class DiaAudioKwargs(AudioKwargs, total=False):
- """
- bos_token_id (`int`, *optional*, defaults to `1026`):
- The token ID used as the beginning-of-sequence token for audio codebooks. This token is prepended to each
- audio sequence during encoding.
- eos_token_id (`int`, *optional*, defaults to `1024`):
- The token ID used as the end-of-sequence token for audio codebooks. This token is appended to audio sequences
- during training (when `generation=False`) to mark the end of the audio.
- pad_token_id (`int`, *optional*, defaults to `1025`):
- The token ID used for padding audio codebook sequences. This token is used to fill positions in the delay
- pattern where no valid audio token exists.
- delay_pattern (`list[int]`, *optional*, defaults to `[0, 8, 9, 10, 11, 12, 13, 14, 15]`):
- A list of delay values (in frames) for each codebook channel. The delay pattern creates temporal offsets
- between different codebook channels, allowing the model to capture dependencies across channels. Each value
- represents the number of frames to delay that specific channel.
- generation (`bool`, *optional*, defaults to `True`):
- Whether the processor is being used for generation (text-to-speech) or training. When `True`, the processor
- prepares inputs for generation mode where audio is generated from text. When `False`, it prepares inputs for
- training where both text and audio are provided.
- """
- bos_token_id: int
- eos_token_id: int
- pad_token_id: int
- delay_pattern: list[int]
- generation: bool
- class DiaProcessorKwargs(ProcessingKwargs, total=False):
- audio_kwargs: DiaAudioKwargs
- _defaults = {
- "text_kwargs": {
- "padding": True,
- "padding_side": "right",
- "add_special_tokens": False,
- },
- "audio_kwargs": {
- "eos_token_id": 1024,
- "pad_token_id": 1025,
- "bos_token_id": 1026,
- "delay_pattern": [0, 8, 9, 10, 11, 12, 13, 14, 15],
- "generation": True,
- "sampling_rate": 44100,
- },
- "common_kwargs": {
- "return_tensors": "pt",
- },
- }
- @auto_docstring
- class DiaProcessor(ProcessorMixin):
- audio_tokenizer_class = "DacModel"
- def __init__(self, feature_extractor, tokenizer, audio_tokenizer):
- r"""
- audio_tokenizer (`DacModel`):
- An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is a required input.
- """
- super().__init__(feature_extractor, tokenizer, audio_tokenizer=audio_tokenizer)
- @auto_docstring
- def __call__(
- self,
- text: str | list[str],
- audio: AudioInput | None = None,
- output_labels: bool | None = False,
- **kwargs: Unpack[DiaProcessorKwargs],
- ):
- r"""
- output_labels (`bool`, *optional*, defaults to `False`):
- Whether to return labels for training. When `True`, the processor generates labels from the decoder input
- sequence by shifting it by one position. Labels use special values: `-100` for tokens to ignore in loss
- computation (padding and BOS tokens), and `-101` for audio frames used only for the backbone model (when
- `depth_decoder_labels_ratio < 1.0`). Cannot be used together with `generation=True`.
- """
- if not is_torch_available():
- raise ValueError(
- "The `DiaProcessor` relies on the `audio_tokenizer` which requires `torch` but we couldn't "
- "find it in your environment. You can install torch via `pip install torch`."
- )
- if text is None:
- raise ValueError("You need to specify the `text` input to process.")
- output_kwargs = self._merge_kwargs(
- DiaProcessorKwargs,
- **kwargs,
- )
- text_kwargs = output_kwargs["text_kwargs"]
- audio_kwargs = output_kwargs["audio_kwargs"]
- return_tensors = text_kwargs.get("return_tensors", None)
- if return_tensors != "pt":
- raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
- data = {}
- # Text
- if isinstance(text, str):
- text = [text]
- elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
- raise ValueError("Invalid input text. Please provide a string, or a list of strings")
- encodings = self.tokenizer(text, **text_kwargs)
- data.update(encodings)
- # Audio
- delay_pattern = audio_kwargs.pop("delay_pattern", None)
- audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
- audio_eos_token_id = audio_kwargs.pop("eos_token_id", None)
- audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
- generation = audio_kwargs.pop("generation", True)
- if (
- audio_bos_token_id is None
- or audio_eos_token_id is None
- or audio_pad_token_id is None
- or delay_pattern is None
- ):
- raise ValueError(
- "To enable processing for Dia, we need the `bos_token_id`, `eos_token_id`, "
- "`pad_token_id`, and `delay_pattern`. You may have accidentally overwritten one of those."
- )
- if generation and output_labels:
- raise ValueError(
- f"Labels with `generation` is incompatible, got generation={generation}, output_labels={output_labels}."
- )
- batch_size = data["input_ids"].shape[0]
- num_channels = len(delay_pattern)
- max_delay = max(delay_pattern)
- # Voice cloning generation / general training
- if audio is not None:
- audio = make_list_of_audio(audio)
- input_audios = self.feature_extractor(audio, **audio_kwargs)
- compression_rate = math.prod(self.audio_tokenizer.config.downsampling_ratios)
- max_encoded_sequence_len = input_audios["padding_mask"][0].shape[-1] // compression_rate
- decoder_input_ids = []
- decoder_attention_mask = []
- # TODO: dac with batching is currently broken, but non-batch is working
- # refer to https://gist.github.com/vasqu/643a45b680cf39fd7467271ee2eb6f80 for a validation script
- for padding_mask, audio in zip(input_audios["padding_mask"], input_audios["input_values"]):
- # get current length with hop length in mind (as if it were sampled as a single audio)
- base_pad_len = self.feature_extractor.hop_length
- current_audio_len = math.ceil(padding_mask.sum(dim=-1) / base_pad_len) * base_pad_len
- encoded_sequence_len = current_audio_len // compression_rate
- padding_len = max_encoded_sequence_len - encoded_sequence_len
- # compute non-padded forward pass; one extra bos (and eos if training) is added
- with torch.no_grad():
- audio = audio[None, ..., :current_audio_len].to(self.audio_tokenizer.device)
- input_ids = self.audio_tokenizer.encode(audio).audio_codes.transpose(1, 2)
- if not generation:
- input_ids = torch.nn.functional.pad(
- input_ids, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=audio_eos_token_id
- )
- # apply padding
- # +1 for the bos within the real sequence
- input_ids = torch.nn.functional.pad(
- input_ids, pad=(0, 0, padding_len + 1, 0, 0, 0), mode="constant", value=audio_bos_token_id
- )
- num_valid_inputs = encoded_sequence_len + 1 + max_delay # sequence + bos + delay
- num_valid_inputs += 0 if generation else 1 # eos if training
- attention_mask = torch.tensor([0] * padding_len + [1] * num_valid_inputs, dtype=torch.long)[None, :]
- decoder_input_ids.append(input_ids)
- decoder_attention_mask.append(attention_mask)
- decoder_input_ids = torch.cat(decoder_input_ids, dim=0)
- decoder_attention_mask = torch.cat(decoder_attention_mask, dim=0)
- # TTS generation
- elif generation:
- # all bos to start with TTS
- decoder_input_ids = torch.full((batch_size, 1, num_channels), audio_bos_token_id, dtype=torch.long)
- # we preemptively add the delay
- decoder_attention_mask = torch.ones(size=(batch_size, 1 + max_delay), dtype=torch.long)
- else:
- raise ValueError("If you try to train, you should provide audio data as well.")
- if batch_size != decoder_input_ids.shape[0]:
- raise ValueError(
- f"Need the same amount of samples for both text and audio, but got text samples={batch_size} and "
- f"audio samples = {decoder_input_ids.shape[0]} instead."
- )
- # prepare shift indices per delay
- max_seq_len = decoder_attention_mask.shape[-1]
- max_audio_len = max_seq_len - max_delay
- precomputed_idx = self.build_indices(
- bsz=batch_size,
- seq_len=max_seq_len,
- num_channels=num_channels,
- delay_pattern=delay_pattern,
- revert=False,
- )
- # create delay pattern input
- # the pad token will be used for masking which input is valid for prediction during generation
- prefill = torch.full(
- (batch_size, max_seq_len, num_channels),
- fill_value=audio_pad_token_id,
- dtype=torch.int,
- )
- prefill[:, :max_audio_len] = decoder_input_ids
- delayed_decoder_input_ids = self.apply_audio_delay(
- audio=prefill,
- pad_token_id=audio_pad_token_id,
- bos_token_id=audio_bos_token_id,
- precomputed_idx=precomputed_idx,
- )
- data.update({"decoder_input_ids": delayed_decoder_input_ids, "decoder_attention_mask": decoder_attention_mask})
- if output_labels:
- # Base idea is to shift on the sequence dim
- labels = data["decoder_input_ids"].clone()[:, 1:]
- labels[labels == audio_pad_token_id] = -100
- labels[labels == audio_bos_token_id] = -100
- data["labels"] = labels.transpose(1, 2).reshape(batch_size * num_channels, -1).contiguous().long()
- data["decoder_input_ids"] = data["decoder_input_ids"][:, :-1]
- data["decoder_attention_mask"] = data["decoder_attention_mask"][:, :-1]
- return BatchFeature(data=data, tensor_type=return_tensors)
- def batch_decode(
- self,
- decoder_input_ids: "torch.Tensor",
- audio_prompt_len: int | None = None,
- **kwargs: Unpack[DiaProcessorKwargs],
- ) -> list["torch.Tensor"]:
- """
- Decodes a batch of audio codebook sequences into their respective audio waveforms via the
- `audio_tokenizer`. See [`~DacModel.decode`] for more information.
- Args:
- decoder_input_ids (`torch.Tensor`): The complete output sequence of the decoder.
- audio_prompt_len (`int`): The audio prefix length (e.g. when using voice cloning).
- """
- output_kwargs = self._merge_kwargs(
- DiaProcessorKwargs,
- **kwargs,
- )
- audio_kwargs = output_kwargs["audio_kwargs"]
- delay_pattern = audio_kwargs.pop("delay_pattern", None)
- audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
- audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
- if audio_bos_token_id is None or audio_pad_token_id is None or delay_pattern is None:
- raise ValueError(
- "To enable decoding for Dia, we need the `bos_token_id`, `pad_token_id`, "
- "and `delay_pattern`. You may have accidentally overwritten one of those."
- )
- # either decode the whole audio sequence or only the generated parts
- if audio_prompt_len is not None:
- audio_prompt_len = torch.tensor(audio_prompt_len, device=decoder_input_ids.device, dtype=torch.long)
- start_of_generation_idx = audio_prompt_len[None].expand(decoder_input_ids.shape[0])
- else:
- start_of_generation_idx = (decoder_input_ids[:, :, 0] == audio_bos_token_id).sum(dim=-1)
- # -1 for the eos token
- end_of_generation_idx = (
- decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == audio_pad_token_id).sum(dim=-1) - 1
- )
- # revert delay
- bsz, seq_len, num_channels = decoder_input_ids.shape
- precomputed_idx = self.build_indices(
- bsz=bsz,
- seq_len=seq_len,
- num_channels=num_channels,
- delay_pattern=delay_pattern,
- revert=True,
- )
- output_sequences = self.apply_audio_delay(
- audio=decoder_input_ids,
- # We do not care about these values as we cut them out
- # with `start_of_generation_idx` and `end_of_generation_idx`
- pad_token_id=-1,
- bos_token_id=-1,
- precomputed_idx=precomputed_idx,
- ).transpose(1, 2)
- # retrieve the correct sequences each
- audios = []
- # TODO: see above, dac doesn't work in batches yet
- with torch.no_grad():
- for i in range(start_of_generation_idx.shape[0]):
- output_i = output_sequences[i, :, start_of_generation_idx[i] : end_of_generation_idx[i]][None, ...]
- output_i = output_i.to(self.audio_tokenizer.device)
- audio_i = self.audio_tokenizer.decode(audio_codes=output_i).audio_values.cpu().squeeze()
- audios.append(audio_i)
- return audios
- def decode(
- self,
- decoder_input_ids: "torch.Tensor",
- audio_prompt_len: int | None = None,
- **kwargs: Unpack[DiaProcessorKwargs],
- ) -> "torch.Tensor":
- """
- Decodes a single sequence of audio codebooks into the respective audio waveform via the
- `audio_tokenizer`. See [`~DacModel.decode`] and [`~DiaProcessor.batch_decode`] for more information.
- """
- if decoder_input_ids.shape[0] != 1:
- raise ValueError(
- f"Expecting a single output to be decoded but received {decoder_input_ids.shape[0]} samples instead."
- )
- return self.batch_decode(decoder_input_ids, audio_prompt_len, **kwargs)[0]
- def get_audio_prompt_len(
- self,
- decoder_attention_mask: "torch.Tensor",
- **kwargs: Unpack[DiaProcessorKwargs],
- ) -> int:
- """Utility function to get the audio prompt length."""
- output_kwargs = self._merge_kwargs(
- DiaProcessorKwargs,
- **kwargs,
- )
- audio_kwargs = output_kwargs["audio_kwargs"]
- delay_pattern = audio_kwargs.pop("delay_pattern", None)
- if delay_pattern is None:
- raise ValueError(
- "To enable the utility of retrieving the prompt length for Dia, we need the "
- "`delay_pattern`. You may have accidentally overwritten this."
- )
- return decoder_attention_mask.shape[1] - max(delay_pattern)
- # Copied from transformers.models.csm.processing_csm.CsmProcessor.save_audio with Csm->Dia
- def save_audio(
- self,
- audio: AudioInput,
- saving_path: str | Path | list[str | Path],
- **kwargs: Unpack[DiaProcessorKwargs],
- ):
- # TODO: @eustlb, this should be in AudioProcessor
- if not is_soundfile_available():
- raise ImportError("Please install `soundfile` to save audio files.")
- # ensure correct audio input
- audio = make_list_of_audio(audio)
- # ensure correct saving path
- if isinstance(saving_path, (str, Path)):
- saving_path = [saving_path]
- elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
- raise ValueError("Invalid input path. Please provide a string, or a list of strings")
- if len(audio) != len(saving_path):
- raise ValueError("The number of audio and saving paths must be the same")
- output_kwargs = self._merge_kwargs(
- DiaProcessorKwargs,
- **kwargs,
- )
- audio_kwargs = output_kwargs["audio_kwargs"]
- sampling_rate = audio_kwargs["sampling_rate"]
- for audio_value, p in zip(audio, saving_path):
- if isinstance(audio_value, torch.Tensor):
- audio_value = audio_value.cpu().float().numpy()
- sf.write(p, audio_value, sampling_rate)
- @staticmethod
- def build_indices(
- bsz: int,
- seq_len: int,
- num_channels: int,
- delay_pattern: list[int],
- revert: bool = False,
- ) -> tuple["torch.Tensor", "torch.Tensor"]:
- """
- Precompute (sequence_idx, all_idx) so that out[seq, channel] = in[seq - delay[channel], channel]
- or in[seq, channel] = out[seq + delay[channel], channel] if `revert`.
- Negative sequence_idx => BOS; sequence_idx >= seq_len => PAD.
- """
- delay_array = torch.tensor(delay_pattern, dtype=torch.int32)
- # (0..seq_len-1)
- sequence_idx = torch.arange(seq_len, dtype=torch.int32)[None, :].expand(bsz, seq_len)[..., None]
- # + or - delay depending if we delay or revert the delay
- if not revert:
- sequence_idx = sequence_idx - delay_array[None, None, :]
- else:
- sequence_idx = sequence_idx + delay_array[None, None, :]
- # if delay goes over the range we clamp back to valid values
- valid_sequence_idx = torch.clamp(sequence_idx, 0, seq_len - 1)
- batch_idx = torch.arange(bsz, dtype=torch.int32)[:, None, None].expand(bsz, seq_len, num_channels)
- channel_idx = torch.arange(num_channels, dtype=torch.int32)[None, None, :].expand(bsz, seq_len, num_channels)
- all_idx = torch.stack(
- [batch_idx.reshape(-1), valid_sequence_idx.reshape(-1), channel_idx.reshape(-1)],
- dim=1,
- ).long()
- return sequence_idx, all_idx
- @staticmethod
- def apply_audio_delay(
- audio: "torch.Tensor",
- pad_token_id: int,
- bos_token_id: int,
- precomputed_idx: tuple["torch.Tensor", "torch.Tensor"],
- ) -> "torch.Tensor":
- """
- Applies or reverts the delay pattern to batched audio tokens using precomputed indices,
- inserting BOS where sequence_idx < 0 and PAD where sequence_idx >= seq_len.
- Args:
- audio: audio tokens of shape [bsz, seq_len, num_channels]
- pad_token_id: the PAD token
- bos_token_id: the BOS token
- precomputed_idx: from `build_indices`
- Returns:
- final_audio: delayed or reverted audio tokens of shape [bsz, seq_len, num_channels]
- """
- # Move everything to the same device
- device = audio.device
- sequence_idx, all_idx = precomputed_idx
- sequence_idx = sequence_idx.to(device)
- all_idx = all_idx.to(device)
- # Gather per precomputed indices
- batch_idx, valid_sequence_idx, channel_idx = torch.unbind(all_idx, dim=-1)
- gathered_audio = audio[batch_idx, valid_sequence_idx, channel_idx].view(audio.size())
- # Mask according to negative sequence_idx => BOS; sequence_idx >= seq_len => PAD
- mask_bos = sequence_idx < 0
- mask_pad = sequence_idx >= audio.shape[1]
- final_audio = torch.where(mask_bos, bos_token_id, torch.where(mask_pad, pad_token_id, gathered_audio))
- return final_audio
- __all__ = ["DiaProcessor"]
|