modeling_xcodec.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Transformers Xcodec model."""
  15. import math
  16. from dataclasses import dataclass
  17. from functools import lru_cache
  18. import torch
  19. import torch.nn as nn
  20. import torch.nn.functional as F
  21. from ... import initialization as init
  22. from ...audio_utils import conv1d_output_length
  23. from ...modeling_utils import PreTrainedAudioTokenizerBase
  24. from ...utils import ModelOutput, auto_docstring
  25. from ..auto import AutoModel
  26. from .configuration_xcodec import XcodecConfig
  27. @dataclass
  28. class XcodecOutput(ModelOutput):
  29. """
  30. Args:
  31. audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
  32. Discrete code indices computed using `model.encode`.
  33. audio_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`, *optional*)
  34. Decoded audio values obtained using the decoder part of Xcodec.
  35. """
  36. audio_codes: torch.LongTensor | None = None
  37. audio_values: torch.FloatTensor | None = None
  38. @dataclass
  39. class XcodecEncoderOutput(ModelOutput):
  40. """
  41. Args:
  42. audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
  43. Discrete code indices computed using `model.encode`.
  44. """
  45. audio_codes: torch.LongTensor | None = None
  46. @dataclass
  47. class XcodecDecoderOutput(ModelOutput):
  48. """
  49. Args:
  50. audio_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`, *optional*):
  51. Decoded audio values obtained using the decoder part of Xcodec.
  52. """
  53. audio_values: torch.FloatTensor | None = None
  54. class XcodecResidualUnit(nn.Module):
  55. """Residual block for SemanticEncoder and SemanticDecoder used in Xcodec."""
  56. def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, dilation: int):
  57. super().__init__()
  58. self.activation = nn.ELU()
  59. padding = ((config.unit_kernel_size - 1) // 2) * dilation
  60. self.conv1 = nn.Conv1d(
  61. in_channels,
  62. out_channels,
  63. config.unit_kernel_size,
  64. stride=1,
  65. padding=padding,
  66. dilation=dilation,
  67. groups=1,
  68. bias=False,
  69. )
  70. self.conv2 = nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, bias=False)
  71. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  72. output_tensor = self.activation(hidden_state)
  73. output_tensor = self.conv1(output_tensor)
  74. output_tensor = self.activation(output_tensor)
  75. output_tensor = self.conv2(output_tensor)
  76. return hidden_state + output_tensor
  77. class XcodecSemanticEncoderBlock(nn.Module):
  78. def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, stride: int):
  79. super().__init__()
  80. self.res_units = nn.ModuleList(
  81. [XcodecResidualUnit(config, in_channels, in_channels, dilation) for dilation in config.block_dilations]
  82. )
  83. # special case: stride=1, do not use kernel=2
  84. kernel = 3 if stride == 1 else (2 * stride)
  85. padding = (kernel - 1) // 2
  86. self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding, bias=True)
  87. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  88. for unit in self.res_units:
  89. hidden_state = unit(hidden_state)
  90. hidden_state = self.conv(hidden_state)
  91. return hidden_state
  92. class SemanticEncoder(nn.Module):
  93. def __init__(self, config):
  94. super().__init__()
  95. if len(config.strides) != len(config.channel_ratios):
  96. raise ValueError("Number of strides must match the number of channel_ratios.")
  97. self.conv = nn.Conv1d(
  98. config.semantic_hidden_size,
  99. config.semantic_hidden_size,
  100. config.kernel_size,
  101. 1,
  102. config.kernel_size // 2,
  103. bias=False,
  104. )
  105. in_channels = config.semantic_hidden_size
  106. conv_blocks = []
  107. for i, stride in enumerate(config.strides):
  108. out_channels = int(config.semantic_hidden_size * config.channel_ratios[i])
  109. conv_blocks += [XcodecSemanticEncoderBlock(config, in_channels, out_channels, stride)]
  110. in_channels = out_channels
  111. self.conv_blocks = nn.ModuleList(conv_blocks)
  112. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  113. hidden_state = self.conv(hidden_state)
  114. for block in self.conv_blocks:
  115. hidden_state = block(hidden_state)
  116. return hidden_state
  117. class SemanticDecoderBlock(nn.Module):
  118. def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, stride: int):
  119. super().__init__()
  120. if stride == 1:
  121. self.conv = nn.Conv1d(
  122. in_channels,
  123. out_channels,
  124. kernel_size=3,
  125. stride=1,
  126. padding=1,
  127. bias=True,
  128. )
  129. else:
  130. kernel_size = 2 * stride
  131. padding = (stride + 1) // 2
  132. output_padding = 1 if stride % 2 == 1 else 0
  133. self.conv = nn.ConvTranspose1d(
  134. in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=False
  135. )
  136. self.res_units = nn.ModuleList(
  137. [XcodecResidualUnit(config, out_channels, out_channels, dilation) for dilation in config.block_dilations]
  138. )
  139. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  140. hidden_state = self.conv(hidden_state)
  141. for unit in self.res_units:
  142. hidden_state = unit(hidden_state)
  143. return hidden_state
  144. class SemanticDecoder(nn.Module):
  145. def __init__(self, config):
  146. super().__init__()
  147. self.conv1 = nn.Conv1d(
  148. in_channels=config.semantic_hidden_size,
  149. out_channels=int(config.semantic_hidden_size * config.channel_ratios[0]),
  150. kernel_size=config.kernel_size,
  151. stride=1,
  152. padding=config.kernel_size // 2,
  153. bias=False,
  154. )
  155. conv_blocks = []
  156. for i, stride in enumerate(config.strides):
  157. in_channels = int(config.semantic_hidden_size * config.channel_ratios[i])
  158. if i < (len(config.channel_ratios) - 1):
  159. out_channels = int(config.semantic_hidden_size * config.channel_ratios[i + 1])
  160. else:
  161. out_channels = config.semantic_hidden_size
  162. conv_blocks += [SemanticDecoderBlock(config, in_channels, out_channels, stride)]
  163. self.conv_blocks = nn.ModuleList(conv_blocks)
  164. self.conv2 = nn.Conv1d(
  165. config.semantic_hidden_size,
  166. config.semantic_hidden_size,
  167. config.kernel_size,
  168. stride=1,
  169. padding=config.kernel_size // 2,
  170. bias=False,
  171. )
  172. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  173. hidden_state = self.conv1(hidden_state)
  174. for block in self.conv_blocks:
  175. hidden_state = block(hidden_state)
  176. hidden_state = self.conv2(hidden_state)
  177. return hidden_state
  178. class XcodecEuclideanCodebook(nn.Module):
  179. """Codebook with Euclidean distance."""
  180. def __init__(self, config):
  181. super().__init__()
  182. embed = torch.zeros(config.codebook_size, config.codebook_dim)
  183. self.codebook_size = config.codebook_size
  184. self.register_buffer("inited", torch.Tensor([True]))
  185. self.register_buffer("cluster_size", torch.zeros(config.codebook_size))
  186. self.register_buffer("embed", embed)
  187. self.register_buffer("embed_avg", embed.clone())
  188. # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.quantize
  189. def quantize(self, hidden_states):
  190. embed = self.embed.t()
  191. scaled_states = hidden_states.pow(2).sum(1, keepdim=True)
  192. dist = -(scaled_states - 2 * hidden_states @ embed + embed.pow(2).sum(0, keepdim=True))
  193. embed_ind = dist.max(dim=-1).indices
  194. return embed_ind
  195. def encode(self, hidden_states):
  196. shape = hidden_states.shape
  197. hidden_states = hidden_states.reshape((-1, shape[-1]))
  198. embed_ind = self.quantize(hidden_states)
  199. embed_ind = embed_ind.view(*shape[:-1])
  200. return embed_ind
  201. def decode(self, embed_ind):
  202. quantized = F.embedding(embed_ind.to(self.embed.device), self.embed)
  203. return quantized
  204. class XcodecVectorQuantization(nn.Module):
  205. """
  206. Vector quantization implementation. Currently supports only euclidean distance.
  207. """
  208. def __init__(self, config: XcodecConfig):
  209. super().__init__()
  210. self.codebook = XcodecEuclideanCodebook(config)
  211. # Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization.encode
  212. def encode(self, hidden_states):
  213. hidden_states = hidden_states.permute(0, 2, 1)
  214. embed_in = self.codebook.encode(hidden_states)
  215. return embed_in
  216. # Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization.decode
  217. def decode(self, embed_ind):
  218. quantize = self.codebook.decode(embed_ind)
  219. quantize = quantize.permute(0, 2, 1)
  220. return quantize
  221. class XcodecResidualVectorQuantization(nn.Module):
  222. """
  223. Residual vector quantization implementation. Follows Algorithm 1 in https://huggingface.co/papers/2107.03312
  224. """
  225. def __init__(self, config: XcodecConfig):
  226. super().__init__()
  227. self.quantizers = nn.ModuleList([XcodecVectorQuantization(config) for _ in range(config.num_quantizers)])
  228. self.frame_rate = config.frame_rate
  229. self.codebook_size = config.codebook_size
  230. self.num_quantizers = config.num_quantizers
  231. def get_bandwidth_per_quantizer(self):
  232. """Return bandwidth per quantizer."""
  233. return math.log2(self.codebook_size) * self.frame_rate / 1000
  234. def get_num_quantizers_for_bandwidth(self, bandwidth=None) -> int:
  235. """Return num_quantizers based on specified target bandwidth."""
  236. bw_per_q = self.get_bandwidth_per_quantizer()
  237. num_quantizers = self.num_quantizers
  238. if bandwidth is not None and bandwidth > 0.0:
  239. num_quantizers = int(max(1, math.floor(bandwidth / bw_per_q)))
  240. return num_quantizers
  241. def encode(self, embeddings: torch.Tensor, bandwidth=None) -> torch.Tensor:
  242. """
  243. Encode the input tensor into discrete indices using RVQ, with the number of quantizers selected based on the given bandwidth.
  244. Each quantizer /codebook residually quantizes the input and returns the nearest indices in terms of Euclidian distance.
  245. """
  246. num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
  247. residual = embeddings
  248. all_indices = []
  249. for quantizer in self.quantizers[:num_quantizers]:
  250. indices = quantizer.encode(residual)
  251. quantized = quantizer.decode(indices)
  252. residual = residual - quantized
  253. all_indices.append(indices)
  254. out_indices = torch.stack(all_indices)
  255. return out_indices
  256. def decode(self, codes: torch.Tensor) -> torch.Tensor:
  257. """Decode the given codes to their quantized representation."""
  258. quantized_out = torch.tensor(0.0, device=codes.device)
  259. for i, indices in enumerate(codes):
  260. quantizer = self.quantizers[i]
  261. quantized = quantizer.decode(indices)
  262. quantized_out = quantized_out + quantized.to(codes.device)
  263. return quantized_out
  264. @auto_docstring
  265. class XcodecPreTrainedModel(PreTrainedAudioTokenizerBase):
  266. """
  267. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  268. models.
  269. """
  270. config_class = XcodecConfig
  271. base_model_prefix = "xcodec"
  272. main_input_name = "input_values"
  273. input_modalities = "audio"
  274. _no_split_modules = ["XcodecResidualVectorQuantization"]
  275. @torch.no_grad()
  276. def _init_weights(self, module):
  277. """Initialize the weights"""
  278. if isinstance(module, nn.Linear):
  279. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  280. if module.bias is not None:
  281. init.zeros_(module.bias)
  282. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  283. init.zeros_(module.bias)
  284. init.ones_(module.weight)
  285. elif isinstance(module, nn.Conv1d):
  286. init.kaiming_normal_(module.weight)
  287. if module.bias is not None:
  288. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  289. init.uniform_(module.bias, a=-k, b=k)
  290. elif module.__class__.__name__ == "Snake1d":
  291. init.ones_(module.alpha)
  292. elif isinstance(module, nn.ConvTranspose1d):
  293. module.reset_parameters()
  294. elif isinstance(module, nn.Embedding):
  295. init.normal_(module.weight, mean=0.0, std=0.02)
  296. elif isinstance(module, XcodecModel):
  297. # The conv1d are not handled correctly, as `self.acoustic_encoder/decoder` are initialized from a PreTrainedModel,
  298. # but then only the submodules are used (which are not PreTrainedModels...) -> here we reinit them as in DacModel
  299. for submodule in module.acoustic_encoder.modules():
  300. if isinstance(submodule, nn.Conv1d):
  301. init.trunc_normal_(submodule.weight, std=0.02)
  302. init.constant_(submodule.bias, 0)
  303. for submodule in module.acoustic_decoder.modules():
  304. if isinstance(submodule, nn.Conv1d):
  305. init.trunc_normal_(submodule.weight, std=0.02)
  306. init.constant_(submodule.bias, 0)
  307. elif isinstance(module, XcodecEuclideanCodebook):
  308. init.copy_(module.inited, torch.Tensor([True]))
  309. init.zeros_(module.cluster_size)
  310. init.zeros_(module.embed)
  311. init.zeros_(module.embed_avg)
  312. def apply_weight_norm(self):
  313. """Apply weight norm in the acoustic encoder and decoder because the original checkpoint has weight norm applied."""
  314. weight_norm = torch.nn.utils.parametrizations.weight_norm
  315. weight_norm(self.acoustic_encoder.conv1)
  316. weight_norm(self.acoustic_encoder.conv2)
  317. for block in self.acoustic_encoder.block:
  318. weight_norm(block.conv1)
  319. for res_unit in (block.res_unit1, block.res_unit2, block.res_unit3):
  320. weight_norm(res_unit.conv1)
  321. weight_norm(res_unit.conv2)
  322. weight_norm(self.acoustic_decoder.conv1, name="weight")
  323. weight_norm(self.acoustic_decoder.conv2, name="weight")
  324. for block in self.acoustic_decoder.block:
  325. weight_norm(block.conv_t1, name="weight")
  326. for res_unit in (block.res_unit1, block.res_unit2, block.res_unit3):
  327. weight_norm(res_unit.conv1, name="weight")
  328. weight_norm(res_unit.conv2, name="weight")
  329. def remove_weight_norm(self):
  330. """Remove the weight norm from the acoustic encoder and decoder."""
  331. for module in (self.acoustic_encoder, self.acoustic_decoder):
  332. for m in module.modules():
  333. try:
  334. torch.nn.utils.remove_weight_norm(m, name="weight")
  335. except (ValueError, AttributeError):
  336. pass
  337. if hasattr(m, "parametrizations") and "weight" in m.parametrizations:
  338. torch.nn.utils.parametrize.remove_parametrizations(m, "weight", leave_parametrized=True)
  339. @lru_cache
  340. def _get_conv1d_layers(self, module):
  341. """
  342. Recursively iterate to fetch all Conv1d layers.
  343. """
  344. def get_conv1d_layers_recursive(module: nn.Module):
  345. params_list = []
  346. if isinstance(module, nn.Conv1d):
  347. params_list.append(module)
  348. # Recursively check all child modules
  349. for child in module.children():
  350. params_list.extend(get_conv1d_layers_recursive(child))
  351. return params_list
  352. return tuple(get_conv1d_layers_recursive(module))
  353. def _get_conv1d_output_lengths(self, input_length, module=None):
  354. """
  355. For a given module, compute the output length that would be obtained after all Conv1d layers.
  356. """
  357. if module is None:
  358. module = self
  359. conv1d_layers = self._get_conv1d_layers(module)
  360. for layer in conv1d_layers:
  361. input_length = conv1d_output_length(layer, input_length)
  362. return input_length
  363. @auto_docstring(custom_intro="""The Xcodec neural audio codec model.""")
  364. class XcodecModel(XcodecPreTrainedModel):
  365. def __init__(self, config):
  366. super().__init__(config)
  367. self.config = config
  368. self.pad = config.hop_length // 2
  369. acoustic_model = AutoModel.from_config(config.acoustic_model_config)
  370. self.acoustic_encoder = acoustic_model.encoder
  371. self.acoustic_decoder = acoustic_model.decoder
  372. self._adjust_dac_decoder(self.acoustic_decoder)
  373. self.encoder_semantic = SemanticEncoder(config)
  374. self.decoder_semantic = SemanticDecoder(config)
  375. self.semantic_model = AutoModel.from_config(config.semantic_model_config).eval()
  376. self.fc = nn.Linear(config.hidden_size, config.hidden_size)
  377. self.fc1 = nn.Linear(config.hidden_size, config.semantic_model_config.hidden_size)
  378. self.fc2 = nn.Linear(config.hidden_size, config.acoustic_model_config.hidden_size)
  379. self.quantizer = XcodecResidualVectorQuantization(config)
  380. # Initialize weights and apply final processing
  381. self.post_init()
  382. @staticmethod
  383. def _adjust_dac_decoder(decoder: nn.Module):
  384. r"""
  385. DAC implemented in Xcodec is slightly different from the HF version.
  386. DAC in Xcodec adjusts the output padding in every ConvTranspose1d in the decoder and removes
  387. the final `nn.Tanh` activation function.
  388. """
  389. for module in decoder.modules():
  390. if isinstance(module, nn.ConvTranspose1d):
  391. stride = module.stride[0] if isinstance(module.stride, tuple) else module.stride
  392. module.output_padding = (stride % 2,)
  393. if hasattr(decoder, "tanh") and isinstance(decoder.tanh, nn.Tanh):
  394. decoder.tanh = nn.Identity()
  395. def _extract_semantic_features(self, input_values: torch.FloatTensor) -> torch.FloatTensor:
  396. input_values = input_values[:, 0, :]
  397. input_values = F.pad(input_values, (self.pad, self.pad))
  398. with torch.no_grad():
  399. outputs = self.semantic_model(input_values, output_hidden_states=True)
  400. hidden_states = outputs.hidden_states
  401. stacked = torch.stack(hidden_states, dim=1)
  402. return stacked.mean(dim=1)
  403. @auto_docstring
  404. def encode(
  405. self,
  406. input_values: torch.Tensor,
  407. bandwidth: float | None = None,
  408. return_dict: bool | None = None,
  409. ) -> torch.Tensor | XcodecEncoderOutput:
  410. r"""
  411. input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
  412. Float values of the input audio waveform.
  413. bandwidth (`float`, *optional*):
  414. The target bandwidth in (kbps) supports only values in `config.target_bandwidths`.
  415. Defaults to the highest available bandwidth `4.0` kbps.
  416. return_dict (`bool`, *optional*):
  417. Whether or not to return a [`~utils.ModelOutput`].
  418. Returns:
  419. `torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)` containing the discrete encoded audio codes.
  420. """
  421. return_dict = return_dict if return_dict is not None else self.config.return_dict
  422. channels = input_values.shape[1]
  423. if channels != 1:
  424. raise ValueError(f"Audio must be mono, but got {channels}")
  425. if bandwidth is None:
  426. bandwidth = self.config.target_bandwidths[-1]
  427. elif bandwidth not in self.config.target_bandwidths:
  428. raise ValueError(
  429. f"This model doesn't support the bandwidth {bandwidth}. Select one of {self.config.target_bandwidths}."
  430. )
  431. e_semantic_input = self._extract_semantic_features(input_values).detach()
  432. e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
  433. # original codebase infer to get the output length, but we can directly infer it
  434. # from the model and know whether we should pad
  435. if self._get_conv1d_output_lengths(input_values.shape[2], self.acoustic_encoder) != e_semantic.shape[2]:
  436. e_acoustic = self.acoustic_encoder(F.pad(input_values, (self.pad, self.pad)))
  437. else:
  438. e_acoustic = self.acoustic_encoder(input_values)
  439. embeddings = torch.cat([e_acoustic.to(e_semantic.device), e_semantic], dim=1)
  440. embeddings = self.fc(embeddings.transpose(1, 2)).transpose(1, 2)
  441. audio_codes = self.quantizer.encode(embeddings, bandwidth)
  442. audio_codes = audio_codes.transpose(0, 1)
  443. if not return_dict:
  444. return audio_codes
  445. return XcodecEncoderOutput(audio_codes)
  446. @auto_docstring
  447. def decode(
  448. self,
  449. audio_codes: torch.Tensor,
  450. return_dict: bool | None = None,
  451. ) -> torch.Tensor | XcodecDecoderOutput:
  452. r"""
  453. audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`):
  454. Discrete code indices computed using `model.encode`.
  455. return_dict (`bool`, *optional*):
  456. Whether or not to return a [`~utils.ModelOutput`]
  457. Returns:
  458. Decoded audio values of shape `(batch_size, channels, num_samples)` obtained using the decoder part of
  459. Xcodec.
  460. """
  461. return_dict = return_dict if return_dict is not None else self.config.return_dict
  462. audio_codes = audio_codes.transpose(0, 1)
  463. quantized = self.quantizer.decode(audio_codes)
  464. quantized_acoustic = self.fc2(quantized.transpose(1, 2)).transpose(1, 2)
  465. audio_values = self.acoustic_decoder(quantized_acoustic)
  466. if not return_dict:
  467. return audio_values
  468. return XcodecDecoderOutput(audio_values)
  469. @auto_docstring
  470. def forward(
  471. self,
  472. input_values: torch.Tensor,
  473. audio_codes: torch.Tensor | None = None,
  474. bandwidth: float | None = None,
  475. return_dict: bool | None = None,
  476. ) -> tuple[torch.Tensor, torch.Tensor] | XcodecOutput:
  477. r"""
  478. input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
  479. The raw float values of the input audio waveform.
  480. audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`:
  481. Discrete code indices computed using `model.encode`.
  482. bandwidth (`float`, *optional*):
  483. Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
  484. bandwidth (`float`, *optional*):
  485. Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
  486. return_dict (`bool`, *optional*):
  487. Whether to return a [`XcodecOutput`] instead of a plain tuple.
  488. Returns:
  489. `XcodecOutput` or tuple `(audio_codes, audio_values)`:
  490. - `audio_codes` of shape `(batch_size, num_quantizers, codes_length)`: the quantized discrete codes.
  491. - `audio_values` of shape `(batch_size, channels, num_samples)`: the reconstructed audio waveform given the codes.
  492. Example:
  493. ```python
  494. >>> from datasets import load_dataset
  495. >>> from transformers import AutoFeatureExtractor, XcodecModel
  496. >>> model_id = "hf-audio/xcodec-hubert-librispeech"
  497. >>> model = XcodecModel.from_pretrained(model_id)
  498. >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
  499. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  500. >>> dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
  501. >>> audio_sample = dataset[0]['audio']['array']
  502. >>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="pt")
  503. >>> outputs = model(**inputs)
  504. >>> audio_codes = outputs.audio_codes
  505. >>> audio_values = outputs.audio_values
  506. ```
  507. """
  508. return_dict = return_dict if return_dict is not None else self.config.return_dict
  509. length = input_values.shape[-1]
  510. if audio_codes is None:
  511. audio_codes = self.encode(input_values, bandwidth, return_dict=False)
  512. audio_values = self.decode(audio_codes, return_dict=return_dict)[0][..., :length]
  513. if not return_dict:
  514. return (audio_codes, audio_values)
  515. return XcodecOutput(audio_codes=audio_codes, audio_values=audio_values)
  516. __all__ = ["XcodecModel", "XcodecPreTrainedModel"]