processing_csm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from pathlib import Path
  16. from typing import Any
  17. import numpy as np
  18. from ...utils import auto_docstring, is_soundfile_available, is_torch_available
  19. if is_torch_available():
  20. import torch
  21. if is_soundfile_available():
  22. import soundfile as sf
  23. from ...audio_utils import AudioInput, make_list_of_audio
  24. from ...feature_extraction_utils import BatchFeature
  25. from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
  26. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  27. class CsmAudioKwargs(AudioKwargs, total=False):
  28. """
  29. encoded_length_kwargs (`dict[str, Any]`, *optional*):
  30. Dictionary of keyword arguments used to compute the encoded audio sequence length. This includes parameters
  31. such as `kernel_sizes`, `strides`, `dilations`, and `use_causal_conv` that define the convolutional layers
  32. used in audio encoding. The encoded length is used to determine how many audio tokens to generate for each
  33. audio input in the text sequence.
  34. """
  35. encoded_length_kwargs: dict[str, Any] | None
  36. class CsmProcessorKwargs(ProcessingKwargs, total=False):
  37. audio_kwargs: CsmAudioKwargs
  38. _defaults = {
  39. "text_kwargs": {
  40. "padding": True,
  41. "padding_side": "left",
  42. "add_special_tokens": False,
  43. },
  44. "audio_kwargs": {
  45. "encoded_length_kwargs": {
  46. "kernel_sizes": [7, 3, 1, 8, 3, 1, 10, 3, 1, 12, 3, 1, 16, 3, 4],
  47. "strides": [1, 1, 1, 4, 1, 1, 5, 1, 1, 6, 1, 1, 8, 1, 2],
  48. "dilations": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  49. "use_causal_conv": True,
  50. },
  51. "sampling_rate": 24000,
  52. },
  53. "common_kwargs": {"return_tensors": "pt"},
  54. }
  55. @auto_docstring
  56. class CsmProcessor(ProcessorMixin):
  57. def __init__(
  58. self,
  59. feature_extractor,
  60. tokenizer,
  61. chat_template=None,
  62. ):
  63. if not hasattr(tokenizer, "audio_token"):
  64. self.audio_token = "<|AUDIO|>"
  65. self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
  66. else:
  67. self.audio_token = tokenizer.audio_token
  68. self.audio_token_id = tokenizer.audio_token_id
  69. if not hasattr(tokenizer, "audio_eos_token"):
  70. self.audio_eos_token = "<|audio_eos|>"
  71. self.audio_eos_token_id = tokenizer.convert_tokens_to_ids(self.audio_eos_token)
  72. else:
  73. self.audio_eos_token = tokenizer.audio_eos_token
  74. self.audio_eos_token_id = tokenizer.audio_eos_token_id
  75. super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
  76. @staticmethod
  77. def _get_encoded_length(audio_length, kernel_sizes=None, strides=None, dilations=None, use_causal_conv=None):
  78. """
  79. Compute the length of the encoded audio sequence.
  80. Args:
  81. audio_length (int): The length of the audio sequence.
  82. kernel_sizes (list[int]): The kernel sizes for the convolutional layers.
  83. strides (list[int]): The strides for the convolutional layers.
  84. use_causal_conv (bool): Whether to use causal convolutions.
  85. """
  86. cur_length = audio_length
  87. if kernel_sizes is None or strides is None or dilations is None or use_causal_conv is None:
  88. return cur_length
  89. for kernel_size, stride, dilation in zip(kernel_sizes, strides, dilations):
  90. effective_kernel_size = (kernel_size - 1) * dilation + 1
  91. padding_total = kernel_size - stride
  92. padding_right = padding_total // 2
  93. padding_left = padding_total - padding_right
  94. n_frames = (cur_length - effective_kernel_size + padding_total) / stride + 1
  95. n_frames = math.ceil(n_frames) - 1
  96. ideal_length = n_frames * stride + kernel_size - padding_total
  97. extra_padding = ideal_length - cur_length
  98. if use_causal_conv:
  99. padding_left = padding_total
  100. padding_right = extra_padding
  101. else:
  102. padding_right = padding_right + extra_padding
  103. cur_length = cur_length + padding_left + padding_right
  104. cur_length = (cur_length - dilation * (kernel_size - 1) - 1) // stride + 1
  105. return cur_length
  106. def save_audio(
  107. self,
  108. audio: AudioInput,
  109. saving_path: str | Path | list[str | Path],
  110. **kwargs: Unpack[CsmProcessorKwargs],
  111. ):
  112. # TODO: @eustlb, this should be in AudioProcessor
  113. if not is_soundfile_available():
  114. raise ImportError("Please install `soundfile` to save audio files.")
  115. # ensure correct audio input
  116. audio = make_list_of_audio(audio)
  117. # ensure correct saving path
  118. if isinstance(saving_path, (str, Path)):
  119. saving_path = [saving_path]
  120. elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
  121. raise ValueError("Invalid input path. Please provide a string, or a list of strings")
  122. if len(audio) != len(saving_path):
  123. raise ValueError("The number of audio and saving paths must be the same")
  124. output_kwargs = self._merge_kwargs(
  125. CsmProcessorKwargs,
  126. **kwargs,
  127. )
  128. audio_kwargs = output_kwargs["audio_kwargs"]
  129. sampling_rate = audio_kwargs["sampling_rate"]
  130. for audio_value, p in zip(audio, saving_path):
  131. if isinstance(audio_value, torch.Tensor):
  132. audio_value = audio_value.cpu().float().numpy()
  133. sf.write(p, audio_value, sampling_rate)
  134. @auto_docstring
  135. def __call__(
  136. self,
  137. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None,
  138. audio: AudioInput | None = None,
  139. output_labels: bool | None = False,
  140. depth_decoder_labels_ratio: float | None = 1.0,
  141. **kwargs: Unpack[CsmProcessorKwargs],
  142. ):
  143. r"""
  144. output_labels (bool, *optional*, default=False):
  145. Whether to return labels for training. Indices will be in `[config.audio_token_id, -100, -101]`.
  146. - `config.audio_token_id` indicates an audio frame (considering sequence length elements as frames)
  147. - `-100` will be ignored in the loss computation
  148. - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
  149. depth_decoder_labels_ratio (float, *optional*, default=1.0):
  150. The ratio of audio frames to keep for the depth decoder labels.
  151. Returns:
  152. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  153. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  154. - **input_values** -- List of audio values to be fed to a model. Returned when `audio` is not `None`.
  155. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  156. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  157. `None`).
  158. - **labels** -- List of labels for the audio frames. Returned when `output_labels=True`.
  159. """
  160. output_kwargs = self._merge_kwargs(
  161. CsmProcessorKwargs,
  162. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  163. **kwargs,
  164. )
  165. text_kwargs = output_kwargs["text_kwargs"]
  166. audio_kwargs = output_kwargs["audio_kwargs"]
  167. return_tensors = text_kwargs.get("return_tensors", None)
  168. if return_tensors != "pt":
  169. raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
  170. if isinstance(text, str):
  171. text = [text]
  172. elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
  173. raise ValueError("Invalid input text. Please provide a string, or a list of strings")
  174. n_audio_in_text = [t.count(self.audio_token) for t in text]
  175. n_audio = 0
  176. if audio is not None:
  177. audio = make_list_of_audio(audio)
  178. n_audio = len(audio)
  179. if sum(n_audio_in_text) > 0 and n_audio != sum(n_audio_in_text):
  180. if audio is None:
  181. raise ValueError("No audio were provided, but there are audio tokens in the prompt")
  182. else:
  183. raise ValueError(
  184. f"The number of audio tokens in each text ({n_audio_in_text}) should be the same as the "
  185. f"number of provided audios ({n_audio})."
  186. )
  187. if audio is not None:
  188. encoded_length_kwargs = audio_kwargs.pop("encoded_length_kwargs", {})
  189. num_audio_tokens_list = [
  190. self._get_encoded_length(audio_array.shape[-1], **encoded_length_kwargs) for audio_array in audio
  191. ]
  192. num_audio_tokens_list_copy = num_audio_tokens_list.copy()
  193. # expand the text to repeat the audio token for the corresponding number of frames
  194. expanded_text = []
  195. for sample in text:
  196. replace_str = []
  197. while self.audio_token in sample:
  198. num_audio_tokens = num_audio_tokens_list_copy.pop(0)
  199. expanded_audio_token = self.audio_token * num_audio_tokens
  200. replace_str.append(expanded_audio_token)
  201. sample = sample.replace(self.audio_token, "<placeholder>", 1)
  202. while "<placeholder>" in sample:
  203. sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
  204. expanded_text.append(sample)
  205. text = expanded_text
  206. encoding = self.tokenizer(text, **text_kwargs)
  207. data = {}
  208. data.update(encoding)
  209. if audio is not None:
  210. audio_kwargs.pop("return_attention_mask", None) # not supported by the feature extractor
  211. concatenated_audio, input_values_cutoffs = [], []
  212. offset = 0
  213. for n_audio in n_audio_in_text:
  214. if n_audio == 0:
  215. concatenated_audio.append(np.zeros(0))
  216. input_values_cutoffs.append(torch.tensor([-1]))
  217. else:
  218. concatenated_audio.append(
  219. np.concatenate(
  220. [
  221. el.cpu().numpy() if isinstance(el, torch.Tensor) else el
  222. for el in audio[offset : offset + n_audio]
  223. ],
  224. axis=-1,
  225. )
  226. )
  227. input_values_cutoffs.append(
  228. torch.tensor([el.shape[-1] for el in audio[offset : offset + n_audio]]).cumsum(dim=-1)
  229. )
  230. offset += n_audio
  231. audio_inputs = self.feature_extractor(concatenated_audio, **audio_kwargs)
  232. audio_inputs.pop("padding_mask", None) # not applicable here
  233. data.update(audio_inputs)
  234. # pad and stack the audio cut idxs
  235. max_len = max(cut_idxs.shape[-1] for cut_idxs in input_values_cutoffs)
  236. input_values_cutoffs = [
  237. torch.nn.functional.pad(cut_idxs, (0, max_len - cut_idxs.shape[-1]), value=-1)
  238. for cut_idxs in input_values_cutoffs
  239. ]
  240. data["input_values_cutoffs"] = torch.stack(input_values_cutoffs, dim=0)
  241. if output_labels:
  242. audio_frame_idxs = (data["input_ids"] == self.audio_token_id).nonzero()
  243. n_audio_frames = audio_frame_idxs.shape[0]
  244. if depth_decoder_labels_ratio <= 1.0:
  245. rand_idxs = torch.randperm(n_audio_frames)[: int(n_audio_frames * (1 - depth_decoder_labels_ratio))]
  246. skip_frames_idxs = audio_frame_idxs[rand_idxs]
  247. else:
  248. skip_frames_idxs = audio_frame_idxs
  249. labels = torch.where(
  250. (data["input_ids"] == self.audio_token_id) | (data["input_ids"] == self.audio_eos_token_id),
  251. data["input_ids"],
  252. -100,
  253. )
  254. labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
  255. data["labels"] = labels
  256. return BatchFeature(data=data, tensor_type=return_tensors)
  257. @property
  258. def model_input_names(self):
  259. tokenizer_input_names = self.tokenizer.model_input_names
  260. feature_extractor_input_names = self.feature_extractor.model_input_names
  261. # Remove `padding_mask`, it is popped and not used when processing. Make a copy of list when removing
  262. # otherwise `self.feature_extractor.model_input_names` is also modified
  263. feature_extractor_input_names = [name for name in feature_extractor_input_names if name != "padding_mask"]
  264. return list(tokenizer_input_names + feature_extractor_input_names + ["input_values_cutoffs"])
  265. __all__ = ["CsmProcessor"]