modeling_encodec.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822
  1. # Copyright 2023 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch EnCodec model."""
  15. import math
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from ... import initialization as init
  20. from ...modeling_utils import PreTrainedAudioTokenizerBase
  21. from ...utils import (
  22. ModelOutput,
  23. auto_docstring,
  24. logging,
  25. )
  26. from .configuration_encodec import EncodecConfig
  27. logger = logging.get_logger(__name__)
  28. # General docstring
  29. @dataclass
  30. @auto_docstring
  31. class EncodecOutput(ModelOutput):
  32. r"""
  33. audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
  34. Discrete code embeddings computed using `model.encode`.
  35. audio_values (`torch.FloatTensor` of shape `(batch_size, segment_length)`, *optional*):
  36. Decoded audio values, obtained using the decoder part of Encodec.
  37. """
  38. audio_codes: torch.LongTensor | None = None
  39. audio_values: torch.FloatTensor | None = None
  40. @dataclass
  41. @auto_docstring
  42. class EncodecEncoderOutput(ModelOutput):
  43. r"""
  44. audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
  45. Discrete code embeddings computed using `model.encode`.
  46. audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*):
  47. Scaling factor for each `audio_codes` input. This is used to unscale each chunk of audio when decoding.
  48. last_frame_pad_length (`int`, *optional*):
  49. The length of the padding in the last frame, if any. This is used to ensure that the encoded frames can be
  50. outputted as a tensor. This value should be passed during decoding to ensure padding is removed from the
  51. encoded frames.
  52. """
  53. audio_codes: torch.LongTensor | None = None
  54. audio_scales: torch.FloatTensor | None = None
  55. last_frame_pad_length: int | None = None
  56. @dataclass
  57. @auto_docstring
  58. class EncodecDecoderOutput(ModelOutput):
  59. r"""
  60. audio_values (`torch.FloatTensor` of shape `(batch_size, segment_length)`, *optional*):
  61. Decoded audio values, obtained using the decoder part of Encodec.
  62. """
  63. audio_values: torch.FloatTensor | None = None
  64. class EncodecConv1d(nn.Module):
  65. """Conv1d with asymmetric or causal padding and normalization."""
  66. def __init__(
  67. self, config, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1
  68. ):
  69. super().__init__()
  70. self.causal = config.use_causal_conv
  71. self.pad_mode = config.pad_mode
  72. self.norm_type = config.norm_type
  73. if self.norm_type not in ["weight_norm", "time_group_norm"]:
  74. raise ValueError(
  75. f'self.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got {self.norm_type}'
  76. )
  77. # warn user on unusual setup between dilation and stride
  78. if stride > 1 and dilation > 1:
  79. logger.warning(
  80. "EncodecConv1d has been initialized with stride > 1 and dilation > 1"
  81. f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
  82. )
  83. self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, dilation=dilation)
  84. weight_norm = nn.utils.weight_norm
  85. if hasattr(nn.utils.parametrizations, "weight_norm"):
  86. weight_norm = nn.utils.parametrizations.weight_norm
  87. if self.norm_type == "weight_norm":
  88. self.conv = weight_norm(self.conv)
  89. elif self.norm_type == "time_group_norm":
  90. self.norm = nn.GroupNorm(1, out_channels)
  91. kernel_size = self.conv.kernel_size[0]
  92. stride = torch.tensor(self.conv.stride[0], dtype=torch.int64)
  93. dilation = self.conv.dilation[0]
  94. # Effective kernel size with dilations.
  95. kernel_size = torch.tensor((kernel_size - 1) * dilation + 1, dtype=torch.int64)
  96. self.register_buffer("stride", stride, persistent=False)
  97. self.register_buffer("kernel_size", kernel_size, persistent=False)
  98. self.register_buffer("padding_total", kernel_size - stride, persistent=False)
  99. def _get_extra_padding_for_conv1d(
  100. self,
  101. hidden_states: torch.Tensor,
  102. ) -> torch.Tensor:
  103. """See `pad_for_conv1d`."""
  104. length = hidden_states.shape[-1]
  105. n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
  106. n_frames = torch.ceil(n_frames).to(torch.int64) - 1
  107. ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
  108. return ideal_length - length
  109. @staticmethod
  110. def _pad1d(hidden_states: torch.Tensor, paddings: tuple[int, int], mode: str = "zero", value: float = 0.0):
  111. """Tiny wrapper around torch.nn.functional.pad, just to allow for reflect padding on small input.
  112. If this is the case, we insert extra 0 padding to the right before the reflection happens.
  113. """
  114. length = hidden_states.shape[-1]
  115. padding_left, padding_right = paddings
  116. if mode != "reflect":
  117. return nn.functional.pad(hidden_states, paddings, mode, value)
  118. max_pad = max(padding_left, padding_right)
  119. extra_pad = 0
  120. if length <= max_pad:
  121. extra_pad = max_pad - length + 1
  122. hidden_states = nn.functional.pad(hidden_states, (0, extra_pad))
  123. padded = nn.functional.pad(hidden_states, paddings, mode, value)
  124. end = padded.shape[-1] - extra_pad
  125. return padded[..., :end]
  126. def forward(self, hidden_states):
  127. extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
  128. if self.causal:
  129. # Left padding for causal
  130. hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode)
  131. else:
  132. # Asymmetric padding required for odd strides
  133. padding_right = self.padding_total // 2
  134. padding_left = self.padding_total - padding_right
  135. hidden_states = self._pad1d(
  136. hidden_states, (padding_left, padding_right + extra_padding), mode=self.pad_mode
  137. )
  138. hidden_states = self.conv(hidden_states)
  139. if self.norm_type == "time_group_norm":
  140. hidden_states = self.norm(hidden_states)
  141. return hidden_states
  142. class EncodecConvTranspose1d(nn.Module):
  143. """ConvTranspose1d with asymmetric or causal padding and normalization."""
  144. def __init__(self, config, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1):
  145. super().__init__()
  146. self.causal = config.use_causal_conv
  147. self.trim_right_ratio = config.trim_right_ratio
  148. self.norm_type = config.norm_type
  149. if self.norm_type not in ["weight_norm", "time_group_norm"]:
  150. raise ValueError(
  151. f'self.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got {self.norm_type}'
  152. )
  153. self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
  154. weight_norm = nn.utils.weight_norm
  155. if hasattr(nn.utils.parametrizations, "weight_norm"):
  156. weight_norm = nn.utils.parametrizations.weight_norm
  157. if config.norm_type == "weight_norm":
  158. self.conv = weight_norm(self.conv)
  159. elif config.norm_type == "time_group_norm":
  160. self.norm = nn.GroupNorm(1, out_channels)
  161. if not (self.causal or self.trim_right_ratio == 1.0):
  162. raise ValueError("`trim_right_ratio` != 1.0 only makes sense for causal convolutions")
  163. def forward(self, hidden_states):
  164. kernel_size = self.conv.kernel_size[0]
  165. stride = self.conv.stride[0]
  166. padding_total = kernel_size - stride
  167. hidden_states = self.conv(hidden_states)
  168. if self.norm_type == "time_group_norm":
  169. hidden_states = self.norm(hidden_states)
  170. # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
  171. # removed at the very end, when keeping only the right length for the output,
  172. # as removing it here would require also passing the length at the matching layer
  173. # in the encoder.
  174. if self.causal:
  175. # Trim the padding on the right according to the specified ratio
  176. # if trim_right_ratio = 1.0, trim everything from right
  177. padding_right = math.ceil(padding_total * self.trim_right_ratio)
  178. else:
  179. # Asymmetric padding required for odd strides
  180. padding_right = padding_total // 2
  181. padding_left = padding_total - padding_right
  182. # unpad
  183. end = hidden_states.shape[-1] - padding_right
  184. hidden_states = hidden_states[..., padding_left:end]
  185. return hidden_states
  186. class EncodecLSTM(nn.Module):
  187. """
  188. LSTM without worrying about the hidden state, nor the layout of the data. Expects input as convolutional layout.
  189. """
  190. def __init__(self, config: EncodecConfig, dimension: int):
  191. super().__init__()
  192. self.lstm = nn.LSTM(dimension, dimension, config.num_lstm_layers)
  193. def forward(self, hidden_states):
  194. hidden_states = hidden_states.permute(2, 0, 1)
  195. hidden_states = self.lstm(hidden_states)[0] + hidden_states
  196. hidden_states = hidden_states.permute(1, 2, 0)
  197. return hidden_states
  198. class EncodecResnetBlock(nn.Module):
  199. """
  200. Residual block from SEANet model as used by EnCodec.
  201. """
  202. def __init__(self, config: EncodecConfig, dim: int, dilations: list[int]):
  203. super().__init__()
  204. kernel_sizes = (config.residual_kernel_size, 1)
  205. if len(kernel_sizes) != len(dilations):
  206. raise ValueError("Number of kernel sizes should match number of dilations")
  207. hidden = dim // config.compress
  208. block = []
  209. for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
  210. in_chs = dim if i == 0 else hidden
  211. out_chs = dim if i == len(kernel_sizes) - 1 else hidden
  212. block += [nn.ELU()]
  213. block += [EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)]
  214. self.block = nn.ModuleList(block)
  215. if config.use_conv_shortcut:
  216. self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
  217. else:
  218. self.shortcut = nn.Identity()
  219. def forward(self, hidden_states):
  220. residual = hidden_states
  221. for layer in self.block:
  222. hidden_states = layer(hidden_states)
  223. return self.shortcut(residual) + hidden_states
  224. class EncodecEncoder(nn.Module):
  225. """SEANet encoder as used by EnCodec."""
  226. def __init__(self, config: EncodecConfig):
  227. super().__init__()
  228. model = [EncodecConv1d(config, config.audio_channels, config.num_filters, config.kernel_size)]
  229. scaling = 1
  230. # Downsample to raw audio scale
  231. for ratio in reversed(config.upsampling_ratios):
  232. current_scale = scaling * config.num_filters
  233. # Add residual layers
  234. for j in range(config.num_residual_layers):
  235. model += [EncodecResnetBlock(config, current_scale, [config.dilation_growth_rate**j, 1])]
  236. # Add downsampling layers
  237. model += [nn.ELU()]
  238. model += [EncodecConv1d(config, current_scale, current_scale * 2, kernel_size=ratio * 2, stride=ratio)]
  239. scaling *= 2
  240. model += [EncodecLSTM(config, scaling * config.num_filters)]
  241. model += [nn.ELU()]
  242. model += [EncodecConv1d(config, scaling * config.num_filters, config.hidden_size, config.last_kernel_size)]
  243. self.layers = nn.ModuleList(model)
  244. def forward(self, hidden_states):
  245. for layer in self.layers:
  246. hidden_states = layer(hidden_states)
  247. return hidden_states
  248. class EncodecDecoder(nn.Module):
  249. """SEANet decoder as used by EnCodec."""
  250. def __init__(self, config: EncodecConfig):
  251. super().__init__()
  252. scaling = int(2 ** len(config.upsampling_ratios))
  253. model = [EncodecConv1d(config, config.hidden_size, scaling * config.num_filters, config.kernel_size)]
  254. model += [EncodecLSTM(config, scaling * config.num_filters)]
  255. # Upsample to raw audio scale
  256. for ratio in config.upsampling_ratios:
  257. current_scale = scaling * config.num_filters
  258. # Add upsampling layers
  259. model += [nn.ELU()]
  260. model += [
  261. EncodecConvTranspose1d(config, current_scale, current_scale // 2, kernel_size=ratio * 2, stride=ratio)
  262. ]
  263. # Add residual layers
  264. for j in range(config.num_residual_layers):
  265. model += [EncodecResnetBlock(config, current_scale // 2, (config.dilation_growth_rate**j, 1))]
  266. scaling //= 2
  267. # Add final layers
  268. model += [nn.ELU()]
  269. model += [EncodecConv1d(config, config.num_filters, config.audio_channels, config.last_kernel_size)]
  270. self.layers = nn.ModuleList(model)
  271. def forward(self, hidden_states):
  272. for layer in self.layers:
  273. hidden_states = layer(hidden_states)
  274. return hidden_states
  275. class EncodecEuclideanCodebook(nn.Module):
  276. """Codebook with Euclidean distance."""
  277. def __init__(self, config: EncodecConfig):
  278. super().__init__()
  279. embed = torch.zeros(config.codebook_size, config.codebook_dim)
  280. self.codebook_size = config.codebook_size
  281. self.register_buffer("inited", torch.Tensor([True]))
  282. self.register_buffer("cluster_size", torch.zeros(config.codebook_size))
  283. self.register_buffer("embed", embed)
  284. self.register_buffer("embed_avg", embed.clone())
  285. def quantize(self, hidden_states):
  286. embed = self.embed.t()
  287. scaled_states = hidden_states.pow(2).sum(1, keepdim=True)
  288. dist = -(scaled_states - 2 * hidden_states @ embed + embed.pow(2).sum(0, keepdim=True))
  289. embed_ind = dist.max(dim=-1).indices
  290. return embed_ind
  291. def encode(self, hidden_states):
  292. shape = hidden_states.shape
  293. # pre-process
  294. hidden_states = hidden_states.reshape((-1, shape[-1]))
  295. # quantize
  296. embed_ind = self.quantize(hidden_states)
  297. # post-process
  298. embed_ind = embed_ind.view(*shape[:-1])
  299. return embed_ind
  300. def decode(self, embed_ind):
  301. quantize = nn.functional.embedding(embed_ind, self.embed)
  302. return quantize
  303. class EncodecVectorQuantization(nn.Module):
  304. """
  305. Vector quantization implementation. Currently supports only euclidean distance.
  306. """
  307. def __init__(self, config: EncodecConfig):
  308. super().__init__()
  309. self.codebook = EncodecEuclideanCodebook(config)
  310. def encode(self, hidden_states):
  311. hidden_states = hidden_states.permute(0, 2, 1)
  312. embed_in = self.codebook.encode(hidden_states)
  313. return embed_in
  314. def decode(self, embed_ind):
  315. quantize = self.codebook.decode(embed_ind)
  316. quantize = quantize.permute(0, 2, 1)
  317. return quantize
  318. class EncodecResidualVectorQuantizer(nn.Module):
  319. """Residual Vector Quantizer."""
  320. def __init__(self, config: EncodecConfig):
  321. super().__init__()
  322. self.codebook_size = config.codebook_size
  323. self.frame_rate = config.frame_rate
  324. self.num_quantizers = config.num_quantizers
  325. self.layers = nn.ModuleList([EncodecVectorQuantization(config) for _ in range(config.num_quantizers)])
  326. def get_num_quantizers_for_bandwidth(self, bandwidth: float | None = None) -> int:
  327. """Return num_quantizers based on specified target bandwidth."""
  328. bw_per_q = math.log2(self.codebook_size) * self.frame_rate
  329. num_quantizers = self.num_quantizers
  330. if bandwidth is not None and bandwidth > 0.0:
  331. num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
  332. return num_quantizers
  333. def encode(self, embeddings: torch.Tensor, bandwidth: float | None = None) -> torch.Tensor:
  334. """
  335. Encode a given input tensor with the specified frame rate at the given bandwidth. The RVQ encode method sets
  336. the appropriate number of quantizers to use and returns indices for each quantizer.
  337. """
  338. num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
  339. residual = embeddings
  340. all_indices = []
  341. for layer in self.layers[:num_quantizers]:
  342. indices = layer.encode(residual)
  343. quantized = layer.decode(indices)
  344. residual = residual - quantized
  345. all_indices.append(indices)
  346. out_indices = torch.stack(all_indices)
  347. return out_indices
  348. def decode(self, codes: torch.Tensor) -> torch.Tensor:
  349. """Decode the given codes to the quantized representation."""
  350. quantized_out = torch.tensor(0.0, device=codes.device)
  351. for i, indices in enumerate(codes):
  352. layer = self.layers[i]
  353. quantized = layer.decode(indices)
  354. quantized_out = quantized_out + quantized
  355. return quantized_out
  356. @auto_docstring
  357. class EncodecPreTrainedModel(PreTrainedAudioTokenizerBase):
  358. config: EncodecConfig
  359. base_model_prefix = "encodec"
  360. main_input_name = "input_values"
  361. @torch.no_grad()
  362. def _init_weights(self, module):
  363. """Initialize the weights"""
  364. if isinstance(module, nn.GroupNorm):
  365. init.zeros_(module.bias)
  366. init.ones_(module.weight)
  367. elif isinstance(module, nn.Conv1d):
  368. init.kaiming_normal_(module.weight)
  369. if module.bias is not None:
  370. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  371. init.uniform_(module.bias, a=-k, b=k)
  372. elif isinstance(module, nn.ConvTranspose1d):
  373. module.reset_parameters()
  374. elif isinstance(module, nn.LSTM):
  375. for name, param in module.named_parameters():
  376. if "weight" in name:
  377. init.xavier_uniform_(param)
  378. elif "bias" in name:
  379. init.constant_(param, 0.0)
  380. elif isinstance(module, EncodecConv1d):
  381. kernel_size = module.conv.kernel_size[0]
  382. stride = torch.tensor(module.conv.stride[0], dtype=torch.int64)
  383. dilation = module.conv.dilation[0]
  384. # Effective kernel size with dilations.
  385. kernel_size = torch.tensor((kernel_size - 1) * dilation + 1, dtype=torch.int64)
  386. init.copy_(module.stride, stride)
  387. init.copy_(module.kernel_size, kernel_size)
  388. init.copy_(module.padding_total, kernel_size - stride)
  389. elif isinstance(module, EncodecEuclideanCodebook):
  390. init.copy_(module.inited, torch.Tensor([True]))
  391. init.zeros_(module.cluster_size)
  392. init.zeros_(module.embed)
  393. init.zeros_(module.embed_avg)
  394. @auto_docstring(
  395. custom_intro="""
  396. The EnCodec neural audio codec model.
  397. """
  398. )
  399. class EncodecModel(EncodecPreTrainedModel):
  400. def __init__(self, config: EncodecConfig):
  401. super().__init__(config)
  402. self.config = config
  403. self.encoder = EncodecEncoder(config)
  404. self.decoder = EncodecDecoder(config)
  405. self.quantizer = EncodecResidualVectorQuantizer(config)
  406. self.bits_per_codebook = int(math.log2(self.config.codebook_size))
  407. if 2**self.bits_per_codebook != self.config.codebook_size:
  408. raise ValueError("The codebook_size must be a power of 2.")
  409. # Initialize weights and apply final processing
  410. self.post_init()
  411. def _encode_frame(self, input_values: torch.Tensor, bandwidth: float) -> tuple[torch.Tensor, torch.Tensor | None]:
  412. """
  413. Encodes the given input using the underlying VQVAE. If `config.normalize` is set to `True` the input is first
  414. normalized. The padding mask is required to compute the correct scale.
  415. """
  416. length = input_values.shape[-1]
  417. duration = length / self.config.sampling_rate
  418. if self.config.chunk_length_s is not None and duration > 1e-5 + self.config.chunk_length_s:
  419. raise RuntimeError(f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}")
  420. scale = None
  421. if self.config.normalize:
  422. mono = torch.sum(input_values, 1, keepdim=True) / input_values.shape[1]
  423. scale = mono.pow(2).mean(dim=-1, keepdim=True).sqrt() + 1e-8
  424. input_values = input_values / scale
  425. scale = scale.view(-1, 1)
  426. embeddings = self.encoder(input_values)
  427. codes = self.quantizer.encode(embeddings, bandwidth)
  428. codes = codes.transpose(0, 1)
  429. return codes, scale
  430. def encode(
  431. self,
  432. input_values: torch.Tensor,
  433. padding_mask: torch.Tensor | None = None,
  434. bandwidth: float | None = None,
  435. return_dict: bool | None = None,
  436. ) -> tuple[torch.Tensor, torch.Tensor | None, int] | EncodecEncoderOutput:
  437. """
  438. Encodes the input audio waveform into discrete codes of shape
  439. `(nb_frames, batch_size, nb_quantizers, frame_len)`.
  440. - `nb_frames=1` if `self.config.chunk_length=None` (as the encoder is applied on the full audio), which is the
  441. case for the 24kHz model. Otherwise, `nb_frames=ceil(input_length/self.config.chunk_stride)`, which is the case
  442. for the 48kHz model.
  443. - `frame_len` is the length of each frame, which is equal to `ceil(input_length/self.config.hop_length)` if
  444. `self.config.chunk_length=None` (e.g., for the 24kHz model). Otherwise, if `self.config.chunk_length` is
  445. defined, `frame_len=self.config.chunk_length/self.config.hop_length`, e.g., the case for the 48kHz model with
  446. `frame_len=150`.
  447. Args:
  448. input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
  449. Float values of the input audio waveform.
  450. padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
  451. Padding mask used to pad the `input_values`.
  452. bandwidth (`float`, *optional*):
  453. The target bandwidth. Must be one of `config.target_bandwidths`. If `None`, uses the smallest possible
  454. bandwidth. bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented
  455. as bandwidth == 6.0
  456. Returns:
  457. EncodecEncoderOutput dict or a tuple containing:
  458. - audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*),
  459. - audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*),
  460. - last_frame_pad_length (`int`, *optional*).
  461. """
  462. return_dict = return_dict if return_dict is not None else self.config.return_dict
  463. if bandwidth is None:
  464. bandwidth = self.config.target_bandwidths[0]
  465. if bandwidth not in self.config.target_bandwidths:
  466. raise ValueError(
  467. f"This model doesn't support the bandwidth {bandwidth}. Select one of {self.config.target_bandwidths}."
  468. )
  469. _, channels, input_length = input_values.shape
  470. if channels < 1 or channels > 2:
  471. raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")
  472. chunk_length = self.config.chunk_length
  473. if chunk_length is None:
  474. chunk_length = input_length
  475. stride = input_length
  476. else:
  477. stride = self.config.chunk_stride
  478. if padding_mask is None:
  479. padding_mask = torch.ones_like(input_values).bool()
  480. else:
  481. padding_mask = padding_mask.view(padding_mask.shape[0], -1, padding_mask.shape[-1])
  482. encoded_frames = []
  483. scales = []
  484. for offset in range(0, input_length, stride):
  485. mask = padding_mask[..., offset : offset + chunk_length].bool()
  486. frame = mask * input_values[..., offset : offset + chunk_length]
  487. encoded_frame, scale = self._encode_frame(frame, bandwidth)
  488. encoded_frames.append(encoded_frame)
  489. scales.append(scale)
  490. # pad last frame (if necessary) to be able to apply `torch.stack`
  491. last_frame_pad_length = encoded_frames[0].shape[-1] - encoded_frames[-1].shape[-1]
  492. if last_frame_pad_length > 0:
  493. last_frame = nn.functional.pad(encoded_frames[-1], (0, last_frame_pad_length), value=0)
  494. encoded_frames[-1] = last_frame
  495. encoded_frames = torch.stack(encoded_frames)
  496. if not return_dict:
  497. return (encoded_frames, scales, last_frame_pad_length)
  498. return EncodecEncoderOutput(encoded_frames, scales, last_frame_pad_length)
  499. @staticmethod
  500. def _linear_overlap_add(frames: list[torch.Tensor], stride: int):
  501. # Generic overlap add, with linear fade-in/fade-out, supporting complex scenario
  502. # e.g., more than 2 frames per position.
  503. # The core idea is to use a weight function that is a triangle,
  504. # with a maximum value at the middle of the chunk.
  505. # We use this weighting when summing the frames, and divide by the sum of weights
  506. # for each positions at the end. Thus:
  507. # - if a frame is the only one to cover a position, the weighting is a no-op.
  508. # - if 2 frames cover a position:
  509. # ... ...
  510. # / \/ \
  511. # / /\ \
  512. # S T , i.e. S offset of second frame starts, T end of first frame.
  513. # Then the weight function for each one is: (t - S), (T - t), with `t` a given offset.
  514. # After the final normalization, the weight of the second frame at position `t` is
  515. # (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want.
  516. #
  517. # - if more than 2 frames overlap at a given point, we hope that by induction
  518. # something sensible happens.
  519. if len(frames) == 0:
  520. raise ValueError("`frames` cannot be an empty list.")
  521. device = frames[0].device
  522. dtype = frames[0].dtype
  523. shape = frames[0].shape[:-1]
  524. total_size = stride * (len(frames) - 1) + frames[-1].shape[-1]
  525. frame_length = frames[0].shape[-1]
  526. time_vec = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1:-1]
  527. weight = 0.5 - (time_vec - 0.5).abs()
  528. sum_weight = torch.zeros(total_size, device=device, dtype=dtype)
  529. out = torch.zeros(*shape, total_size, device=device, dtype=dtype)
  530. offset: int = 0
  531. for frame in frames:
  532. frame_length = frame.shape[-1]
  533. out[..., offset : offset + frame_length] += weight[:frame_length] * frame
  534. sum_weight[offset : offset + frame_length] += weight[:frame_length]
  535. offset += stride
  536. if sum_weight.min() == 0:
  537. raise ValueError(f"`sum_weight` minimum element must be bigger than zero: {sum_weight}`")
  538. return out / sum_weight
  539. def _decode_frame(self, codes: torch.Tensor, scale: torch.Tensor | None = None) -> torch.Tensor:
  540. codes = codes.transpose(0, 1)
  541. embeddings = self.quantizer.decode(codes)
  542. outputs = self.decoder(embeddings)
  543. if scale is not None:
  544. outputs = outputs * scale.view(-1, 1, 1)
  545. return outputs
  546. def decode(
  547. self,
  548. audio_codes: torch.LongTensor,
  549. audio_scales: torch.Tensor,
  550. padding_mask: torch.Tensor | None = None,
  551. return_dict: bool | None = None,
  552. last_frame_pad_length: int | None = 0,
  553. ) -> tuple[torch.Tensor, torch.Tensor] | EncodecDecoderOutput:
  554. """
  555. Decodes the given frames into an output audio waveform.
  556. Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
  557. trimmed.
  558. Args:
  559. audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
  560. Discrete code embeddings computed using `model.encode`.
  561. audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*):
  562. Scaling factor for each `audio_codes` input.
  563. padding_mask (`torch.Tensor` of shape `(channels, sequence_length)`):
  564. Padding mask used to pad the `input_values`.
  565. return_dict (`bool`, *optional*):
  566. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  567. last_frame_pad_length (`int`, *optional*):
  568. Integer representing the length of the padding in the last frame, which is removed during decoding.
  569. """
  570. return_dict = return_dict if return_dict is not None else self.config.return_dict
  571. chunk_length = self.config.chunk_length
  572. if chunk_length is None:
  573. if len(audio_codes) != 1:
  574. raise ValueError(f"Expected one frame, got {len(audio_codes)}")
  575. frame = audio_codes[0]
  576. if last_frame_pad_length > 0:
  577. frame = frame[..., :-last_frame_pad_length]
  578. audio_values = self._decode_frame(frame, audio_scales[0])
  579. else:
  580. decoded_frames = []
  581. for i, (frame, scale) in enumerate(zip(audio_codes, audio_scales)):
  582. if i == len(audio_codes) - 1 and last_frame_pad_length > 0:
  583. frame = frame[..., :-last_frame_pad_length]
  584. frames = self._decode_frame(frame, scale)
  585. decoded_frames.append(frames)
  586. audio_values = self._linear_overlap_add(decoded_frames, self.config.chunk_stride or 1)
  587. # truncate based on padding mask
  588. if padding_mask is not None and padding_mask.shape[-1] < audio_values.shape[-1]:
  589. audio_values = audio_values[..., : padding_mask.shape[-1]]
  590. if not return_dict:
  591. return (audio_values,)
  592. return EncodecDecoderOutput(audio_values)
  593. @auto_docstring
  594. def forward(
  595. self,
  596. input_values: torch.FloatTensor,
  597. padding_mask: torch.BoolTensor | None = None,
  598. bandwidth: float | None = None,
  599. audio_codes: torch.LongTensor | None = None,
  600. audio_scales: torch.Tensor | None = None,
  601. return_dict: bool | None = None,
  602. last_frame_pad_length: int | None = 0,
  603. ) -> tuple[torch.Tensor, torch.Tensor] | EncodecOutput:
  604. r"""
  605. input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
  606. Raw audio input converted to Float and padded to the appropriate length in order to be encoded using chunks
  607. of length self.chunk_length and a stride of `config.chunk_stride`.
  608. padding_mask (`torch.BoolTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
  609. Mask to avoid computing scaling factors on padding token indices (can we avoid computing conv on these+).
  610. Mask values selected in `[0, 1]`:
  611. - 1 for tokens that are **not masked**,
  612. - 0 for tokens that are **masked**.
  613. <Tip warning={true}>
  614. `padding_mask` should always be passed, unless the input was truncated or not padded. This is because in
  615. order to process tensors effectively, the input audio should be padded so that `input_length % stride =
  616. step` with `step = chunk_length-stride`. This ensures that all chunks are of the same shape
  617. </Tip>
  618. bandwidth (`float`, *optional*):
  619. The target bandwidth. Must be one of `config.target_bandwidths`. If `None`, uses the smallest possible
  620. bandwidth. bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as
  621. `bandwidth == 6.0`
  622. audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
  623. Discrete code embeddings computed using `model.encode`.
  624. audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*):
  625. Scaling factor for each `audio_codes` input.
  626. return_dict (`bool`, *optional*):
  627. Whether to return outputs as a dict.
  628. last_frame_pad_length (`int`, *optional*):
  629. The length of the padding in the last frame, if any. This is used to ensure that the encoded frames can be
  630. outputted as a tensor. This value should be passed during decoding to ensure padding is removed from the
  631. encoded frames.
  632. Examples:
  633. ```python
  634. >>> from datasets import load_dataset
  635. >>> from transformers import AutoProcessor, EncodecModel
  636. >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
  637. >>> audio_sample = dataset["train"]["audio"][0]["array"]
  638. >>> model_id = "facebook/encodec_24khz"
  639. >>> model = EncodecModel.from_pretrained(model_id)
  640. >>> processor = AutoProcessor.from_pretrained(model_id)
  641. >>> inputs = processor(raw_audio=audio_sample, return_tensors="pt")
  642. >>> outputs = model(**inputs)
  643. >>> audio_codes = outputs.audio_codes
  644. >>> audio_values = outputs.audio_values
  645. ```"""
  646. return_dict = return_dict if return_dict is not None else self.config.return_dict
  647. if padding_mask is None:
  648. padding_mask = torch.ones_like(input_values).bool()
  649. else:
  650. # ensure that channel dimension is present
  651. padding_mask = padding_mask.view(padding_mask.shape[0], -1, padding_mask.shape[-1])
  652. if audio_codes is not None and audio_scales is None:
  653. raise ValueError("You specified `audio_codes` but did not specify the `audio_scales`")
  654. if audio_scales is not None and audio_codes is None:
  655. raise ValueError("You specified `audio_scales` but did not specify the `audio_codes`")
  656. if audio_scales is None and audio_codes is None:
  657. audio_codes, audio_scales, last_frame_pad_length = self.encode(
  658. input_values, padding_mask, bandwidth, False
  659. )
  660. audio_values = self.decode(
  661. audio_codes,
  662. audio_scales,
  663. padding_mask,
  664. return_dict=return_dict,
  665. last_frame_pad_length=last_frame_pad_length,
  666. )[0]
  667. if not return_dict:
  668. return (audio_codes, audio_values)
  669. return EncodecOutput(audio_codes=audio_codes, audio_values=audio_values)
  670. __all__ = ["EncodecModel", "EncodecPreTrainedModel"]