| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315 |
- # Copyright 2023 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.from typing import List, Union
- from typing import Any, TypedDict, overload
- from ..audio_utils import AudioInput
- from ..generation import GenerationConfig
- from ..utils import is_torch_available
- from ..utils.chat_template_utils import Chat, ChatType
- from .base import Pipeline
- if is_torch_available():
- import torch
- from ..models.auto.modeling_auto import MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
- from ..models.speecht5.modeling_speecht5 import SpeechT5HifiGan
- DEFAULT_VOCODER_ID = "microsoft/speecht5_hifigan"
- class AudioOutput(TypedDict, total=False):
- """
- audio (`AudioInput`):
- The generated audio waveform.
- sampling_rate (`int`):
- The sampling rate of the generated audio waveform.
- """
- audio: AudioInput
- sampling_rate: int
- class TextToAudioPipeline(Pipeline):
- """
- Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`. This
- pipeline generates an audio file from an input text and optional other conditional inputs.
- Unless the model you're using explicitly sets these generation parameters in its configuration files
- (`generation_config.json`), the following default values will be used:
- - max_new_tokens: 256
- Example:
- ```python
- >>> from transformers import pipeline
- >>> pipe = pipeline(model="suno/bark-small")
- >>> output = pipe("Hey it's HuggingFace on the phone!")
- >>> audio = output["audio"]
- >>> sampling_rate = output["sampling_rate"]
- ```
- Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
- <Tip>
- You can specify parameters passed to the model by using [`TextToAudioPipeline.__call__.forward_params`] or
- [`TextToAudioPipeline.__call__.generate_kwargs`].
- Example:
- ```python
- >>> from transformers import pipeline
- >>> music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small")
- >>> # diversify the music generation by adding randomness with a high temperature and set a maximum music length
- >>> generate_kwargs = {
- ... "do_sample": True,
- ... "temperature": 0.7,
- ... "max_new_tokens": 35,
- ... }
- >>> outputs = music_generator("Techno music with high melodic riffs", generate_kwargs=generate_kwargs)
- ```
- </Tip>
- This pipeline can currently be loaded from [`pipeline`] using the following task identifiers: `"text-to-speech"` or
- `"text-to-audio"`.
- See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=text-to-speech).
- """
- _pipeline_calls_generate = True
- _load_processor = None # prioritize processors as some models require it
- _load_image_processor = False
- _load_feature_extractor = False
- _load_tokenizer = True
- # Make sure the docstring is updated when the default generation config is changed
- _default_generation_config = GenerationConfig(
- max_new_tokens=256,
- )
- def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs):
- super().__init__(*args, **kwargs)
- self.vocoder = None
- if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values():
- self.vocoder = (
- SpeechT5HifiGan.from_pretrained(DEFAULT_VOCODER_ID).to(self.model.device)
- if vocoder is None
- else vocoder
- )
- if self.model.config.model_type in ["musicgen", "speecht5"]:
- # MusicGen and SpeechT5 expect to use their tokenizer instead
- self.processor = None
- self.sampling_rate = sampling_rate
- if self.vocoder is not None:
- self.sampling_rate = self.vocoder.config.sampling_rate
- if self.sampling_rate is None:
- # get sampling_rate from config and generation config
- config = self.model.config
- gen_config = self.model.__dict__.get("generation_config", None)
- if gen_config is not None:
- config.update(gen_config.to_dict())
- for sampling_rate_name in ["sample_rate", "sampling_rate"]:
- sampling_rate = getattr(config, sampling_rate_name, None)
- if sampling_rate is not None:
- self.sampling_rate = sampling_rate
- elif getattr(config, "codec_config", None) is not None:
- sampling_rate = getattr(config.codec_config, sampling_rate_name, None)
- if sampling_rate is not None:
- self.sampling_rate = sampling_rate
- # last fallback to get the sampling rate based on processor
- if self.sampling_rate is None and self.processor is not None and hasattr(self.processor, "feature_extractor"):
- self.sampling_rate = self.processor.feature_extractor.sampling_rate
- def preprocess(self, text, **kwargs):
- if isinstance(text, str):
- text = [text]
- if self.model.config.model_type == "bark":
- # bark Tokenizer is called with BarkProcessor which uses those kwargs
- # Check if generation_config has semantic_config (BarkGenerationConfig) or use default
- max_length = 256
- if hasattr(self.generation_config, "semantic_config"):
- max_length = getattr(self.generation_config.semantic_config, "max_input_semantic_length", 256)
- new_kwargs = {
- "max_length": max_length,
- "add_special_tokens": False,
- "return_attention_mask": True,
- "return_token_type_ids": False,
- }
- # priority is given to kwargs
- new_kwargs.update(kwargs)
- kwargs = new_kwargs
- preprocessor = self.processor if self.processor is not None else self.tokenizer
- if isinstance(text, Chat):
- output = preprocessor.apply_chat_template(
- text.messages,
- tokenize=True,
- return_dict=True,
- **kwargs,
- )
- else:
- # Add speaker ID if needed and user didn't insert at start of text
- if self.model.config.model_type == "csm":
- text = [f"[0]{t}" if not t.startswith("[") else t for t in text]
- if self.model.config.model_type == "dia":
- text = [f"[S1] {t}" if not t.startswith("[") else t for t in text]
- output = preprocessor(text, **kwargs, return_tensors="pt")
- return output
- def _forward(self, model_inputs, **kwargs):
- # we expect some kwargs to be additional tensors which need to be on the right device
- kwargs = self._ensure_tensor_on_device(kwargs, device=self.device)
- forward_params = kwargs["forward_params"]
- generate_kwargs = kwargs["generate_kwargs"]
- if self.model.can_generate():
- # we expect some kwargs to be additional tensors which need to be on the right device
- generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device)
- # User-defined `generation_config` passed to the pipeline call take precedence
- if "generation_config" not in generate_kwargs:
- generate_kwargs["generation_config"] = self.generation_config
- # generate_kwargs get priority over forward_params
- forward_params.update(generate_kwargs)
- # ensure dict output to facilitate postprocessing
- forward_params.update({"return_dict_in_generate": True})
- if self.model.config.model_type in ["csm"]:
- # NOTE (ebezzam): CSM does not have the audio tokenizer in the processor therefore `output_audio=True`
- # needed for decoding to audio
- if "output_audio" not in forward_params:
- forward_params["output_audio"] = True
- output = self.model.generate(**model_inputs, **forward_params)
- else:
- if len(generate_kwargs):
- raise ValueError(
- "You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non "
- "empty. For forward-only TTA models, please use `forward_params` instead of `generate_kwargs`. "
- f"For reference, the `generate_kwargs` used here are: {generate_kwargs.keys()}"
- )
- output = self.model(**model_inputs, **forward_params)[0]
- if self.vocoder is not None:
- # in that case, the output is a spectrogram that needs to be converted into a waveform
- output = self.vocoder(output)
- return output
- @overload
- def __call__(self, text_inputs: str, **forward_params: Any) -> AudioOutput: ...
- @overload
- def __call__(self, text_inputs: list[str], **forward_params: Any) -> list[AudioOutput]: ...
- @overload
- def __call__(self, text_inputs: ChatType, **forward_params: Any) -> AudioOutput: ...
- @overload
- def __call__(self, text_inputs: list[ChatType], **forward_params: Any) -> list[AudioOutput]: ...
- def __call__(self, text_inputs, **forward_params):
- """
- Generates speech/audio from the inputs. See the [`TextToAudioPipeline`] documentation for more information.
- Args:
- text_inputs (`str`, `list[str]`, `ChatType`, or `list[ChatType]`):
- One or several texts to generate. If strings or a list of string are passed, this pipeline will
- generate the corresponding text. Alternatively, a "chat", in the form of a list of dicts with "role"
- and "content" keys, can be passed, or a list of such chats. When chats are passed, the model's chat
- template will be used to format them before passing them to the model.
- forward_params (`dict`, *optional*):
- Parameters passed to the model generation/forward method. `forward_params` are always passed to the
- underlying model.
- generate_kwargs (`dict`, *optional*):
- The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
- complete overview of generate, check the [following
- guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). `generate_kwargs` are
- only passed to the underlying model if the latter is a generative model.
- Return:
- `AudioOutput` or a list of `AudioOutput`, which is a `TypedDict` with two keys:
- - **audio** (`np.ndarray` of shape `(nb_channels, audio_length)`) -- The generated audio waveform.
- - **sampling_rate** (`int`) -- The sampling rate of the generated audio waveform.
- """
- return super().__call__(text_inputs, **forward_params)
- def _sanitize_parameters(
- self,
- preprocess_params=None,
- forward_params=None,
- generate_kwargs=None,
- ):
- if getattr(self, "assistant_model", None) is not None:
- generate_kwargs["assistant_model"] = self.assistant_model
- if getattr(self, "assistant_tokenizer", None) is not None:
- generate_kwargs["tokenizer"] = self.tokenizer
- generate_kwargs["assistant_tokenizer"] = self.assistant_tokenizer
- params = {
- "forward_params": forward_params if forward_params else {},
- "generate_kwargs": generate_kwargs if generate_kwargs else {},
- }
- if preprocess_params is None:
- preprocess_params = {}
- postprocess_params = {}
- return preprocess_params, params, postprocess_params
- def postprocess(self, audio):
- needs_decoding = False
- if isinstance(audio, dict):
- if "audio" in audio:
- audio = audio["audio"]
- else:
- needs_decoding = True
- audio = audio["sequences"]
- elif isinstance(audio, tuple):
- audio = audio[0]
- if needs_decoding and self.processor is not None:
- audio = self.processor.decode(audio)
- if isinstance(audio, list):
- audio = [el.to(device="cpu", dtype=torch.float).numpy().squeeze() for el in audio]
- audio = audio if len(audio) > 1 else audio[0]
- else:
- audio = audio.to(device="cpu", dtype=torch.float).numpy().squeeze()
- return AudioOutput(
- audio=audio,
- sampling_rate=self.sampling_rate,
- )
|