| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623 |
- # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Transformers Xcodec model."""
- import math
- from dataclasses import dataclass
- from functools import lru_cache
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from ... import initialization as init
- from ...audio_utils import conv1d_output_length
- from ...modeling_utils import PreTrainedAudioTokenizerBase
- from ...utils import ModelOutput, auto_docstring
- from ..auto import AutoModel
- from .configuration_xcodec import XcodecConfig
- @dataclass
- class XcodecOutput(ModelOutput):
- """
- Args:
- audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
- Discrete code indices computed using `model.encode`.
- audio_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`, *optional*)
- Decoded audio values obtained using the decoder part of Xcodec.
- """
- audio_codes: torch.LongTensor | None = None
- audio_values: torch.FloatTensor | None = None
- @dataclass
- class XcodecEncoderOutput(ModelOutput):
- """
- Args:
- audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
- Discrete code indices computed using `model.encode`.
- """
- audio_codes: torch.LongTensor | None = None
- @dataclass
- class XcodecDecoderOutput(ModelOutput):
- """
- Args:
- audio_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`, *optional*):
- Decoded audio values obtained using the decoder part of Xcodec.
- """
- audio_values: torch.FloatTensor | None = None
- class XcodecResidualUnit(nn.Module):
- """Residual block for SemanticEncoder and SemanticDecoder used in Xcodec."""
- def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, dilation: int):
- super().__init__()
- self.activation = nn.ELU()
- padding = ((config.unit_kernel_size - 1) // 2) * dilation
- self.conv1 = nn.Conv1d(
- in_channels,
- out_channels,
- config.unit_kernel_size,
- stride=1,
- padding=padding,
- dilation=dilation,
- groups=1,
- bias=False,
- )
- self.conv2 = nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, bias=False)
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
- output_tensor = self.activation(hidden_state)
- output_tensor = self.conv1(output_tensor)
- output_tensor = self.activation(output_tensor)
- output_tensor = self.conv2(output_tensor)
- return hidden_state + output_tensor
- class XcodecSemanticEncoderBlock(nn.Module):
- def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, stride: int):
- super().__init__()
- self.res_units = nn.ModuleList(
- [XcodecResidualUnit(config, in_channels, in_channels, dilation) for dilation in config.block_dilations]
- )
- # special case: stride=1, do not use kernel=2
- kernel = 3 if stride == 1 else (2 * stride)
- padding = (kernel - 1) // 2
- self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding, bias=True)
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
- for unit in self.res_units:
- hidden_state = unit(hidden_state)
- hidden_state = self.conv(hidden_state)
- return hidden_state
- class SemanticEncoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- if len(config.strides) != len(config.channel_ratios):
- raise ValueError("Number of strides must match the number of channel_ratios.")
- self.conv = nn.Conv1d(
- config.semantic_hidden_size,
- config.semantic_hidden_size,
- config.kernel_size,
- 1,
- config.kernel_size // 2,
- bias=False,
- )
- in_channels = config.semantic_hidden_size
- conv_blocks = []
- for i, stride in enumerate(config.strides):
- out_channels = int(config.semantic_hidden_size * config.channel_ratios[i])
- conv_blocks += [XcodecSemanticEncoderBlock(config, in_channels, out_channels, stride)]
- in_channels = out_channels
- self.conv_blocks = nn.ModuleList(conv_blocks)
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
- hidden_state = self.conv(hidden_state)
- for block in self.conv_blocks:
- hidden_state = block(hidden_state)
- return hidden_state
- class SemanticDecoderBlock(nn.Module):
- def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, stride: int):
- super().__init__()
- if stride == 1:
- self.conv = nn.Conv1d(
- in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=True,
- )
- else:
- kernel_size = 2 * stride
- padding = (stride + 1) // 2
- output_padding = 1 if stride % 2 == 1 else 0
- self.conv = nn.ConvTranspose1d(
- in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=False
- )
- self.res_units = nn.ModuleList(
- [XcodecResidualUnit(config, out_channels, out_channels, dilation) for dilation in config.block_dilations]
- )
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
- hidden_state = self.conv(hidden_state)
- for unit in self.res_units:
- hidden_state = unit(hidden_state)
- return hidden_state
- class SemanticDecoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.conv1 = nn.Conv1d(
- in_channels=config.semantic_hidden_size,
- out_channels=int(config.semantic_hidden_size * config.channel_ratios[0]),
- kernel_size=config.kernel_size,
- stride=1,
- padding=config.kernel_size // 2,
- bias=False,
- )
- conv_blocks = []
- for i, stride in enumerate(config.strides):
- in_channels = int(config.semantic_hidden_size * config.channel_ratios[i])
- if i < (len(config.channel_ratios) - 1):
- out_channels = int(config.semantic_hidden_size * config.channel_ratios[i + 1])
- else:
- out_channels = config.semantic_hidden_size
- conv_blocks += [SemanticDecoderBlock(config, in_channels, out_channels, stride)]
- self.conv_blocks = nn.ModuleList(conv_blocks)
- self.conv2 = nn.Conv1d(
- config.semantic_hidden_size,
- config.semantic_hidden_size,
- config.kernel_size,
- stride=1,
- padding=config.kernel_size // 2,
- bias=False,
- )
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
- hidden_state = self.conv1(hidden_state)
- for block in self.conv_blocks:
- hidden_state = block(hidden_state)
- hidden_state = self.conv2(hidden_state)
- return hidden_state
- class XcodecEuclideanCodebook(nn.Module):
- """Codebook with Euclidean distance."""
- def __init__(self, config):
- super().__init__()
- embed = torch.zeros(config.codebook_size, config.codebook_dim)
- self.codebook_size = config.codebook_size
- self.register_buffer("inited", torch.Tensor([True]))
- self.register_buffer("cluster_size", torch.zeros(config.codebook_size))
- self.register_buffer("embed", embed)
- self.register_buffer("embed_avg", embed.clone())
- # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.quantize
- def quantize(self, hidden_states):
- embed = self.embed.t()
- scaled_states = hidden_states.pow(2).sum(1, keepdim=True)
- dist = -(scaled_states - 2 * hidden_states @ embed + embed.pow(2).sum(0, keepdim=True))
- embed_ind = dist.max(dim=-1).indices
- return embed_ind
- def encode(self, hidden_states):
- shape = hidden_states.shape
- hidden_states = hidden_states.reshape((-1, shape[-1]))
- embed_ind = self.quantize(hidden_states)
- embed_ind = embed_ind.view(*shape[:-1])
- return embed_ind
- def decode(self, embed_ind):
- quantized = F.embedding(embed_ind.to(self.embed.device), self.embed)
- return quantized
- class XcodecVectorQuantization(nn.Module):
- """
- Vector quantization implementation. Currently supports only euclidean distance.
- """
- def __init__(self, config: XcodecConfig):
- super().__init__()
- self.codebook = XcodecEuclideanCodebook(config)
- # Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization.encode
- def encode(self, hidden_states):
- hidden_states = hidden_states.permute(0, 2, 1)
- embed_in = self.codebook.encode(hidden_states)
- return embed_in
- # Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization.decode
- def decode(self, embed_ind):
- quantize = self.codebook.decode(embed_ind)
- quantize = quantize.permute(0, 2, 1)
- return quantize
- class XcodecResidualVectorQuantization(nn.Module):
- """
- Residual vector quantization implementation. Follows Algorithm 1 in https://huggingface.co/papers/2107.03312
- """
- def __init__(self, config: XcodecConfig):
- super().__init__()
- self.quantizers = nn.ModuleList([XcodecVectorQuantization(config) for _ in range(config.num_quantizers)])
- self.frame_rate = config.frame_rate
- self.codebook_size = config.codebook_size
- self.num_quantizers = config.num_quantizers
- def get_bandwidth_per_quantizer(self):
- """Return bandwidth per quantizer."""
- return math.log2(self.codebook_size) * self.frame_rate / 1000
- def get_num_quantizers_for_bandwidth(self, bandwidth=None) -> int:
- """Return num_quantizers based on specified target bandwidth."""
- bw_per_q = self.get_bandwidth_per_quantizer()
- num_quantizers = self.num_quantizers
- if bandwidth is not None and bandwidth > 0.0:
- num_quantizers = int(max(1, math.floor(bandwidth / bw_per_q)))
- return num_quantizers
- def encode(self, embeddings: torch.Tensor, bandwidth=None) -> torch.Tensor:
- """
- Encode the input tensor into discrete indices using RVQ, with the number of quantizers selected based on the given bandwidth.
- Each quantizer /codebook residually quantizes the input and returns the nearest indices in terms of Euclidian distance.
- """
- num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
- residual = embeddings
- all_indices = []
- for quantizer in self.quantizers[:num_quantizers]:
- indices = quantizer.encode(residual)
- quantized = quantizer.decode(indices)
- residual = residual - quantized
- all_indices.append(indices)
- out_indices = torch.stack(all_indices)
- return out_indices
- def decode(self, codes: torch.Tensor) -> torch.Tensor:
- """Decode the given codes to their quantized representation."""
- quantized_out = torch.tensor(0.0, device=codes.device)
- for i, indices in enumerate(codes):
- quantizer = self.quantizers[i]
- quantized = quantizer.decode(indices)
- quantized_out = quantized_out + quantized.to(codes.device)
- return quantized_out
- @auto_docstring
- class XcodecPreTrainedModel(PreTrainedAudioTokenizerBase):
- """
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
- models.
- """
- config_class = XcodecConfig
- base_model_prefix = "xcodec"
- main_input_name = "input_values"
- input_modalities = "audio"
- _no_split_modules = ["XcodecResidualVectorQuantization"]
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- if isinstance(module, nn.Linear):
- init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- elif isinstance(module, nn.Conv1d):
- init.kaiming_normal_(module.weight)
- if module.bias is not None:
- k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
- init.uniform_(module.bias, a=-k, b=k)
- elif module.__class__.__name__ == "Snake1d":
- init.ones_(module.alpha)
- elif isinstance(module, nn.ConvTranspose1d):
- module.reset_parameters()
- elif isinstance(module, nn.Embedding):
- init.normal_(module.weight, mean=0.0, std=0.02)
- elif isinstance(module, XcodecModel):
- # The conv1d are not handled correctly, as `self.acoustic_encoder/decoder` are initialized from a PreTrainedModel,
- # but then only the submodules are used (which are not PreTrainedModels...) -> here we reinit them as in DacModel
- for submodule in module.acoustic_encoder.modules():
- if isinstance(submodule, nn.Conv1d):
- init.trunc_normal_(submodule.weight, std=0.02)
- init.constant_(submodule.bias, 0)
- for submodule in module.acoustic_decoder.modules():
- if isinstance(submodule, nn.Conv1d):
- init.trunc_normal_(submodule.weight, std=0.02)
- init.constant_(submodule.bias, 0)
- elif isinstance(module, XcodecEuclideanCodebook):
- init.copy_(module.inited, torch.Tensor([True]))
- init.zeros_(module.cluster_size)
- init.zeros_(module.embed)
- init.zeros_(module.embed_avg)
- def apply_weight_norm(self):
- """Apply weight norm in the acoustic encoder and decoder because the original checkpoint has weight norm applied."""
- weight_norm = torch.nn.utils.parametrizations.weight_norm
- weight_norm(self.acoustic_encoder.conv1)
- weight_norm(self.acoustic_encoder.conv2)
- for block in self.acoustic_encoder.block:
- weight_norm(block.conv1)
- for res_unit in (block.res_unit1, block.res_unit2, block.res_unit3):
- weight_norm(res_unit.conv1)
- weight_norm(res_unit.conv2)
- weight_norm(self.acoustic_decoder.conv1, name="weight")
- weight_norm(self.acoustic_decoder.conv2, name="weight")
- for block in self.acoustic_decoder.block:
- weight_norm(block.conv_t1, name="weight")
- for res_unit in (block.res_unit1, block.res_unit2, block.res_unit3):
- weight_norm(res_unit.conv1, name="weight")
- weight_norm(res_unit.conv2, name="weight")
- def remove_weight_norm(self):
- """Remove the weight norm from the acoustic encoder and decoder."""
- for module in (self.acoustic_encoder, self.acoustic_decoder):
- for m in module.modules():
- try:
- torch.nn.utils.remove_weight_norm(m, name="weight")
- except (ValueError, AttributeError):
- pass
- if hasattr(m, "parametrizations") and "weight" in m.parametrizations:
- torch.nn.utils.parametrize.remove_parametrizations(m, "weight", leave_parametrized=True)
- @lru_cache
- def _get_conv1d_layers(self, module):
- """
- Recursively iterate to fetch all Conv1d layers.
- """
- def get_conv1d_layers_recursive(module: nn.Module):
- params_list = []
- if isinstance(module, nn.Conv1d):
- params_list.append(module)
- # Recursively check all child modules
- for child in module.children():
- params_list.extend(get_conv1d_layers_recursive(child))
- return params_list
- return tuple(get_conv1d_layers_recursive(module))
- def _get_conv1d_output_lengths(self, input_length, module=None):
- """
- For a given module, compute the output length that would be obtained after all Conv1d layers.
- """
- if module is None:
- module = self
- conv1d_layers = self._get_conv1d_layers(module)
- for layer in conv1d_layers:
- input_length = conv1d_output_length(layer, input_length)
- return input_length
- @auto_docstring(custom_intro="""The Xcodec neural audio codec model.""")
- class XcodecModel(XcodecPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.config = config
- self.pad = config.hop_length // 2
- acoustic_model = AutoModel.from_config(config.acoustic_model_config)
- self.acoustic_encoder = acoustic_model.encoder
- self.acoustic_decoder = acoustic_model.decoder
- self._adjust_dac_decoder(self.acoustic_decoder)
- self.encoder_semantic = SemanticEncoder(config)
- self.decoder_semantic = SemanticDecoder(config)
- self.semantic_model = AutoModel.from_config(config.semantic_model_config).eval()
- self.fc = nn.Linear(config.hidden_size, config.hidden_size)
- self.fc1 = nn.Linear(config.hidden_size, config.semantic_model_config.hidden_size)
- self.fc2 = nn.Linear(config.hidden_size, config.acoustic_model_config.hidden_size)
- self.quantizer = XcodecResidualVectorQuantization(config)
- # Initialize weights and apply final processing
- self.post_init()
- @staticmethod
- def _adjust_dac_decoder(decoder: nn.Module):
- r"""
- DAC implemented in Xcodec is slightly different from the HF version.
- DAC in Xcodec adjusts the output padding in every ConvTranspose1d in the decoder and removes
- the final `nn.Tanh` activation function.
- """
- for module in decoder.modules():
- if isinstance(module, nn.ConvTranspose1d):
- stride = module.stride[0] if isinstance(module.stride, tuple) else module.stride
- module.output_padding = (stride % 2,)
- if hasattr(decoder, "tanh") and isinstance(decoder.tanh, nn.Tanh):
- decoder.tanh = nn.Identity()
- def _extract_semantic_features(self, input_values: torch.FloatTensor) -> torch.FloatTensor:
- input_values = input_values[:, 0, :]
- input_values = F.pad(input_values, (self.pad, self.pad))
- with torch.no_grad():
- outputs = self.semantic_model(input_values, output_hidden_states=True)
- hidden_states = outputs.hidden_states
- stacked = torch.stack(hidden_states, dim=1)
- return stacked.mean(dim=1)
- @auto_docstring
- def encode(
- self,
- input_values: torch.Tensor,
- bandwidth: float | None = None,
- return_dict: bool | None = None,
- ) -> torch.Tensor | XcodecEncoderOutput:
- r"""
- input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
- Float values of the input audio waveform.
- bandwidth (`float`, *optional*):
- The target bandwidth in (kbps) supports only values in `config.target_bandwidths`.
- Defaults to the highest available bandwidth `4.0` kbps.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`].
- Returns:
- `torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)` containing the discrete encoded audio codes.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- channels = input_values.shape[1]
- if channels != 1:
- raise ValueError(f"Audio must be mono, but got {channels}")
- if bandwidth is None:
- bandwidth = self.config.target_bandwidths[-1]
- elif bandwidth not in self.config.target_bandwidths:
- raise ValueError(
- f"This model doesn't support the bandwidth {bandwidth}. Select one of {self.config.target_bandwidths}."
- )
- e_semantic_input = self._extract_semantic_features(input_values).detach()
- e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
- # original codebase infer to get the output length, but we can directly infer it
- # from the model and know whether we should pad
- if self._get_conv1d_output_lengths(input_values.shape[2], self.acoustic_encoder) != e_semantic.shape[2]:
- e_acoustic = self.acoustic_encoder(F.pad(input_values, (self.pad, self.pad)))
- else:
- e_acoustic = self.acoustic_encoder(input_values)
- embeddings = torch.cat([e_acoustic.to(e_semantic.device), e_semantic], dim=1)
- embeddings = self.fc(embeddings.transpose(1, 2)).transpose(1, 2)
- audio_codes = self.quantizer.encode(embeddings, bandwidth)
- audio_codes = audio_codes.transpose(0, 1)
- if not return_dict:
- return audio_codes
- return XcodecEncoderOutput(audio_codes)
- @auto_docstring
- def decode(
- self,
- audio_codes: torch.Tensor,
- return_dict: bool | None = None,
- ) -> torch.Tensor | XcodecDecoderOutput:
- r"""
- audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`):
- Discrete code indices computed using `model.encode`.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`]
- Returns:
- Decoded audio values of shape `(batch_size, channels, num_samples)` obtained using the decoder part of
- Xcodec.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- audio_codes = audio_codes.transpose(0, 1)
- quantized = self.quantizer.decode(audio_codes)
- quantized_acoustic = self.fc2(quantized.transpose(1, 2)).transpose(1, 2)
- audio_values = self.acoustic_decoder(quantized_acoustic)
- if not return_dict:
- return audio_values
- return XcodecDecoderOutput(audio_values)
- @auto_docstring
- def forward(
- self,
- input_values: torch.Tensor,
- audio_codes: torch.Tensor | None = None,
- bandwidth: float | None = None,
- return_dict: bool | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor] | XcodecOutput:
- r"""
- input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
- The raw float values of the input audio waveform.
- audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`:
- Discrete code indices computed using `model.encode`.
- bandwidth (`float`, *optional*):
- Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
- bandwidth (`float`, *optional*):
- Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
- return_dict (`bool`, *optional*):
- Whether to return a [`XcodecOutput`] instead of a plain tuple.
- Returns:
- `XcodecOutput` or tuple `(audio_codes, audio_values)`:
- - `audio_codes` of shape `(batch_size, num_quantizers, codes_length)`: the quantized discrete codes.
- - `audio_values` of shape `(batch_size, channels, num_samples)`: the reconstructed audio waveform given the codes.
- Example:
- ```python
- >>> from datasets import load_dataset
- >>> from transformers import AutoFeatureExtractor, XcodecModel
- >>> model_id = "hf-audio/xcodec-hubert-librispeech"
- >>> model = XcodecModel.from_pretrained(model_id)
- >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
- >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
- >>> dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
- >>> audio_sample = dataset[0]['audio']['array']
- >>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> audio_codes = outputs.audio_codes
- >>> audio_values = outputs.audio_values
- ```
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- length = input_values.shape[-1]
- if audio_codes is None:
- audio_codes = self.encode(input_values, bandwidth, return_dict=False)
- audio_values = self.decode(audio_codes, return_dict=return_dict)[0][..., :length]
- if not return_dict:
- return (audio_codes, audio_values)
- return XcodecOutput(audio_codes=audio_codes, audio_values=audio_values)
- __all__ = ["XcodecModel", "XcodecPreTrainedModel"]
|