modeling_dac.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689
  1. # Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Transformers DAC model."""
  15. import math
  16. from dataclasses import dataclass
  17. import numpy as np
  18. import torch
  19. import torch.nn as nn
  20. import torch.nn.functional as F
  21. from ... import initialization as init
  22. from ...modeling_utils import PreTrainedAudioTokenizerBase
  23. from ...utils import ModelOutput, auto_docstring
  24. from .configuration_dac import DacConfig
  25. @dataclass
  26. @auto_docstring
  27. class DacOutput(ModelOutput):
  28. r"""
  29. loss (`torch.Tensor`):
  30. Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
  31. audio_values (`torch.Tensor` of shape `(batch_size, input_length)`):
  32. Reconstructed audio data.
  33. quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  34. Quantized continuous representation of input.
  35. audio_codes (`torch.LongTensor` of shape `(batch_size, num_codebooks, time_steps)`):
  36. Codebook indices for each codebook (quantized discrete representation of input).
  37. projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
  38. Projected latents (continuous representation of input before quantization).
  39. """
  40. loss: torch.FloatTensor | None = None
  41. audio_values: torch.FloatTensor | None = None
  42. quantized_representation: torch.FloatTensor | None = None
  43. audio_codes: torch.LongTensor | None = None
  44. projected_latents: torch.FloatTensor | None = None
  45. @dataclass
  46. @auto_docstring
  47. class DacEncoderOutput(ModelOutput):
  48. r"""
  49. loss (`torch.Tensor`):
  50. Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
  51. quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`, *optional*):
  52. Quantized continuous representation of input.
  53. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
  54. Codebook indices for each codebook (quantized discrete representation of input).
  55. projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`, *optional*):
  56. Projected latents (continuous representation of input before quantization).
  57. """
  58. loss: torch.FloatTensor | None = None
  59. quantized_representation: torch.FloatTensor | None = None
  60. audio_codes: torch.FloatTensor | None = None
  61. projected_latents: torch.FloatTensor | None = None
  62. @dataclass
  63. @auto_docstring
  64. # Copied from transformers.models.encodec.modeling_encodec.EncodecDecoderOutput with Encodec->Dac, segment_length->input_length
  65. class DacDecoderOutput(ModelOutput):
  66. r"""
  67. audio_values (`torch.FloatTensor` of shape `(batch_size, input_length)`, *optional*):
  68. Decoded audio values, obtained using the decoder part of Dac.
  69. """
  70. audio_values: torch.FloatTensor | None = None
  71. class Snake1d(nn.Module):
  72. """
  73. A 1-dimensional Snake activation function module.
  74. """
  75. def __init__(self, hidden_dim):
  76. super().__init__()
  77. self.alpha = nn.Parameter(torch.ones(1, hidden_dim, 1))
  78. def forward(self, hidden_states):
  79. shape = hidden_states.shape
  80. hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
  81. hidden_states = hidden_states + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * hidden_states).pow(2)
  82. hidden_states = hidden_states.reshape(shape)
  83. return hidden_states
  84. class DacVectorQuantize(nn.Module):
  85. """
  86. Implementation of VQ similar to Karpathy's repo (https://github.com/karpathy/deep-vector-quantization)
  87. Additionally uses following tricks from improved VQGAN
  88. (https://huggingface.co/papers/2110.04627):
  89. 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
  90. for improved codebook usage
  91. 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
  92. improves training stability
  93. """
  94. def __init__(self, config: DacConfig):
  95. super().__init__()
  96. self.codebook_dim = config.codebook_dim
  97. self.in_proj = nn.Conv1d(config.hidden_size, config.codebook_dim, kernel_size=1)
  98. self.out_proj = nn.Conv1d(config.codebook_dim, config.hidden_size, kernel_size=1)
  99. self.codebook = nn.Embedding(config.codebook_size, config.codebook_dim)
  100. def forward(self, hidden_state):
  101. """
  102. Quantizes the input tensor using a fixed codebook and returns the corresponding codebook vectors.
  103. Args:
  104. hidden_state (`torch.FloatTensor` of shape `(batch_size, dimension, time_steps)`):
  105. Input tensor.
  106. Returns:
  107. quantized_representation (`torch.Tensor`of shape `(batch_size, dimension, time_steps)`):
  108. Quantized continuous representation of input.
  109. commitment_loss (`torch.FloatTensor`of shape `(1)`):
  110. Commitment loss to train encoder to predict vectors closer to codebook entries.
  111. codebook_loss (`torch.FloatTensor`of shape `(1)`):
  112. Codebook loss to update the codebook.
  113. audio_codes (`torch.LongTensor` of shape `(batch_size, time_steps)`):
  114. Codebook indices for each codebook, quantized discrete representation of input.
  115. projected_latents (torch.FloatTensor of shape `(batch_size, num_codebooks * dimension, time_steps)`):
  116. Projected latents (continuous representation of input before quantization).
  117. """
  118. projected_latents = self.in_proj(hidden_state)
  119. quantized_representation, audio_codes = self.decode_latents(projected_latents)
  120. commitment_loss = F.mse_loss(projected_latents, quantized_representation.detach(), reduction="mean")
  121. codebook_loss = F.mse_loss(quantized_representation, projected_latents.detach(), reduction="mean")
  122. # noop in forward pass, straight-through gradient estimator in backward pass
  123. quantized_representation = projected_latents + (quantized_representation - projected_latents).detach()
  124. quantized_representation = self.out_proj(quantized_representation)
  125. return quantized_representation, commitment_loss, codebook_loss, audio_codes, projected_latents
  126. def decode_latents(self, hidden_states):
  127. batch_size, hidden_dim, sequence_length = hidden_states.shape
  128. encodings = hidden_states.permute(0, 2, 1).reshape(batch_size * sequence_length, hidden_dim)
  129. codebook = self.codebook.weight # codebook: (N x D)
  130. # L2 normalize encodings and codebook (ViT-VQGAN)
  131. encodings = F.normalize(encodings)
  132. codebook = F.normalize(codebook)
  133. # Compute euclidean distance with codebook
  134. l2_norm = encodings.pow(2).sum(1, keepdim=True)
  135. dist = -(l2_norm - 2 * encodings @ codebook.t()) + codebook.pow(2).sum(1, keepdim=True).t()
  136. indices = dist.max(1)[1]
  137. indices = indices.reshape(hidden_states.size(0), -1)
  138. quantized_representation = self.codebook(indices).transpose(1, 2)
  139. return quantized_representation, indices
  140. class DacResidualUnit(nn.Module):
  141. """
  142. A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
  143. """
  144. def __init__(self, dimension: int = 16, dilation: int = 1):
  145. super().__init__()
  146. pad = ((7 - 1) * dilation) // 2
  147. self.snake1 = Snake1d(dimension)
  148. self.conv1 = nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)
  149. self.snake2 = Snake1d(dimension)
  150. self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1)
  151. def forward(self, hidden_state):
  152. """
  153. Forward pass through the residual unit.
  154. Args:
  155. hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
  156. Input tensor .
  157. Returns:
  158. output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
  159. Input tensor after passing through the residual unit.
  160. """
  161. output_tensor = hidden_state
  162. output_tensor = self.conv1(self.snake1(output_tensor))
  163. output_tensor = self.conv2(self.snake2(output_tensor))
  164. padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
  165. if padding > 0:
  166. hidden_state = hidden_state[..., padding:-padding]
  167. output_tensor = hidden_state + output_tensor
  168. return output_tensor
  169. class DacEncoderBlock(nn.Module):
  170. """Encoder block used in DAC encoder."""
  171. def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
  172. super().__init__()
  173. dimension = config.encoder_hidden_size * 2**stride_index
  174. self.res_unit1 = DacResidualUnit(dimension // 2, dilation=1)
  175. self.res_unit2 = DacResidualUnit(dimension // 2, dilation=3)
  176. self.res_unit3 = DacResidualUnit(dimension // 2, dilation=9)
  177. self.snake1 = Snake1d(dimension // 2)
  178. self.conv1 = nn.Conv1d(
  179. dimension // 2, dimension, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
  180. )
  181. def forward(self, hidden_state):
  182. hidden_state = self.res_unit1(hidden_state)
  183. hidden_state = self.res_unit2(hidden_state)
  184. hidden_state = self.snake1(self.res_unit3(hidden_state))
  185. hidden_state = self.conv1(hidden_state)
  186. return hidden_state
  187. class DacDecoderBlock(nn.Module):
  188. """Decoder block used in DAC decoder."""
  189. def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
  190. super().__init__()
  191. input_dim = config.decoder_hidden_size // 2**stride_index
  192. output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
  193. self.snake1 = Snake1d(input_dim)
  194. self.conv_t1 = nn.ConvTranspose1d(
  195. input_dim,
  196. output_dim,
  197. kernel_size=2 * stride,
  198. stride=stride,
  199. padding=math.ceil(stride / 2),
  200. )
  201. self.res_unit1 = DacResidualUnit(output_dim, dilation=1)
  202. self.res_unit2 = DacResidualUnit(output_dim, dilation=3)
  203. self.res_unit3 = DacResidualUnit(output_dim, dilation=9)
  204. def forward(self, hidden_state):
  205. hidden_state = self.snake1(hidden_state)
  206. hidden_state = self.conv_t1(hidden_state)
  207. hidden_state = self.res_unit1(hidden_state)
  208. hidden_state = self.res_unit2(hidden_state)
  209. hidden_state = self.res_unit3(hidden_state)
  210. return hidden_state
  211. class DacResidualVectorQuantizer(nn.Module):
  212. """
  213. ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://huggingface.co/papers/2107.03312)
  214. """
  215. def __init__(self, config: DacConfig):
  216. super().__init__()
  217. n_codebooks = config.n_codebooks
  218. quantizer_dropout = config.quantizer_dropout
  219. self.n_codebooks = n_codebooks
  220. self.quantizers = nn.ModuleList([DacVectorQuantize(config) for i in range(config.n_codebooks)])
  221. self.quantizer_dropout = quantizer_dropout
  222. def forward(self, hidden_state, n_quantizers: int | None = None):
  223. """
  224. Quantizes the input tensor using a fixed set of codebooks and returns corresponding codebook vectors.
  225. Args:
  226. hidden_state (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  227. Input tensor to be quantized.
  228. n_quantizers (`int`, *optional*):
  229. Number of quantizers to use. If specified and `self.quantizer_dropout` is True,
  230. this argument is ignored during training, and a random number of quantizers is used.
  231. Returns:
  232. quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  233. Quantized continuous representation of input.
  234. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
  235. Codebook indices for each codebook (quantized discrete representation of input).
  236. projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
  237. Projected latents (continuous representation of input before quantization).
  238. commitment_loss (`torch.Tensor` of shape `(1)`):
  239. Commitment loss to train the encoder to predict vectors closer to codebook entries.
  240. codebook_loss (`torch.Tensor` of shape `(1)`):
  241. Codebook loss to update the codebook.
  242. """
  243. quantized_representation = 0
  244. residual = hidden_state
  245. commitment_loss = 0
  246. codebook_loss = 0
  247. audio_codes = []
  248. projected_latents = []
  249. n_quantizers = n_quantizers if n_quantizers is not None else self.n_codebooks
  250. if self.training:
  251. n_quantizers = torch.ones((hidden_state.shape[0],)) * self.n_codebooks + 1
  252. dropout = torch.randint(1, self.n_codebooks + 1, (hidden_state.shape[0],))
  253. n_dropout = int(hidden_state.shape[0] * self.quantizer_dropout)
  254. n_quantizers[:n_dropout] = dropout[:n_dropout]
  255. n_quantizers = n_quantizers.to(hidden_state.device)
  256. for i, quantizer in enumerate(self.quantizers):
  257. if self.training is False and i >= n_quantizers:
  258. break
  259. quantized_representation_i, commitment_loss_i, codebook_loss_i, indices_i, projected_latents_i = quantizer(
  260. residual
  261. )
  262. # Create mask to apply quantizer dropout
  263. mask = torch.full((hidden_state.shape[0],), i, device=hidden_state.device, dtype=torch.long) < n_quantizers
  264. quantized_representation = quantized_representation + quantized_representation_i * mask[:, None, None]
  265. residual = residual - quantized_representation_i
  266. # Sum losses
  267. commitment_loss += commitment_loss_i * mask
  268. codebook_loss += codebook_loss_i * mask
  269. audio_codes.append(indices_i)
  270. projected_latents.append(projected_latents_i)
  271. audio_codes = torch.stack(audio_codes, dim=1)
  272. projected_latents = torch.cat(projected_latents, dim=1)
  273. return quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss
  274. def from_codes(self, audio_codes: torch.Tensor):
  275. """
  276. Reconstructs the continuous representation from quantized codes.
  277. Args:
  278. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
  279. Quantized discrete representation of input.
  280. Returns:
  281. quantized_representation (`torch.Tensor`):
  282. Quantized continuous representation of input.
  283. projected_latents (`torch.Tensor`):
  284. List of projected latents (continuous representations of input before quantization)
  285. for each codebook.
  286. audio_codes (`torch.Tensor`):
  287. Codebook indices for each codebook.
  288. """
  289. quantized_representation = 0.0
  290. projected_latents = []
  291. n_codebooks = audio_codes.shape[1]
  292. for i in range(n_codebooks):
  293. projected_latents_i = self.quantizers[i].codebook(audio_codes[:, i, :]).transpose(1, 2)
  294. projected_latents.append(projected_latents_i)
  295. quantized_representation += self.quantizers[i].out_proj(projected_latents_i)
  296. return quantized_representation, torch.cat(projected_latents, dim=1), audio_codes
  297. def from_latents(self, latents: torch.Tensor):
  298. """Reconstructs the quantized representation from unquantized latents.
  299. Args:
  300. latents (`torch.Tensor` of shape `(batch_size, total_latent_dimension, time_steps)`):
  301. Continuous representation of input after projection.
  302. Returns:
  303. quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  304. Quantized representation of the full-projected space.
  305. quantized_latents (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  306. Quantized representation of the latent space (continuous representation before quantization).
  307. """
  308. quantized_representation = 0
  309. quantized_latents = []
  310. codes = []
  311. codebook_dims_tensor = torch.tensor([0] + [q.codebook_dim for q in self.quantizers])
  312. dims = torch.cumsum(codebook_dims_tensor, dim=0)
  313. n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
  314. for i in range(n_codebooks):
  315. hidden_dim_j, hidden_dim_k = dims[i], dims[i + 1]
  316. latent_chunk = latents[:, hidden_dim_j:hidden_dim_k, :]
  317. quantized_latents_i, codes_i = self.quantizers[i].decode_latents(latent_chunk)
  318. quantized_latents.append(quantized_latents_i)
  319. codes.append(codes_i)
  320. quantized_with_ste = latent_chunk + (quantized_latents_i - latent_chunk)
  321. quantized_representation_i = self.quantizers[i].out_proj(quantized_with_ste)
  322. quantized_representation = quantized_representation + quantized_representation_i
  323. return quantized_representation, torch.cat(quantized_latents, dim=1)
  324. class DacDecoder(nn.Module):
  325. """DAC Decoder"""
  326. def __init__(self, config: DacConfig):
  327. super().__init__()
  328. input_channel = config.hidden_size
  329. channels = config.decoder_hidden_size
  330. strides = config.upsampling_ratios
  331. # Add first conv layer
  332. self.conv1 = nn.Conv1d(input_channel, channels, kernel_size=7, padding=3)
  333. # Add upsampling + MRF blocks
  334. block = []
  335. for stride_index, stride in enumerate(strides):
  336. block += [DacDecoderBlock(config, stride, stride_index)]
  337. self.block = nn.ModuleList(block)
  338. output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
  339. self.snake1 = Snake1d(output_dim)
  340. self.conv2 = nn.Conv1d(output_dim, 1, kernel_size=7, padding=3)
  341. self.tanh = nn.Tanh()
  342. def forward(self, hidden_state):
  343. hidden_state = self.conv1(hidden_state)
  344. for layer in self.block:
  345. hidden_state = layer(hidden_state)
  346. hidden_state = self.snake1(hidden_state)
  347. hidden_state = self.conv2(hidden_state)
  348. hidden_state = self.tanh(hidden_state)
  349. return hidden_state
  350. class DacEncoder(nn.Module):
  351. """DAC Encoder"""
  352. def __init__(self, config: DacConfig):
  353. super().__init__()
  354. strides = config.downsampling_ratios
  355. # Create first convolution
  356. self.conv1 = nn.Conv1d(1, config.encoder_hidden_size, kernel_size=7, padding=3)
  357. self.block = []
  358. # Create EncoderBlocks that double channels as they downsample by `stride`
  359. for stride_index, stride in enumerate(strides):
  360. stride_index = stride_index + 1
  361. self.block += [DacEncoderBlock(config, stride=stride, stride_index=stride_index)]
  362. self.block = nn.ModuleList(self.block)
  363. d_model = config.encoder_hidden_size * 2**stride_index
  364. self.snake1 = Snake1d(d_model)
  365. self.conv2 = nn.Conv1d(d_model, config.hidden_size, kernel_size=3, padding=1)
  366. def forward(self, hidden_state):
  367. hidden_state = self.conv1(hidden_state)
  368. for module in self.block:
  369. hidden_state = module(hidden_state)
  370. hidden_state = self.snake1(hidden_state)
  371. hidden_state = self.conv2(hidden_state)
  372. return hidden_state
  373. @auto_docstring
  374. class DacPreTrainedModel(PreTrainedAudioTokenizerBase):
  375. config: DacConfig
  376. base_model_prefix = "dac"
  377. main_input_name = "input_values"
  378. @torch.no_grad()
  379. def _init_weights(self, module):
  380. if isinstance(module, nn.Conv1d):
  381. init.trunc_normal_(module.weight, std=0.02)
  382. init.constant_(module.bias, 0)
  383. elif isinstance(module, Snake1d):
  384. init.ones_(module.alpha)
  385. elif isinstance(module, nn.ConvTranspose1d):
  386. module.reset_parameters()
  387. elif isinstance(module, nn.Embedding):
  388. init.normal_(module.weight, mean=0.0, std=0.02)
  389. def apply_weight_norm(self):
  390. weight_norm = nn.utils.weight_norm
  391. if hasattr(nn.utils.parametrizations, "weight_norm"):
  392. weight_norm = nn.utils.parametrizations.weight_norm
  393. for layer in self.quantizer.quantizers:
  394. weight_norm(layer.in_proj)
  395. weight_norm(layer.out_proj)
  396. weight_norm(self.encoder.conv1)
  397. weight_norm(self.encoder.conv2)
  398. for layer in self.encoder.block:
  399. weight_norm(layer.conv1)
  400. weight_norm(layer.res_unit1.conv1)
  401. weight_norm(layer.res_unit1.conv2)
  402. weight_norm(layer.res_unit2.conv1)
  403. weight_norm(layer.res_unit2.conv2)
  404. weight_norm(layer.res_unit3.conv1)
  405. weight_norm(layer.res_unit3.conv2)
  406. weight_norm(self.decoder.conv1)
  407. weight_norm(self.decoder.conv2)
  408. for layer in self.decoder.block:
  409. weight_norm(layer.conv_t1)
  410. weight_norm(layer.res_unit1.conv1)
  411. weight_norm(layer.res_unit1.conv2)
  412. weight_norm(layer.res_unit2.conv1)
  413. weight_norm(layer.res_unit2.conv2)
  414. weight_norm(layer.res_unit3.conv1)
  415. weight_norm(layer.res_unit3.conv2)
  416. def remove_weight_norm(self):
  417. for layer in self.quantizer.quantizers:
  418. nn.utils.remove_weight_norm(layer.in_proj)
  419. nn.utils.remove_weight_norm(layer.out_proj)
  420. nn.utils.remove_weight_norm(self.encoder.conv1)
  421. nn.utils.remove_weight_norm(self.encoder.conv2)
  422. for layer in self.encoder.block:
  423. nn.utils.remove_weight_norm(layer.conv1)
  424. nn.utils.remove_weight_norm(layer.res_unit1.conv1)
  425. nn.utils.remove_weight_norm(layer.res_unit1.conv2)
  426. nn.utils.remove_weight_norm(layer.res_unit2.conv1)
  427. nn.utils.remove_weight_norm(layer.res_unit2.conv2)
  428. nn.utils.remove_weight_norm(layer.res_unit3.conv1)
  429. nn.utils.remove_weight_norm(layer.res_unit3.conv2)
  430. nn.utils.remove_weight_norm(self.decoder.conv1)
  431. nn.utils.remove_weight_norm(self.decoder.conv2)
  432. for layer in self.decoder.block:
  433. nn.utils.remove_weight_norm(layer.conv_t1)
  434. nn.utils.remove_weight_norm(layer.res_unit1.conv1)
  435. nn.utils.remove_weight_norm(layer.res_unit1.conv2)
  436. nn.utils.remove_weight_norm(layer.res_unit2.conv1)
  437. nn.utils.remove_weight_norm(layer.res_unit2.conv2)
  438. nn.utils.remove_weight_norm(layer.res_unit3.conv1)
  439. nn.utils.remove_weight_norm(layer.res_unit3.conv2)
  440. @auto_docstring(
  441. custom_intro="""
  442. The DAC (Descript Audio Codec) model.
  443. """
  444. )
  445. class DacModel(DacPreTrainedModel):
  446. input_modalities = "audio"
  447. def __init__(self, config: DacConfig):
  448. super().__init__(config)
  449. self.config = config
  450. self.encoder = DacEncoder(config)
  451. self.decoder = DacDecoder(config)
  452. self.quantizer = DacResidualVectorQuantizer(config)
  453. self.bits_per_codebook = int(math.log2(self.config.codebook_size))
  454. if 2**self.bits_per_codebook != self.config.codebook_size:
  455. raise ValueError("The codebook_size must be a power of 2.")
  456. # Initialize weights and apply final processing
  457. self.post_init()
  458. @auto_docstring
  459. def encode(
  460. self,
  461. input_values: torch.Tensor,
  462. n_quantizers: int | None = None,
  463. return_dict: bool | None = None,
  464. ) -> tuple | DacEncoderOutput:
  465. r"""
  466. input_values (`torch.Tensor of shape `(batch_size, 1, time_steps)`):
  467. Input audio data to encode,
  468. n_quantizers (int, *optional*):
  469. Number of quantizers to use. If None, all quantizers are used. Default is None.
  470. """
  471. return_dict = return_dict if return_dict is not None else self.config.return_dict
  472. quantized_representation = self.encoder(input_values)
  473. quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss = self.quantizer(
  474. quantized_representation, n_quantizers
  475. )
  476. loss = self.config.commitment_loss_weight * commitment_loss + self.config.codebook_loss_weight * codebook_loss
  477. if not return_dict:
  478. return (loss, quantized_representation, audio_codes, projected_latents)
  479. return DacEncoderOutput(loss, quantized_representation, audio_codes, projected_latents)
  480. @auto_docstring
  481. def decode(
  482. self,
  483. quantized_representation: torch.Tensor | None = None,
  484. audio_codes: torch.Tensor | None = None,
  485. return_dict: bool | None = None,
  486. ) -> tuple | DacDecoderOutput:
  487. r"""
  488. quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*):
  489. Quantized continuous representation of input.
  490. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
  491. The codebook indices for each codebook, representing the quantized discrete
  492. representation of the input. This parameter should be provided if you want
  493. to decode directly from the audio codes (it will overwrite quantized_representation).
  494. return_dict (`bool`, *optional*, defaults to `True`):
  495. Whether to return a [`DacDecoderOutput`] instead of a plain tuple.
  496. """
  497. if quantized_representation is None and audio_codes is None:
  498. raise ValueError("Either `quantized_representation` or `audio_codes` must be provided.")
  499. return_dict = return_dict if return_dict is not None else self.config.return_dict
  500. if audio_codes is not None:
  501. quantized_representation = self.quantizer.from_codes(audio_codes)[0]
  502. audio_values = self.decoder(quantized_representation).squeeze(1)
  503. if not return_dict:
  504. return (audio_values,)
  505. return DacDecoderOutput(audio_values)
  506. @auto_docstring
  507. def forward(
  508. self,
  509. input_values: torch.Tensor,
  510. n_quantizers: int | None = None,
  511. return_dict: bool | None = None,
  512. ) -> tuple | DacOutput:
  513. r"""
  514. input_values (`torch.Tensor` of shape `(batch_size, 1, time_steps)`):
  515. Audio data to encode.
  516. n_quantizers (`int`, *optional*):
  517. Number of quantizers to use. If `None`, all quantizers are used. Default is `None`.
  518. Examples:
  519. ```python
  520. >>> from datasets import load_dataset, Audio
  521. >>> from transformers import DacModel, AutoProcessor
  522. >>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  523. >>> model = DacModel.from_pretrained("descript/dac_16khz")
  524. >>> processor = AutoProcessor.from_pretrained("descript/dac_16khz")
  525. >>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
  526. >>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
  527. >>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
  528. >>> encoder_outputs = model.encode(inputs["input_values"])
  529. >>> # Get the intermediate audio codes
  530. >>> audio_codes = encoder_outputs.audio_codes
  531. >>> # Reconstruct the audio from its quantized representation
  532. >>> audio_values = model.decode(encoder_outputs.quantized_representation)
  533. >>> # or the equivalent with a forward pass
  534. >>> audio_values = model(inputs["input_values"]).audio_values
  535. ```"""
  536. return_dict = return_dict if return_dict is not None else self.config.return_dict
  537. length = input_values.shape[-1]
  538. loss, quantized_representation, audio_codes, projected_latents = self.encode(
  539. input_values, n_quantizers, return_dict=False
  540. )
  541. audio_values = self.decode(quantized_representation, return_dict=False)[0][..., :length]
  542. if not return_dict:
  543. return (loss, audio_values, quantized_representation, audio_codes, projected_latents)
  544. return DacOutput(loss, audio_values, quantized_representation, audio_codes, projected_latents)
  545. __all__ = ["DacModel", "DacPreTrainedModel"]