| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822 |
- # Copyright 2023 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch EnCodec model."""
- import math
- from dataclasses import dataclass
- import torch
- from torch import nn
- from ... import initialization as init
- from ...modeling_utils import PreTrainedAudioTokenizerBase
- from ...utils import (
- ModelOutput,
- auto_docstring,
- logging,
- )
- from .configuration_encodec import EncodecConfig
- logger = logging.get_logger(__name__)
- # General docstring
- @dataclass
- @auto_docstring
- class EncodecOutput(ModelOutput):
- r"""
- audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
- Discrete code embeddings computed using `model.encode`.
- audio_values (`torch.FloatTensor` of shape `(batch_size, segment_length)`, *optional*):
- Decoded audio values, obtained using the decoder part of Encodec.
- """
- audio_codes: torch.LongTensor | None = None
- audio_values: torch.FloatTensor | None = None
- @dataclass
- @auto_docstring
- class EncodecEncoderOutput(ModelOutput):
- r"""
- audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
- Discrete code embeddings computed using `model.encode`.
- audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*):
- Scaling factor for each `audio_codes` input. This is used to unscale each chunk of audio when decoding.
- last_frame_pad_length (`int`, *optional*):
- The length of the padding in the last frame, if any. This is used to ensure that the encoded frames can be
- outputted as a tensor. This value should be passed during decoding to ensure padding is removed from the
- encoded frames.
- """
- audio_codes: torch.LongTensor | None = None
- audio_scales: torch.FloatTensor | None = None
- last_frame_pad_length: int | None = None
- @dataclass
- @auto_docstring
- class EncodecDecoderOutput(ModelOutput):
- r"""
- audio_values (`torch.FloatTensor` of shape `(batch_size, segment_length)`, *optional*):
- Decoded audio values, obtained using the decoder part of Encodec.
- """
- audio_values: torch.FloatTensor | None = None
- class EncodecConv1d(nn.Module):
- """Conv1d with asymmetric or causal padding and normalization."""
- def __init__(
- self, config, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1
- ):
- super().__init__()
- self.causal = config.use_causal_conv
- self.pad_mode = config.pad_mode
- self.norm_type = config.norm_type
- if self.norm_type not in ["weight_norm", "time_group_norm"]:
- raise ValueError(
- f'self.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got {self.norm_type}'
- )
- # warn user on unusual setup between dilation and stride
- if stride > 1 and dilation > 1:
- logger.warning(
- "EncodecConv1d has been initialized with stride > 1 and dilation > 1"
- f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
- )
- self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, dilation=dilation)
- weight_norm = nn.utils.weight_norm
- if hasattr(nn.utils.parametrizations, "weight_norm"):
- weight_norm = nn.utils.parametrizations.weight_norm
- if self.norm_type == "weight_norm":
- self.conv = weight_norm(self.conv)
- elif self.norm_type == "time_group_norm":
- self.norm = nn.GroupNorm(1, out_channels)
- kernel_size = self.conv.kernel_size[0]
- stride = torch.tensor(self.conv.stride[0], dtype=torch.int64)
- dilation = self.conv.dilation[0]
- # Effective kernel size with dilations.
- kernel_size = torch.tensor((kernel_size - 1) * dilation + 1, dtype=torch.int64)
- self.register_buffer("stride", stride, persistent=False)
- self.register_buffer("kernel_size", kernel_size, persistent=False)
- self.register_buffer("padding_total", kernel_size - stride, persistent=False)
- def _get_extra_padding_for_conv1d(
- self,
- hidden_states: torch.Tensor,
- ) -> torch.Tensor:
- """See `pad_for_conv1d`."""
- length = hidden_states.shape[-1]
- n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
- n_frames = torch.ceil(n_frames).to(torch.int64) - 1
- ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
- return ideal_length - length
- @staticmethod
- def _pad1d(hidden_states: torch.Tensor, paddings: tuple[int, int], mode: str = "zero", value: float = 0.0):
- """Tiny wrapper around torch.nn.functional.pad, just to allow for reflect padding on small input.
- If this is the case, we insert extra 0 padding to the right before the reflection happens.
- """
- length = hidden_states.shape[-1]
- padding_left, padding_right = paddings
- if mode != "reflect":
- return nn.functional.pad(hidden_states, paddings, mode, value)
- max_pad = max(padding_left, padding_right)
- extra_pad = 0
- if length <= max_pad:
- extra_pad = max_pad - length + 1
- hidden_states = nn.functional.pad(hidden_states, (0, extra_pad))
- padded = nn.functional.pad(hidden_states, paddings, mode, value)
- end = padded.shape[-1] - extra_pad
- return padded[..., :end]
- def forward(self, hidden_states):
- extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
- if self.causal:
- # Left padding for causal
- hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode)
- else:
- # Asymmetric padding required for odd strides
- padding_right = self.padding_total // 2
- padding_left = self.padding_total - padding_right
- hidden_states = self._pad1d(
- hidden_states, (padding_left, padding_right + extra_padding), mode=self.pad_mode
- )
- hidden_states = self.conv(hidden_states)
- if self.norm_type == "time_group_norm":
- hidden_states = self.norm(hidden_states)
- return hidden_states
- class EncodecConvTranspose1d(nn.Module):
- """ConvTranspose1d with asymmetric or causal padding and normalization."""
- def __init__(self, config, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1):
- super().__init__()
- self.causal = config.use_causal_conv
- self.trim_right_ratio = config.trim_right_ratio
- self.norm_type = config.norm_type
- if self.norm_type not in ["weight_norm", "time_group_norm"]:
- raise ValueError(
- f'self.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got {self.norm_type}'
- )
- self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
- weight_norm = nn.utils.weight_norm
- if hasattr(nn.utils.parametrizations, "weight_norm"):
- weight_norm = nn.utils.parametrizations.weight_norm
- if config.norm_type == "weight_norm":
- self.conv = weight_norm(self.conv)
- elif config.norm_type == "time_group_norm":
- self.norm = nn.GroupNorm(1, out_channels)
- if not (self.causal or self.trim_right_ratio == 1.0):
- raise ValueError("`trim_right_ratio` != 1.0 only makes sense for causal convolutions")
- def forward(self, hidden_states):
- kernel_size = self.conv.kernel_size[0]
- stride = self.conv.stride[0]
- padding_total = kernel_size - stride
- hidden_states = self.conv(hidden_states)
- if self.norm_type == "time_group_norm":
- hidden_states = self.norm(hidden_states)
- # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
- # removed at the very end, when keeping only the right length for the output,
- # as removing it here would require also passing the length at the matching layer
- # in the encoder.
- if self.causal:
- # Trim the padding on the right according to the specified ratio
- # if trim_right_ratio = 1.0, trim everything from right
- padding_right = math.ceil(padding_total * self.trim_right_ratio)
- else:
- # Asymmetric padding required for odd strides
- padding_right = padding_total // 2
- padding_left = padding_total - padding_right
- # unpad
- end = hidden_states.shape[-1] - padding_right
- hidden_states = hidden_states[..., padding_left:end]
- return hidden_states
- class EncodecLSTM(nn.Module):
- """
- LSTM without worrying about the hidden state, nor the layout of the data. Expects input as convolutional layout.
- """
- def __init__(self, config: EncodecConfig, dimension: int):
- super().__init__()
- self.lstm = nn.LSTM(dimension, dimension, config.num_lstm_layers)
- def forward(self, hidden_states):
- hidden_states = hidden_states.permute(2, 0, 1)
- hidden_states = self.lstm(hidden_states)[0] + hidden_states
- hidden_states = hidden_states.permute(1, 2, 0)
- return hidden_states
- class EncodecResnetBlock(nn.Module):
- """
- Residual block from SEANet model as used by EnCodec.
- """
- def __init__(self, config: EncodecConfig, dim: int, dilations: list[int]):
- super().__init__()
- kernel_sizes = (config.residual_kernel_size, 1)
- if len(kernel_sizes) != len(dilations):
- raise ValueError("Number of kernel sizes should match number of dilations")
- hidden = dim // config.compress
- block = []
- for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
- in_chs = dim if i == 0 else hidden
- out_chs = dim if i == len(kernel_sizes) - 1 else hidden
- block += [nn.ELU()]
- block += [EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)]
- self.block = nn.ModuleList(block)
- if config.use_conv_shortcut:
- self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
- else:
- self.shortcut = nn.Identity()
- def forward(self, hidden_states):
- residual = hidden_states
- for layer in self.block:
- hidden_states = layer(hidden_states)
- return self.shortcut(residual) + hidden_states
- class EncodecEncoder(nn.Module):
- """SEANet encoder as used by EnCodec."""
- def __init__(self, config: EncodecConfig):
- super().__init__()
- model = [EncodecConv1d(config, config.audio_channels, config.num_filters, config.kernel_size)]
- scaling = 1
- # Downsample to raw audio scale
- for ratio in reversed(config.upsampling_ratios):
- current_scale = scaling * config.num_filters
- # Add residual layers
- for j in range(config.num_residual_layers):
- model += [EncodecResnetBlock(config, current_scale, [config.dilation_growth_rate**j, 1])]
- # Add downsampling layers
- model += [nn.ELU()]
- model += [EncodecConv1d(config, current_scale, current_scale * 2, kernel_size=ratio * 2, stride=ratio)]
- scaling *= 2
- model += [EncodecLSTM(config, scaling * config.num_filters)]
- model += [nn.ELU()]
- model += [EncodecConv1d(config, scaling * config.num_filters, config.hidden_size, config.last_kernel_size)]
- self.layers = nn.ModuleList(model)
- def forward(self, hidden_states):
- for layer in self.layers:
- hidden_states = layer(hidden_states)
- return hidden_states
- class EncodecDecoder(nn.Module):
- """SEANet decoder as used by EnCodec."""
- def __init__(self, config: EncodecConfig):
- super().__init__()
- scaling = int(2 ** len(config.upsampling_ratios))
- model = [EncodecConv1d(config, config.hidden_size, scaling * config.num_filters, config.kernel_size)]
- model += [EncodecLSTM(config, scaling * config.num_filters)]
- # Upsample to raw audio scale
- for ratio in config.upsampling_ratios:
- current_scale = scaling * config.num_filters
- # Add upsampling layers
- model += [nn.ELU()]
- model += [
- EncodecConvTranspose1d(config, current_scale, current_scale // 2, kernel_size=ratio * 2, stride=ratio)
- ]
- # Add residual layers
- for j in range(config.num_residual_layers):
- model += [EncodecResnetBlock(config, current_scale // 2, (config.dilation_growth_rate**j, 1))]
- scaling //= 2
- # Add final layers
- model += [nn.ELU()]
- model += [EncodecConv1d(config, config.num_filters, config.audio_channels, config.last_kernel_size)]
- self.layers = nn.ModuleList(model)
- def forward(self, hidden_states):
- for layer in self.layers:
- hidden_states = layer(hidden_states)
- return hidden_states
- class EncodecEuclideanCodebook(nn.Module):
- """Codebook with Euclidean distance."""
- def __init__(self, config: EncodecConfig):
- super().__init__()
- embed = torch.zeros(config.codebook_size, config.codebook_dim)
- self.codebook_size = config.codebook_size
- self.register_buffer("inited", torch.Tensor([True]))
- self.register_buffer("cluster_size", torch.zeros(config.codebook_size))
- self.register_buffer("embed", embed)
- self.register_buffer("embed_avg", embed.clone())
- def quantize(self, hidden_states):
- embed = self.embed.t()
- scaled_states = hidden_states.pow(2).sum(1, keepdim=True)
- dist = -(scaled_states - 2 * hidden_states @ embed + embed.pow(2).sum(0, keepdim=True))
- embed_ind = dist.max(dim=-1).indices
- return embed_ind
- def encode(self, hidden_states):
- shape = hidden_states.shape
- # pre-process
- hidden_states = hidden_states.reshape((-1, shape[-1]))
- # quantize
- embed_ind = self.quantize(hidden_states)
- # post-process
- embed_ind = embed_ind.view(*shape[:-1])
- return embed_ind
- def decode(self, embed_ind):
- quantize = nn.functional.embedding(embed_ind, self.embed)
- return quantize
- class EncodecVectorQuantization(nn.Module):
- """
- Vector quantization implementation. Currently supports only euclidean distance.
- """
- def __init__(self, config: EncodecConfig):
- super().__init__()
- self.codebook = EncodecEuclideanCodebook(config)
- def encode(self, hidden_states):
- hidden_states = hidden_states.permute(0, 2, 1)
- embed_in = self.codebook.encode(hidden_states)
- return embed_in
- def decode(self, embed_ind):
- quantize = self.codebook.decode(embed_ind)
- quantize = quantize.permute(0, 2, 1)
- return quantize
- class EncodecResidualVectorQuantizer(nn.Module):
- """Residual Vector Quantizer."""
- def __init__(self, config: EncodecConfig):
- super().__init__()
- self.codebook_size = config.codebook_size
- self.frame_rate = config.frame_rate
- self.num_quantizers = config.num_quantizers
- self.layers = nn.ModuleList([EncodecVectorQuantization(config) for _ in range(config.num_quantizers)])
- def get_num_quantizers_for_bandwidth(self, bandwidth: float | None = None) -> int:
- """Return num_quantizers based on specified target bandwidth."""
- bw_per_q = math.log2(self.codebook_size) * self.frame_rate
- num_quantizers = self.num_quantizers
- if bandwidth is not None and bandwidth > 0.0:
- num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
- return num_quantizers
- def encode(self, embeddings: torch.Tensor, bandwidth: float | None = None) -> torch.Tensor:
- """
- Encode a given input tensor with the specified frame rate at the given bandwidth. The RVQ encode method sets
- the appropriate number of quantizers to use and returns indices for each quantizer.
- """
- num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
- residual = embeddings
- all_indices = []
- for layer in self.layers[:num_quantizers]:
- indices = layer.encode(residual)
- quantized = layer.decode(indices)
- residual = residual - quantized
- all_indices.append(indices)
- out_indices = torch.stack(all_indices)
- return out_indices
- def decode(self, codes: torch.Tensor) -> torch.Tensor:
- """Decode the given codes to the quantized representation."""
- quantized_out = torch.tensor(0.0, device=codes.device)
- for i, indices in enumerate(codes):
- layer = self.layers[i]
- quantized = layer.decode(indices)
- quantized_out = quantized_out + quantized
- return quantized_out
- @auto_docstring
- class EncodecPreTrainedModel(PreTrainedAudioTokenizerBase):
- config: EncodecConfig
- base_model_prefix = "encodec"
- main_input_name = "input_values"
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- if isinstance(module, nn.GroupNorm):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- elif isinstance(module, nn.Conv1d):
- init.kaiming_normal_(module.weight)
- if module.bias is not None:
- k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
- init.uniform_(module.bias, a=-k, b=k)
- elif isinstance(module, nn.ConvTranspose1d):
- module.reset_parameters()
- elif isinstance(module, nn.LSTM):
- for name, param in module.named_parameters():
- if "weight" in name:
- init.xavier_uniform_(param)
- elif "bias" in name:
- init.constant_(param, 0.0)
- elif isinstance(module, EncodecConv1d):
- kernel_size = module.conv.kernel_size[0]
- stride = torch.tensor(module.conv.stride[0], dtype=torch.int64)
- dilation = module.conv.dilation[0]
- # Effective kernel size with dilations.
- kernel_size = torch.tensor((kernel_size - 1) * dilation + 1, dtype=torch.int64)
- init.copy_(module.stride, stride)
- init.copy_(module.kernel_size, kernel_size)
- init.copy_(module.padding_total, kernel_size - stride)
- elif isinstance(module, EncodecEuclideanCodebook):
- init.copy_(module.inited, torch.Tensor([True]))
- init.zeros_(module.cluster_size)
- init.zeros_(module.embed)
- init.zeros_(module.embed_avg)
- @auto_docstring(
- custom_intro="""
- The EnCodec neural audio codec model.
- """
- )
- class EncodecModel(EncodecPreTrainedModel):
- def __init__(self, config: EncodecConfig):
- super().__init__(config)
- self.config = config
- self.encoder = EncodecEncoder(config)
- self.decoder = EncodecDecoder(config)
- self.quantizer = EncodecResidualVectorQuantizer(config)
- self.bits_per_codebook = int(math.log2(self.config.codebook_size))
- if 2**self.bits_per_codebook != self.config.codebook_size:
- raise ValueError("The codebook_size must be a power of 2.")
- # Initialize weights and apply final processing
- self.post_init()
- def _encode_frame(self, input_values: torch.Tensor, bandwidth: float) -> tuple[torch.Tensor, torch.Tensor | None]:
- """
- Encodes the given input using the underlying VQVAE. If `config.normalize` is set to `True` the input is first
- normalized. The padding mask is required to compute the correct scale.
- """
- length = input_values.shape[-1]
- duration = length / self.config.sampling_rate
- if self.config.chunk_length_s is not None and duration > 1e-5 + self.config.chunk_length_s:
- raise RuntimeError(f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}")
- scale = None
- if self.config.normalize:
- mono = torch.sum(input_values, 1, keepdim=True) / input_values.shape[1]
- scale = mono.pow(2).mean(dim=-1, keepdim=True).sqrt() + 1e-8
- input_values = input_values / scale
- scale = scale.view(-1, 1)
- embeddings = self.encoder(input_values)
- codes = self.quantizer.encode(embeddings, bandwidth)
- codes = codes.transpose(0, 1)
- return codes, scale
- def encode(
- self,
- input_values: torch.Tensor,
- padding_mask: torch.Tensor | None = None,
- bandwidth: float | None = None,
- return_dict: bool | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor | None, int] | EncodecEncoderOutput:
- """
- Encodes the input audio waveform into discrete codes of shape
- `(nb_frames, batch_size, nb_quantizers, frame_len)`.
- - `nb_frames=1` if `self.config.chunk_length=None` (as the encoder is applied on the full audio), which is the
- case for the 24kHz model. Otherwise, `nb_frames=ceil(input_length/self.config.chunk_stride)`, which is the case
- for the 48kHz model.
- - `frame_len` is the length of each frame, which is equal to `ceil(input_length/self.config.hop_length)` if
- `self.config.chunk_length=None` (e.g., for the 24kHz model). Otherwise, if `self.config.chunk_length` is
- defined, `frame_len=self.config.chunk_length/self.config.hop_length`, e.g., the case for the 48kHz model with
- `frame_len=150`.
- Args:
- input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
- Float values of the input audio waveform.
- padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
- Padding mask used to pad the `input_values`.
- bandwidth (`float`, *optional*):
- The target bandwidth. Must be one of `config.target_bandwidths`. If `None`, uses the smallest possible
- bandwidth. bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented
- as bandwidth == 6.0
- Returns:
- EncodecEncoderOutput dict or a tuple containing:
- - audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*),
- - audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*),
- - last_frame_pad_length (`int`, *optional*).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if bandwidth is None:
- bandwidth = self.config.target_bandwidths[0]
- if bandwidth not in self.config.target_bandwidths:
- raise ValueError(
- f"This model doesn't support the bandwidth {bandwidth}. Select one of {self.config.target_bandwidths}."
- )
- _, channels, input_length = input_values.shape
- if channels < 1 or channels > 2:
- raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")
- chunk_length = self.config.chunk_length
- if chunk_length is None:
- chunk_length = input_length
- stride = input_length
- else:
- stride = self.config.chunk_stride
- if padding_mask is None:
- padding_mask = torch.ones_like(input_values).bool()
- else:
- padding_mask = padding_mask.view(padding_mask.shape[0], -1, padding_mask.shape[-1])
- encoded_frames = []
- scales = []
- for offset in range(0, input_length, stride):
- mask = padding_mask[..., offset : offset + chunk_length].bool()
- frame = mask * input_values[..., offset : offset + chunk_length]
- encoded_frame, scale = self._encode_frame(frame, bandwidth)
- encoded_frames.append(encoded_frame)
- scales.append(scale)
- # pad last frame (if necessary) to be able to apply `torch.stack`
- last_frame_pad_length = encoded_frames[0].shape[-1] - encoded_frames[-1].shape[-1]
- if last_frame_pad_length > 0:
- last_frame = nn.functional.pad(encoded_frames[-1], (0, last_frame_pad_length), value=0)
- encoded_frames[-1] = last_frame
- encoded_frames = torch.stack(encoded_frames)
- if not return_dict:
- return (encoded_frames, scales, last_frame_pad_length)
- return EncodecEncoderOutput(encoded_frames, scales, last_frame_pad_length)
- @staticmethod
- def _linear_overlap_add(frames: list[torch.Tensor], stride: int):
- # Generic overlap add, with linear fade-in/fade-out, supporting complex scenario
- # e.g., more than 2 frames per position.
- # The core idea is to use a weight function that is a triangle,
- # with a maximum value at the middle of the chunk.
- # We use this weighting when summing the frames, and divide by the sum of weights
- # for each positions at the end. Thus:
- # - if a frame is the only one to cover a position, the weighting is a no-op.
- # - if 2 frames cover a position:
- # ... ...
- # / \/ \
- # / /\ \
- # S T , i.e. S offset of second frame starts, T end of first frame.
- # Then the weight function for each one is: (t - S), (T - t), with `t` a given offset.
- # After the final normalization, the weight of the second frame at position `t` is
- # (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want.
- #
- # - if more than 2 frames overlap at a given point, we hope that by induction
- # something sensible happens.
- if len(frames) == 0:
- raise ValueError("`frames` cannot be an empty list.")
- device = frames[0].device
- dtype = frames[0].dtype
- shape = frames[0].shape[:-1]
- total_size = stride * (len(frames) - 1) + frames[-1].shape[-1]
- frame_length = frames[0].shape[-1]
- time_vec = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1:-1]
- weight = 0.5 - (time_vec - 0.5).abs()
- sum_weight = torch.zeros(total_size, device=device, dtype=dtype)
- out = torch.zeros(*shape, total_size, device=device, dtype=dtype)
- offset: int = 0
- for frame in frames:
- frame_length = frame.shape[-1]
- out[..., offset : offset + frame_length] += weight[:frame_length] * frame
- sum_weight[offset : offset + frame_length] += weight[:frame_length]
- offset += stride
- if sum_weight.min() == 0:
- raise ValueError(f"`sum_weight` minimum element must be bigger than zero: {sum_weight}`")
- return out / sum_weight
- def _decode_frame(self, codes: torch.Tensor, scale: torch.Tensor | None = None) -> torch.Tensor:
- codes = codes.transpose(0, 1)
- embeddings = self.quantizer.decode(codes)
- outputs = self.decoder(embeddings)
- if scale is not None:
- outputs = outputs * scale.view(-1, 1, 1)
- return outputs
- def decode(
- self,
- audio_codes: torch.LongTensor,
- audio_scales: torch.Tensor,
- padding_mask: torch.Tensor | None = None,
- return_dict: bool | None = None,
- last_frame_pad_length: int | None = 0,
- ) -> tuple[torch.Tensor, torch.Tensor] | EncodecDecoderOutput:
- """
- Decodes the given frames into an output audio waveform.
- Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
- trimmed.
- Args:
- audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
- Discrete code embeddings computed using `model.encode`.
- audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*):
- Scaling factor for each `audio_codes` input.
- padding_mask (`torch.Tensor` of shape `(channels, sequence_length)`):
- Padding mask used to pad the `input_values`.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- last_frame_pad_length (`int`, *optional*):
- Integer representing the length of the padding in the last frame, which is removed during decoding.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- chunk_length = self.config.chunk_length
- if chunk_length is None:
- if len(audio_codes) != 1:
- raise ValueError(f"Expected one frame, got {len(audio_codes)}")
- frame = audio_codes[0]
- if last_frame_pad_length > 0:
- frame = frame[..., :-last_frame_pad_length]
- audio_values = self._decode_frame(frame, audio_scales[0])
- else:
- decoded_frames = []
- for i, (frame, scale) in enumerate(zip(audio_codes, audio_scales)):
- if i == len(audio_codes) - 1 and last_frame_pad_length > 0:
- frame = frame[..., :-last_frame_pad_length]
- frames = self._decode_frame(frame, scale)
- decoded_frames.append(frames)
- audio_values = self._linear_overlap_add(decoded_frames, self.config.chunk_stride or 1)
- # truncate based on padding mask
- if padding_mask is not None and padding_mask.shape[-1] < audio_values.shape[-1]:
- audio_values = audio_values[..., : padding_mask.shape[-1]]
- if not return_dict:
- return (audio_values,)
- return EncodecDecoderOutput(audio_values)
- @auto_docstring
- def forward(
- self,
- input_values: torch.FloatTensor,
- padding_mask: torch.BoolTensor | None = None,
- bandwidth: float | None = None,
- audio_codes: torch.LongTensor | None = None,
- audio_scales: torch.Tensor | None = None,
- return_dict: bool | None = None,
- last_frame_pad_length: int | None = 0,
- ) -> tuple[torch.Tensor, torch.Tensor] | EncodecOutput:
- r"""
- input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
- Raw audio input converted to Float and padded to the appropriate length in order to be encoded using chunks
- of length self.chunk_length and a stride of `config.chunk_stride`.
- padding_mask (`torch.BoolTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
- Mask to avoid computing scaling factors on padding token indices (can we avoid computing conv on these+).
- Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- <Tip warning={true}>
- `padding_mask` should always be passed, unless the input was truncated or not padded. This is because in
- order to process tensors effectively, the input audio should be padded so that `input_length % stride =
- step` with `step = chunk_length-stride`. This ensures that all chunks are of the same shape
- </Tip>
- bandwidth (`float`, *optional*):
- The target bandwidth. Must be one of `config.target_bandwidths`. If `None`, uses the smallest possible
- bandwidth. bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as
- `bandwidth == 6.0`
- audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
- Discrete code embeddings computed using `model.encode`.
- audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*):
- Scaling factor for each `audio_codes` input.
- return_dict (`bool`, *optional*):
- Whether to return outputs as a dict.
- last_frame_pad_length (`int`, *optional*):
- The length of the padding in the last frame, if any. This is used to ensure that the encoded frames can be
- outputted as a tensor. This value should be passed during decoding to ensure padding is removed from the
- encoded frames.
- Examples:
- ```python
- >>> from datasets import load_dataset
- >>> from transformers import AutoProcessor, EncodecModel
- >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
- >>> audio_sample = dataset["train"]["audio"][0]["array"]
- >>> model_id = "facebook/encodec_24khz"
- >>> model = EncodecModel.from_pretrained(model_id)
- >>> processor = AutoProcessor.from_pretrained(model_id)
- >>> inputs = processor(raw_audio=audio_sample, return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> audio_codes = outputs.audio_codes
- >>> audio_values = outputs.audio_values
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if padding_mask is None:
- padding_mask = torch.ones_like(input_values).bool()
- else:
- # ensure that channel dimension is present
- padding_mask = padding_mask.view(padding_mask.shape[0], -1, padding_mask.shape[-1])
- if audio_codes is not None and audio_scales is None:
- raise ValueError("You specified `audio_codes` but did not specify the `audio_scales`")
- if audio_scales is not None and audio_codes is None:
- raise ValueError("You specified `audio_scales` but did not specify the `audio_codes`")
- if audio_scales is None and audio_codes is None:
- audio_codes, audio_scales, last_frame_pad_length = self.encode(
- input_values, padding_mask, bandwidth, False
- )
- audio_values = self.decode(
- audio_codes,
- audio_scales,
- padding_mask,
- return_dict=return_dict,
- last_frame_pad_length=last_frame_pad_length,
- )[0]
- if not return_dict:
- return (audio_codes, audio_values)
- return EncodecOutput(audio_codes=audio_codes, audio_values=audio_values)
- __all__ = ["EncodecModel", "EncodecPreTrainedModel"]
|