text_to_audio.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.from typing import List, Union
  14. from typing import Any, TypedDict, overload
  15. from ..audio_utils import AudioInput
  16. from ..generation import GenerationConfig
  17. from ..utils import is_torch_available
  18. from ..utils.chat_template_utils import Chat, ChatType
  19. from .base import Pipeline
  20. if is_torch_available():
  21. import torch
  22. from ..models.auto.modeling_auto import MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
  23. from ..models.speecht5.modeling_speecht5 import SpeechT5HifiGan
  24. DEFAULT_VOCODER_ID = "microsoft/speecht5_hifigan"
  25. class AudioOutput(TypedDict, total=False):
  26. """
  27. audio (`AudioInput`):
  28. The generated audio waveform.
  29. sampling_rate (`int`):
  30. The sampling rate of the generated audio waveform.
  31. """
  32. audio: AudioInput
  33. sampling_rate: int
  34. class TextToAudioPipeline(Pipeline):
  35. """
  36. Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`. This
  37. pipeline generates an audio file from an input text and optional other conditional inputs.
  38. Unless the model you're using explicitly sets these generation parameters in its configuration files
  39. (`generation_config.json`), the following default values will be used:
  40. - max_new_tokens: 256
  41. Example:
  42. ```python
  43. >>> from transformers import pipeline
  44. >>> pipe = pipeline(model="suno/bark-small")
  45. >>> output = pipe("Hey it's HuggingFace on the phone!")
  46. >>> audio = output["audio"]
  47. >>> sampling_rate = output["sampling_rate"]
  48. ```
  49. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  50. <Tip>
  51. You can specify parameters passed to the model by using [`TextToAudioPipeline.__call__.forward_params`] or
  52. [`TextToAudioPipeline.__call__.generate_kwargs`].
  53. Example:
  54. ```python
  55. >>> from transformers import pipeline
  56. >>> music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small")
  57. >>> # diversify the music generation by adding randomness with a high temperature and set a maximum music length
  58. >>> generate_kwargs = {
  59. ... "do_sample": True,
  60. ... "temperature": 0.7,
  61. ... "max_new_tokens": 35,
  62. ... }
  63. >>> outputs = music_generator("Techno music with high melodic riffs", generate_kwargs=generate_kwargs)
  64. ```
  65. </Tip>
  66. This pipeline can currently be loaded from [`pipeline`] using the following task identifiers: `"text-to-speech"` or
  67. `"text-to-audio"`.
  68. See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=text-to-speech).
  69. """
  70. _pipeline_calls_generate = True
  71. _load_processor = None # prioritize processors as some models require it
  72. _load_image_processor = False
  73. _load_feature_extractor = False
  74. _load_tokenizer = True
  75. # Make sure the docstring is updated when the default generation config is changed
  76. _default_generation_config = GenerationConfig(
  77. max_new_tokens=256,
  78. )
  79. def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs):
  80. super().__init__(*args, **kwargs)
  81. self.vocoder = None
  82. if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values():
  83. self.vocoder = (
  84. SpeechT5HifiGan.from_pretrained(DEFAULT_VOCODER_ID).to(self.model.device)
  85. if vocoder is None
  86. else vocoder
  87. )
  88. if self.model.config.model_type in ["musicgen", "speecht5"]:
  89. # MusicGen and SpeechT5 expect to use their tokenizer instead
  90. self.processor = None
  91. self.sampling_rate = sampling_rate
  92. if self.vocoder is not None:
  93. self.sampling_rate = self.vocoder.config.sampling_rate
  94. if self.sampling_rate is None:
  95. # get sampling_rate from config and generation config
  96. config = self.model.config
  97. gen_config = self.model.__dict__.get("generation_config", None)
  98. if gen_config is not None:
  99. config.update(gen_config.to_dict())
  100. for sampling_rate_name in ["sample_rate", "sampling_rate"]:
  101. sampling_rate = getattr(config, sampling_rate_name, None)
  102. if sampling_rate is not None:
  103. self.sampling_rate = sampling_rate
  104. elif getattr(config, "codec_config", None) is not None:
  105. sampling_rate = getattr(config.codec_config, sampling_rate_name, None)
  106. if sampling_rate is not None:
  107. self.sampling_rate = sampling_rate
  108. # last fallback to get the sampling rate based on processor
  109. if self.sampling_rate is None and self.processor is not None and hasattr(self.processor, "feature_extractor"):
  110. self.sampling_rate = self.processor.feature_extractor.sampling_rate
  111. def preprocess(self, text, **kwargs):
  112. if isinstance(text, str):
  113. text = [text]
  114. if self.model.config.model_type == "bark":
  115. # bark Tokenizer is called with BarkProcessor which uses those kwargs
  116. # Check if generation_config has semantic_config (BarkGenerationConfig) or use default
  117. max_length = 256
  118. if hasattr(self.generation_config, "semantic_config"):
  119. max_length = getattr(self.generation_config.semantic_config, "max_input_semantic_length", 256)
  120. new_kwargs = {
  121. "max_length": max_length,
  122. "add_special_tokens": False,
  123. "return_attention_mask": True,
  124. "return_token_type_ids": False,
  125. }
  126. # priority is given to kwargs
  127. new_kwargs.update(kwargs)
  128. kwargs = new_kwargs
  129. preprocessor = self.processor if self.processor is not None else self.tokenizer
  130. if isinstance(text, Chat):
  131. output = preprocessor.apply_chat_template(
  132. text.messages,
  133. tokenize=True,
  134. return_dict=True,
  135. **kwargs,
  136. )
  137. else:
  138. # Add speaker ID if needed and user didn't insert at start of text
  139. if self.model.config.model_type == "csm":
  140. text = [f"[0]{t}" if not t.startswith("[") else t for t in text]
  141. if self.model.config.model_type == "dia":
  142. text = [f"[S1] {t}" if not t.startswith("[") else t for t in text]
  143. output = preprocessor(text, **kwargs, return_tensors="pt")
  144. return output
  145. def _forward(self, model_inputs, **kwargs):
  146. # we expect some kwargs to be additional tensors which need to be on the right device
  147. kwargs = self._ensure_tensor_on_device(kwargs, device=self.device)
  148. forward_params = kwargs["forward_params"]
  149. generate_kwargs = kwargs["generate_kwargs"]
  150. if self.model.can_generate():
  151. # we expect some kwargs to be additional tensors which need to be on the right device
  152. generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device)
  153. # User-defined `generation_config` passed to the pipeline call take precedence
  154. if "generation_config" not in generate_kwargs:
  155. generate_kwargs["generation_config"] = self.generation_config
  156. # generate_kwargs get priority over forward_params
  157. forward_params.update(generate_kwargs)
  158. # ensure dict output to facilitate postprocessing
  159. forward_params.update({"return_dict_in_generate": True})
  160. if self.model.config.model_type in ["csm"]:
  161. # NOTE (ebezzam): CSM does not have the audio tokenizer in the processor therefore `output_audio=True`
  162. # needed for decoding to audio
  163. if "output_audio" not in forward_params:
  164. forward_params["output_audio"] = True
  165. output = self.model.generate(**model_inputs, **forward_params)
  166. else:
  167. if len(generate_kwargs):
  168. raise ValueError(
  169. "You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non "
  170. "empty. For forward-only TTA models, please use `forward_params` instead of `generate_kwargs`. "
  171. f"For reference, the `generate_kwargs` used here are: {generate_kwargs.keys()}"
  172. )
  173. output = self.model(**model_inputs, **forward_params)[0]
  174. if self.vocoder is not None:
  175. # in that case, the output is a spectrogram that needs to be converted into a waveform
  176. output = self.vocoder(output)
  177. return output
  178. @overload
  179. def __call__(self, text_inputs: str, **forward_params: Any) -> AudioOutput: ...
  180. @overload
  181. def __call__(self, text_inputs: list[str], **forward_params: Any) -> list[AudioOutput]: ...
  182. @overload
  183. def __call__(self, text_inputs: ChatType, **forward_params: Any) -> AudioOutput: ...
  184. @overload
  185. def __call__(self, text_inputs: list[ChatType], **forward_params: Any) -> list[AudioOutput]: ...
  186. def __call__(self, text_inputs, **forward_params):
  187. """
  188. Generates speech/audio from the inputs. See the [`TextToAudioPipeline`] documentation for more information.
  189. Args:
  190. text_inputs (`str`, `list[str]`, `ChatType`, or `list[ChatType]`):
  191. One or several texts to generate. If strings or a list of string are passed, this pipeline will
  192. generate the corresponding text. Alternatively, a "chat", in the form of a list of dicts with "role"
  193. and "content" keys, can be passed, or a list of such chats. When chats are passed, the model's chat
  194. template will be used to format them before passing them to the model.
  195. forward_params (`dict`, *optional*):
  196. Parameters passed to the model generation/forward method. `forward_params` are always passed to the
  197. underlying model.
  198. generate_kwargs (`dict`, *optional*):
  199. The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
  200. complete overview of generate, check the [following
  201. guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). `generate_kwargs` are
  202. only passed to the underlying model if the latter is a generative model.
  203. Return:
  204. `AudioOutput` or a list of `AudioOutput`, which is a `TypedDict` with two keys:
  205. - **audio** (`np.ndarray` of shape `(nb_channels, audio_length)`) -- The generated audio waveform.
  206. - **sampling_rate** (`int`) -- The sampling rate of the generated audio waveform.
  207. """
  208. return super().__call__(text_inputs, **forward_params)
  209. def _sanitize_parameters(
  210. self,
  211. preprocess_params=None,
  212. forward_params=None,
  213. generate_kwargs=None,
  214. ):
  215. if getattr(self, "assistant_model", None) is not None:
  216. generate_kwargs["assistant_model"] = self.assistant_model
  217. if getattr(self, "assistant_tokenizer", None) is not None:
  218. generate_kwargs["tokenizer"] = self.tokenizer
  219. generate_kwargs["assistant_tokenizer"] = self.assistant_tokenizer
  220. params = {
  221. "forward_params": forward_params if forward_params else {},
  222. "generate_kwargs": generate_kwargs if generate_kwargs else {},
  223. }
  224. if preprocess_params is None:
  225. preprocess_params = {}
  226. postprocess_params = {}
  227. return preprocess_params, params, postprocess_params
  228. def postprocess(self, audio):
  229. needs_decoding = False
  230. if isinstance(audio, dict):
  231. if "audio" in audio:
  232. audio = audio["audio"]
  233. else:
  234. needs_decoding = True
  235. audio = audio["sequences"]
  236. elif isinstance(audio, tuple):
  237. audio = audio[0]
  238. if needs_decoding and self.processor is not None:
  239. audio = self.processor.decode(audio)
  240. if isinstance(audio, list):
  241. audio = [el.to(device="cpu", dtype=torch.float).numpy().squeeze() for el in audio]
  242. audio = audio if len(audio) > 1 else audio[0]
  243. else:
  244. audio = audio.to(device="cpu", dtype=torch.float).numpy().squeeze()
  245. return AudioOutput(
  246. audio=audio,
  247. sampling_rate=self.sampling_rate,
  248. )