| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322 |
- # Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from pathlib import Path
- from typing import Any
- import numpy as np
- from ...utils import auto_docstring, is_soundfile_available, is_torch_available
- if is_torch_available():
- import torch
- if is_soundfile_available():
- import soundfile as sf
- from ...audio_utils import AudioInput, make_list_of_audio
- from ...feature_extraction_utils import BatchFeature
- from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
- from ...tokenization_utils_base import PreTokenizedInput, TextInput
- class CsmAudioKwargs(AudioKwargs, total=False):
- """
- encoded_length_kwargs (`dict[str, Any]`, *optional*):
- Dictionary of keyword arguments used to compute the encoded audio sequence length. This includes parameters
- such as `kernel_sizes`, `strides`, `dilations`, and `use_causal_conv` that define the convolutional layers
- used in audio encoding. The encoded length is used to determine how many audio tokens to generate for each
- audio input in the text sequence.
- """
- encoded_length_kwargs: dict[str, Any] | None
- class CsmProcessorKwargs(ProcessingKwargs, total=False):
- audio_kwargs: CsmAudioKwargs
- _defaults = {
- "text_kwargs": {
- "padding": True,
- "padding_side": "left",
- "add_special_tokens": False,
- },
- "audio_kwargs": {
- "encoded_length_kwargs": {
- "kernel_sizes": [7, 3, 1, 8, 3, 1, 10, 3, 1, 12, 3, 1, 16, 3, 4],
- "strides": [1, 1, 1, 4, 1, 1, 5, 1, 1, 6, 1, 1, 8, 1, 2],
- "dilations": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- "use_causal_conv": True,
- },
- "sampling_rate": 24000,
- },
- "common_kwargs": {"return_tensors": "pt"},
- }
- @auto_docstring
- class CsmProcessor(ProcessorMixin):
- def __init__(
- self,
- feature_extractor,
- tokenizer,
- chat_template=None,
- ):
- if not hasattr(tokenizer, "audio_token"):
- self.audio_token = "<|AUDIO|>"
- self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
- else:
- self.audio_token = tokenizer.audio_token
- self.audio_token_id = tokenizer.audio_token_id
- if not hasattr(tokenizer, "audio_eos_token"):
- self.audio_eos_token = "<|audio_eos|>"
- self.audio_eos_token_id = tokenizer.convert_tokens_to_ids(self.audio_eos_token)
- else:
- self.audio_eos_token = tokenizer.audio_eos_token
- self.audio_eos_token_id = tokenizer.audio_eos_token_id
- super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
- @staticmethod
- def _get_encoded_length(audio_length, kernel_sizes=None, strides=None, dilations=None, use_causal_conv=None):
- """
- Compute the length of the encoded audio sequence.
- Args:
- audio_length (int): The length of the audio sequence.
- kernel_sizes (list[int]): The kernel sizes for the convolutional layers.
- strides (list[int]): The strides for the convolutional layers.
- use_causal_conv (bool): Whether to use causal convolutions.
- """
- cur_length = audio_length
- if kernel_sizes is None or strides is None or dilations is None or use_causal_conv is None:
- return cur_length
- for kernel_size, stride, dilation in zip(kernel_sizes, strides, dilations):
- effective_kernel_size = (kernel_size - 1) * dilation + 1
- padding_total = kernel_size - stride
- padding_right = padding_total // 2
- padding_left = padding_total - padding_right
- n_frames = (cur_length - effective_kernel_size + padding_total) / stride + 1
- n_frames = math.ceil(n_frames) - 1
- ideal_length = n_frames * stride + kernel_size - padding_total
- extra_padding = ideal_length - cur_length
- if use_causal_conv:
- padding_left = padding_total
- padding_right = extra_padding
- else:
- padding_right = padding_right + extra_padding
- cur_length = cur_length + padding_left + padding_right
- cur_length = (cur_length - dilation * (kernel_size - 1) - 1) // stride + 1
- return cur_length
- def save_audio(
- self,
- audio: AudioInput,
- saving_path: str | Path | list[str | Path],
- **kwargs: Unpack[CsmProcessorKwargs],
- ):
- # TODO: @eustlb, this should be in AudioProcessor
- if not is_soundfile_available():
- raise ImportError("Please install `soundfile` to save audio files.")
- # ensure correct audio input
- audio = make_list_of_audio(audio)
- # ensure correct saving path
- if isinstance(saving_path, (str, Path)):
- saving_path = [saving_path]
- elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
- raise ValueError("Invalid input path. Please provide a string, or a list of strings")
- if len(audio) != len(saving_path):
- raise ValueError("The number of audio and saving paths must be the same")
- output_kwargs = self._merge_kwargs(
- CsmProcessorKwargs,
- **kwargs,
- )
- audio_kwargs = output_kwargs["audio_kwargs"]
- sampling_rate = audio_kwargs["sampling_rate"]
- for audio_value, p in zip(audio, saving_path):
- if isinstance(audio_value, torch.Tensor):
- audio_value = audio_value.cpu().float().numpy()
- sf.write(p, audio_value, sampling_rate)
- @auto_docstring
- def __call__(
- self,
- text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None,
- audio: AudioInput | None = None,
- output_labels: bool | None = False,
- depth_decoder_labels_ratio: float | None = 1.0,
- **kwargs: Unpack[CsmProcessorKwargs],
- ):
- r"""
- output_labels (bool, *optional*, default=False):
- Whether to return labels for training. Indices will be in `[config.audio_token_id, -100, -101]`.
- - `config.audio_token_id` indicates an audio frame (considering sequence length elements as frames)
- - `-100` will be ignored in the loss computation
- - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
- depth_decoder_labels_ratio (float, *optional*, default=1.0):
- The ratio of audio frames to keep for the depth decoder labels.
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
- - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- - **input_values** -- List of audio values to be fed to a model. Returned when `audio` is not `None`.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- - **labels** -- List of labels for the audio frames. Returned when `output_labels=True`.
- """
- output_kwargs = self._merge_kwargs(
- CsmProcessorKwargs,
- tokenizer_init_kwargs=self.tokenizer.init_kwargs,
- **kwargs,
- )
- text_kwargs = output_kwargs["text_kwargs"]
- audio_kwargs = output_kwargs["audio_kwargs"]
- return_tensors = text_kwargs.get("return_tensors", None)
- if return_tensors != "pt":
- raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
- if isinstance(text, str):
- text = [text]
- elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
- raise ValueError("Invalid input text. Please provide a string, or a list of strings")
- n_audio_in_text = [t.count(self.audio_token) for t in text]
- n_audio = 0
- if audio is not None:
- audio = make_list_of_audio(audio)
- n_audio = len(audio)
- if sum(n_audio_in_text) > 0 and n_audio != sum(n_audio_in_text):
- if audio is None:
- raise ValueError("No audio were provided, but there are audio tokens in the prompt")
- else:
- raise ValueError(
- f"The number of audio tokens in each text ({n_audio_in_text}) should be the same as the "
- f"number of provided audios ({n_audio})."
- )
- if audio is not None:
- encoded_length_kwargs = audio_kwargs.pop("encoded_length_kwargs", {})
- num_audio_tokens_list = [
- self._get_encoded_length(audio_array.shape[-1], **encoded_length_kwargs) for audio_array in audio
- ]
- num_audio_tokens_list_copy = num_audio_tokens_list.copy()
- # expand the text to repeat the audio token for the corresponding number of frames
- expanded_text = []
- for sample in text:
- replace_str = []
- while self.audio_token in sample:
- num_audio_tokens = num_audio_tokens_list_copy.pop(0)
- expanded_audio_token = self.audio_token * num_audio_tokens
- replace_str.append(expanded_audio_token)
- sample = sample.replace(self.audio_token, "<placeholder>", 1)
- while "<placeholder>" in sample:
- sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
- expanded_text.append(sample)
- text = expanded_text
- encoding = self.tokenizer(text, **text_kwargs)
- data = {}
- data.update(encoding)
- if audio is not None:
- audio_kwargs.pop("return_attention_mask", None) # not supported by the feature extractor
- concatenated_audio, input_values_cutoffs = [], []
- offset = 0
- for n_audio in n_audio_in_text:
- if n_audio == 0:
- concatenated_audio.append(np.zeros(0))
- input_values_cutoffs.append(torch.tensor([-1]))
- else:
- concatenated_audio.append(
- np.concatenate(
- [
- el.cpu().numpy() if isinstance(el, torch.Tensor) else el
- for el in audio[offset : offset + n_audio]
- ],
- axis=-1,
- )
- )
- input_values_cutoffs.append(
- torch.tensor([el.shape[-1] for el in audio[offset : offset + n_audio]]).cumsum(dim=-1)
- )
- offset += n_audio
- audio_inputs = self.feature_extractor(concatenated_audio, **audio_kwargs)
- audio_inputs.pop("padding_mask", None) # not applicable here
- data.update(audio_inputs)
- # pad and stack the audio cut idxs
- max_len = max(cut_idxs.shape[-1] for cut_idxs in input_values_cutoffs)
- input_values_cutoffs = [
- torch.nn.functional.pad(cut_idxs, (0, max_len - cut_idxs.shape[-1]), value=-1)
- for cut_idxs in input_values_cutoffs
- ]
- data["input_values_cutoffs"] = torch.stack(input_values_cutoffs, dim=0)
- if output_labels:
- audio_frame_idxs = (data["input_ids"] == self.audio_token_id).nonzero()
- n_audio_frames = audio_frame_idxs.shape[0]
- if depth_decoder_labels_ratio <= 1.0:
- rand_idxs = torch.randperm(n_audio_frames)[: int(n_audio_frames * (1 - depth_decoder_labels_ratio))]
- skip_frames_idxs = audio_frame_idxs[rand_idxs]
- else:
- skip_frames_idxs = audio_frame_idxs
- labels = torch.where(
- (data["input_ids"] == self.audio_token_id) | (data["input_ids"] == self.audio_eos_token_id),
- data["input_ids"],
- -100,
- )
- labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
- data["labels"] = labels
- return BatchFeature(data=data, tensor_type=return_tensors)
- @property
- def model_input_names(self):
- tokenizer_input_names = self.tokenizer.model_input_names
- feature_extractor_input_names = self.feature_extractor.model_input_names
- # Remove `padding_mask`, it is popped and not used when processing. Make a copy of list when removing
- # otherwise `self.feature_extractor.model_input_names` is also modified
- feature_extractor_input_names = [name for name in feature_extractor_input_names if name != "padding_mask"]
- return list(tokenizer_input_names + feature_extractor_input_names + ["input_values_cutoffs"])
- __all__ = ["CsmProcessor"]
|