modeling_vits.py 60 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406
  1. # Copyright 2023 The Kakao Enterprise Authors and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch VITS model."""
  15. import math
  16. from dataclasses import dataclass
  17. from typing import Any
  18. import numpy as np
  19. import torch
  20. from torch import nn
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  24. from ...integrations.fsdp import is_fsdp_managed_module
  25. from ...masking_utils import create_bidirectional_mask
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import BaseModelOutput, ModelOutput
  28. from ...modeling_utils import PreTrainedModel
  29. from ...utils import auto_docstring, logging, torch_compilable_check
  30. from .configuration_vits import VitsConfig
  31. logger = logging.get_logger(__name__)
  32. @dataclass
  33. @auto_docstring(
  34. custom_intro="""
  35. Describes the outputs for the VITS model, with potential hidden states and attentions.
  36. """
  37. )
  38. class VitsModelOutput(ModelOutput):
  39. r"""
  40. waveform (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  41. The final audio waveform predicted by the model.
  42. sequence_lengths (`torch.FloatTensor` of shape `(batch_size,)`):
  43. The length in samples of each element in the `waveform` batch.
  44. spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`):
  45. The log-mel spectrogram predicted at the output of the flow model. This spectrogram is passed to the Hi-Fi
  46. GAN decoder model to obtain the final audio waveform.
  47. """
  48. waveform: torch.FloatTensor | None = None
  49. sequence_lengths: torch.FloatTensor | None = None
  50. spectrogram: tuple[torch.FloatTensor] | None = None
  51. hidden_states: tuple[torch.FloatTensor] | None = None
  52. attentions: tuple[torch.FloatTensor] | None = None
  53. @dataclass
  54. @auto_docstring(
  55. custom_intro="""
  56. Describes the outputs for the VITS text encoder model, with potential hidden states and attentions.
  57. """
  58. )
  59. class VitsTextEncoderOutput(ModelOutput):
  60. r"""
  61. prior_means (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  62. The predicted mean values of the prior distribution for the latent text variables.
  63. prior_log_variances (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  64. The predicted log-variance values of the prior distribution for the latent text variables.
  65. """
  66. last_hidden_state: torch.FloatTensor | None = None
  67. prior_means: torch.FloatTensor | None = None
  68. prior_log_variances: torch.FloatTensor | None = None
  69. hidden_states: tuple[torch.FloatTensor] | None = None
  70. attentions: tuple[torch.FloatTensor] | None = None
  71. @torch.jit.script
  72. def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
  73. in_act = input_a + input_b
  74. t_act = torch.tanh(in_act[:, :num_channels, :])
  75. s_act = torch.sigmoid(in_act[:, num_channels:, :])
  76. acts = t_act * s_act
  77. return acts
  78. def _unconstrained_rational_quadratic_spline(
  79. inputs,
  80. unnormalized_widths,
  81. unnormalized_heights,
  82. unnormalized_derivatives,
  83. reverse=False,
  84. tail_bound=5.0,
  85. min_bin_width=1e-3,
  86. min_bin_height=1e-3,
  87. min_derivative=1e-3,
  88. ):
  89. """
  90. This transformation represents a monotonically increasing piecewise rational quadratic function. Outside of the
  91. `tail_bound`, the transform behaves as an identity function.
  92. Args:
  93. inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
  94. Second half of the hidden-states input to the Vits convolutional flow module.
  95. unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
  96. First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
  97. layer in the convolutional flow module
  98. unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
  99. Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
  100. layer in the convolutional flow module
  101. unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
  102. Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
  103. layer in the convolutional flow module
  104. reverse (`bool`, *optional*, defaults to `False`):
  105. Whether the model is being run in reverse mode.
  106. tail_bound (`float`, *optional* defaults to 5):
  107. Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
  108. transform behaves as an identity function.
  109. min_bin_width (`float`, *optional*, defaults to 1e-3):
  110. Minimum bin value across the width dimension for the piecewise rational quadratic function.
  111. min_bin_height (`float`, *optional*, defaults to 1e-3):
  112. Minimum bin value across the height dimension for the piecewise rational quadratic function.
  113. min_derivative (`float`, *optional*, defaults to 1e-3):
  114. Minimum bin value across the derivatives for the piecewise rational quadratic function.
  115. Returns:
  116. outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
  117. Hidden-states as transformed by the piecewise rational quadratic function with the `tail_bound` limits
  118. applied.
  119. log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
  120. Logarithm of the absolute value of the determinants corresponding to the `outputs` with the `tail_bound`
  121. limits applied.
  122. """
  123. inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
  124. outside_interval_mask = ~inside_interval_mask
  125. outputs = torch.zeros_like(inputs)
  126. log_abs_det = torch.zeros_like(inputs)
  127. constant = np.log(np.exp(1 - min_derivative) - 1)
  128. unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1))
  129. unnormalized_derivatives[..., 0] = constant
  130. unnormalized_derivatives[..., -1] = constant
  131. outputs[outside_interval_mask] = inputs[outside_interval_mask]
  132. log_abs_det[outside_interval_mask] = 0.0
  133. outputs[inside_interval_mask], log_abs_det[inside_interval_mask] = _rational_quadratic_spline(
  134. inputs=inputs[inside_interval_mask],
  135. unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
  136. unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
  137. unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
  138. reverse=reverse,
  139. tail_bound=tail_bound,
  140. min_bin_width=min_bin_width,
  141. min_bin_height=min_bin_height,
  142. min_derivative=min_derivative,
  143. )
  144. return outputs, log_abs_det
  145. def _rational_quadratic_spline(
  146. inputs,
  147. unnormalized_widths,
  148. unnormalized_heights,
  149. unnormalized_derivatives,
  150. reverse,
  151. tail_bound,
  152. min_bin_width,
  153. min_bin_height,
  154. min_derivative,
  155. ):
  156. """
  157. This transformation represents a monotonically increasing piecewise rational quadratic function. Unlike the
  158. function `_unconstrained_rational_quadratic_spline`, the function behaves the same across the `tail_bound`.
  159. Args:
  160. inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
  161. Second half of the hidden-states input to the Vits convolutional flow module.
  162. unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
  163. First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
  164. layer in the convolutional flow module
  165. unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
  166. Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
  167. layer in the convolutional flow module
  168. unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
  169. Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
  170. layer in the convolutional flow module
  171. reverse (`bool`):
  172. Whether the model is being run in reverse mode.
  173. tail_bound (`float`):
  174. Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
  175. transform behaves as an identity function.
  176. min_bin_width (`float`):
  177. Minimum bin value across the width dimension for the piecewise rational quadratic function.
  178. min_bin_height (`float`):
  179. Minimum bin value across the height dimension for the piecewise rational quadratic function.
  180. min_derivative (`float`):
  181. Minimum bin value across the derivatives for the piecewise rational quadratic function.
  182. Returns:
  183. outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
  184. Hidden-states as transformed by the piecewise rational quadratic function.
  185. log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
  186. Logarithm of the absolute value of the determinants corresponding to the `outputs`.
  187. """
  188. upper_bound = tail_bound
  189. lower_bound = -tail_bound
  190. torch_compilable_check(
  191. (inputs.min() >= lower_bound) & (inputs.max() <= upper_bound),
  192. f"Inputs are outside the range [{lower_bound}, {upper_bound}]",
  193. )
  194. num_bins = unnormalized_widths.shape[-1]
  195. if min_bin_width * num_bins > 1.0:
  196. raise ValueError(f"Minimal bin width {min_bin_width} too large for the number of bins {num_bins}")
  197. if min_bin_height * num_bins > 1.0:
  198. raise ValueError(f"Minimal bin height {min_bin_height} too large for the number of bins {num_bins}")
  199. widths = nn.functional.softmax(unnormalized_widths, dim=-1)
  200. widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
  201. cumwidths = torch.cumsum(widths, dim=-1)
  202. cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
  203. cumwidths = (upper_bound - lower_bound) * cumwidths + lower_bound
  204. cumwidths[..., 0] = lower_bound
  205. cumwidths[..., -1] = upper_bound
  206. widths = cumwidths[..., 1:] - cumwidths[..., :-1]
  207. derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives)
  208. heights = nn.functional.softmax(unnormalized_heights, dim=-1)
  209. heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
  210. cumheights = torch.cumsum(heights, dim=-1)
  211. cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
  212. cumheights = (upper_bound - lower_bound) * cumheights + lower_bound
  213. cumheights[..., 0] = lower_bound
  214. cumheights[..., -1] = upper_bound
  215. heights = cumheights[..., 1:] - cumheights[..., :-1]
  216. bin_locations = cumheights if reverse else cumwidths
  217. bin_locations[..., -1] += 1e-6
  218. bin_idx = torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
  219. bin_idx = bin_idx[..., None]
  220. input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
  221. input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
  222. input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
  223. delta = heights / widths
  224. input_delta = delta.gather(-1, bin_idx)[..., 0]
  225. input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
  226. input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
  227. input_heights = heights.gather(-1, bin_idx)[..., 0]
  228. intermediate1 = input_derivatives + input_derivatives_plus_one - 2 * input_delta
  229. if not reverse:
  230. theta = (inputs - input_cumwidths) / input_bin_widths
  231. theta_one_minus_theta = theta * (1 - theta)
  232. numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
  233. denominator = input_delta + intermediate1 * theta_one_minus_theta
  234. outputs = input_cumheights + numerator / denominator
  235. derivative_numerator = input_delta.pow(2) * (
  236. input_derivatives_plus_one * theta.pow(2)
  237. + 2 * input_delta * theta_one_minus_theta
  238. + input_derivatives * (1 - theta).pow(2)
  239. )
  240. log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
  241. return outputs, log_abs_det
  242. else:
  243. # find the roots of a quadratic equation
  244. intermediate2 = inputs - input_cumheights
  245. intermediate3 = intermediate2 * intermediate1
  246. a = input_heights * (input_delta - input_derivatives) + intermediate3
  247. b = input_heights * input_derivatives - intermediate3
  248. c = -input_delta * intermediate2
  249. discriminant = b.pow(2) - 4 * a * c
  250. torch_compilable_check(
  251. torch.all(discriminant >= 0),
  252. f"Discriminant has negative values {discriminant}",
  253. )
  254. root = (2 * c) / (-b - torch.sqrt(discriminant))
  255. outputs = root * input_bin_widths + input_cumwidths
  256. theta_one_minus_theta = root * (1 - root)
  257. denominator = input_delta + intermediate1 * theta_one_minus_theta
  258. derivative_numerator = input_delta.pow(2) * (
  259. input_derivatives_plus_one * root.pow(2)
  260. + 2 * input_delta * theta_one_minus_theta
  261. + input_derivatives * (1 - root).pow(2)
  262. )
  263. log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
  264. return outputs, -log_abs_det
  265. class VitsWaveNet(torch.nn.Module):
  266. def __init__(self, config: VitsConfig, num_layers: int):
  267. super().__init__()
  268. self.hidden_size = config.hidden_size
  269. self.num_layers = num_layers
  270. self.in_layers = torch.nn.ModuleList()
  271. self.res_skip_layers = torch.nn.ModuleList()
  272. self.dropout = nn.Dropout(config.wavenet_dropout)
  273. if hasattr(nn.utils.parametrizations, "weight_norm"):
  274. weight_norm = nn.utils.parametrizations.weight_norm
  275. else:
  276. weight_norm = nn.utils.weight_norm
  277. if config.speaker_embedding_size != 0:
  278. cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1)
  279. self.cond_layer = weight_norm(cond_layer, name="weight")
  280. for i in range(num_layers):
  281. dilation = config.wavenet_dilation_rate**i
  282. padding = (config.wavenet_kernel_size * dilation - dilation) // 2
  283. in_layer = torch.nn.Conv1d(
  284. in_channels=config.hidden_size,
  285. out_channels=2 * config.hidden_size,
  286. kernel_size=config.wavenet_kernel_size,
  287. dilation=dilation,
  288. padding=padding,
  289. )
  290. in_layer = weight_norm(in_layer, name="weight")
  291. self.in_layers.append(in_layer)
  292. # last one is not necessary
  293. if i < num_layers - 1:
  294. res_skip_channels = 2 * config.hidden_size
  295. else:
  296. res_skip_channels = config.hidden_size
  297. res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
  298. res_skip_layer = weight_norm(res_skip_layer, name="weight")
  299. self.res_skip_layers.append(res_skip_layer)
  300. def forward(self, inputs, padding_mask, global_conditioning=None):
  301. outputs = torch.zeros_like(inputs)
  302. num_channels_tensor = torch.IntTensor([self.hidden_size])
  303. if global_conditioning is not None:
  304. global_conditioning = self.cond_layer(global_conditioning)
  305. for i in range(self.num_layers):
  306. hidden_states = self.in_layers[i](inputs)
  307. if global_conditioning is not None:
  308. cond_offset = i * 2 * self.hidden_size
  309. global_states = global_conditioning[:, cond_offset : cond_offset + 2 * self.hidden_size, :]
  310. else:
  311. global_states = torch.zeros_like(hidden_states)
  312. acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
  313. acts = self.dropout(acts)
  314. res_skip_acts = self.res_skip_layers[i](acts)
  315. if i < self.num_layers - 1:
  316. res_acts = res_skip_acts[:, : self.hidden_size, :]
  317. inputs = (inputs + res_acts) * padding_mask
  318. outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
  319. else:
  320. outputs = outputs + res_skip_acts
  321. return outputs * padding_mask
  322. def remove_weight_norm(self):
  323. if self.speaker_embedding_size != 0:
  324. torch.nn.utils.remove_weight_norm(self.cond_layer)
  325. for layer in self.in_layers:
  326. torch.nn.utils.remove_weight_norm(layer)
  327. for layer in self.res_skip_layers:
  328. torch.nn.utils.remove_weight_norm(layer)
  329. class VitsPosteriorEncoder(nn.Module):
  330. def __init__(self, config: VitsConfig):
  331. super().__init__()
  332. self.out_channels = config.flow_size
  333. self.conv_pre = nn.Conv1d(config.spectrogram_bins, config.hidden_size, 1)
  334. self.wavenet = VitsWaveNet(config, num_layers=config.posterior_encoder_num_wavenet_layers)
  335. self.conv_proj = nn.Conv1d(config.hidden_size, self.out_channels * 2, 1)
  336. def forward(self, inputs, padding_mask, global_conditioning=None):
  337. inputs = self.conv_pre(inputs) * padding_mask
  338. inputs = self.wavenet(inputs, padding_mask, global_conditioning)
  339. stats = self.conv_proj(inputs) * padding_mask
  340. mean, log_stddev = torch.split(stats, self.out_channels, dim=1)
  341. sampled = (mean + torch.randn_like(mean) * torch.exp(log_stddev)) * padding_mask
  342. return sampled, mean, log_stddev
  343. # Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock
  344. class HifiGanResidualBlock(nn.Module):
  345. def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
  346. super().__init__()
  347. self.leaky_relu_slope = leaky_relu_slope
  348. self.convs1 = nn.ModuleList(
  349. [
  350. nn.Conv1d(
  351. channels,
  352. channels,
  353. kernel_size,
  354. stride=1,
  355. dilation=dilation[i],
  356. padding=self.get_padding(kernel_size, dilation[i]),
  357. )
  358. for i in range(len(dilation))
  359. ]
  360. )
  361. self.convs2 = nn.ModuleList(
  362. [
  363. nn.Conv1d(
  364. channels,
  365. channels,
  366. kernel_size,
  367. stride=1,
  368. dilation=1,
  369. padding=self.get_padding(kernel_size, 1),
  370. )
  371. for _ in range(len(dilation))
  372. ]
  373. )
  374. def get_padding(self, kernel_size, dilation=1):
  375. return (kernel_size * dilation - dilation) // 2
  376. def apply_weight_norm(self):
  377. weight_norm = nn.utils.weight_norm
  378. if hasattr(nn.utils.parametrizations, "weight_norm"):
  379. weight_norm = nn.utils.parametrizations.weight_norm
  380. for layer in self.convs1:
  381. weight_norm(layer)
  382. for layer in self.convs2:
  383. weight_norm(layer)
  384. def remove_weight_norm(self):
  385. for layer in self.convs1:
  386. nn.utils.remove_weight_norm(layer)
  387. for layer in self.convs2:
  388. nn.utils.remove_weight_norm(layer)
  389. def forward(self, hidden_states):
  390. for conv1, conv2 in zip(self.convs1, self.convs2):
  391. residual = hidden_states
  392. hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
  393. hidden_states = conv1(hidden_states)
  394. hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
  395. hidden_states = conv2(hidden_states)
  396. hidden_states = hidden_states + residual
  397. return hidden_states
  398. class VitsHifiGan(nn.Module):
  399. def __init__(self, config: VitsConfig):
  400. super().__init__()
  401. self.config = config
  402. self.num_kernels = len(config.resblock_kernel_sizes)
  403. self.num_upsamples = len(config.upsample_rates)
  404. self.conv_pre = nn.Conv1d(
  405. config.flow_size,
  406. config.upsample_initial_channel,
  407. kernel_size=7,
  408. stride=1,
  409. padding=3,
  410. )
  411. self.upsampler = nn.ModuleList()
  412. for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
  413. self.upsampler.append(
  414. nn.ConvTranspose1d(
  415. config.upsample_initial_channel // (2**i),
  416. config.upsample_initial_channel // (2 ** (i + 1)),
  417. kernel_size=kernel_size,
  418. stride=upsample_rate,
  419. padding=(kernel_size - upsample_rate) // 2,
  420. )
  421. )
  422. self.resblocks = nn.ModuleList()
  423. for i in range(len(self.upsampler)):
  424. channels = config.upsample_initial_channel // (2 ** (i + 1))
  425. for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
  426. self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
  427. self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False)
  428. if config.speaker_embedding_size != 0:
  429. self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1)
  430. def apply_weight_norm(self):
  431. weight_norm = nn.utils.weight_norm
  432. if hasattr(nn.utils.parametrizations, "weight_norm"):
  433. weight_norm = nn.utils.parametrizations.weight_norm
  434. for layer in self.upsampler:
  435. weight_norm(layer)
  436. for layer in self.resblocks:
  437. layer.apply_weight_norm()
  438. def remove_weight_norm(self):
  439. for layer in self.upsampler:
  440. nn.utils.remove_weight_norm(layer)
  441. for layer in self.resblocks:
  442. layer.remove_weight_norm()
  443. def forward(
  444. self, spectrogram: torch.FloatTensor, global_conditioning: torch.FloatTensor | None = None
  445. ) -> torch.FloatTensor:
  446. r"""
  447. Converts a spectrogram into a speech waveform.
  448. Args:
  449. spectrogram (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`):
  450. Tensor containing the spectrograms.
  451. global_conditioning (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_size, 1)`, *optional*):
  452. Tensor containing speaker embeddings, for multispeaker models.
  453. Returns:
  454. `torch.FloatTensor`: Tensor of shape shape `(batch_size, 1, num_frames)` containing the speech waveform.
  455. """
  456. hidden_states = self.conv_pre(spectrogram)
  457. if global_conditioning is not None:
  458. hidden_states = hidden_states + self.cond(global_conditioning)
  459. for i in range(self.num_upsamples):
  460. hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
  461. hidden_states = self.upsampler[i](hidden_states)
  462. res_state = self.resblocks[i * self.num_kernels](hidden_states)
  463. for j in range(1, self.num_kernels):
  464. res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
  465. hidden_states = res_state / self.num_kernels
  466. hidden_states = nn.functional.leaky_relu(hidden_states)
  467. hidden_states = self.conv_post(hidden_states)
  468. waveform = torch.tanh(hidden_states)
  469. return waveform
  470. class VitsResidualCouplingLayer(nn.Module):
  471. def __init__(self, config: VitsConfig):
  472. super().__init__()
  473. self.half_channels = config.flow_size // 2
  474. self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
  475. self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
  476. self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)
  477. def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
  478. first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
  479. hidden_states = self.conv_pre(first_half) * padding_mask
  480. hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning)
  481. mean = self.conv_post(hidden_states) * padding_mask
  482. log_stddev = torch.zeros_like(mean)
  483. if not reverse:
  484. second_half = mean + second_half * torch.exp(log_stddev) * padding_mask
  485. outputs = torch.cat([first_half, second_half], dim=1)
  486. log_determinant = torch.sum(log_stddev, [1, 2])
  487. return outputs, log_determinant
  488. else:
  489. second_half = (second_half - mean) * torch.exp(-log_stddev) * padding_mask
  490. outputs = torch.cat([first_half, second_half], dim=1)
  491. return outputs, None
  492. class VitsResidualCouplingBlock(nn.Module):
  493. def __init__(self, config: VitsConfig):
  494. super().__init__()
  495. self.flows = nn.ModuleList()
  496. for _ in range(config.prior_encoder_num_flows):
  497. self.flows.append(VitsResidualCouplingLayer(config))
  498. def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
  499. if not reverse:
  500. for flow in self.flows:
  501. inputs, _ = flow(inputs, padding_mask, global_conditioning)
  502. inputs = torch.flip(inputs, [1])
  503. else:
  504. for flow in reversed(self.flows):
  505. inputs = torch.flip(inputs, [1])
  506. inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=True)
  507. return inputs
  508. class VitsDilatedDepthSeparableConv(nn.Module):
  509. def __init__(self, config: VitsConfig, dropout_rate=0.0):
  510. super().__init__()
  511. kernel_size = config.duration_predictor_kernel_size
  512. channels = config.hidden_size
  513. self.num_layers = config.depth_separable_num_layers
  514. self.dropout = nn.Dropout(dropout_rate)
  515. self.convs_dilated = nn.ModuleList()
  516. self.convs_pointwise = nn.ModuleList()
  517. self.norms_1 = nn.ModuleList()
  518. self.norms_2 = nn.ModuleList()
  519. for i in range(self.num_layers):
  520. dilation = kernel_size**i
  521. padding = (kernel_size * dilation - dilation) // 2
  522. self.convs_dilated.append(
  523. nn.Conv1d(
  524. in_channels=channels,
  525. out_channels=channels,
  526. kernel_size=kernel_size,
  527. groups=channels,
  528. dilation=dilation,
  529. padding=padding,
  530. )
  531. )
  532. self.convs_pointwise.append(nn.Conv1d(channels, channels, 1))
  533. self.norms_1.append(nn.LayerNorm(channels))
  534. self.norms_2.append(nn.LayerNorm(channels))
  535. def forward(self, inputs, padding_mask, global_conditioning=None):
  536. if global_conditioning is not None:
  537. inputs = inputs + global_conditioning
  538. for i in range(self.num_layers):
  539. hidden_states = self.convs_dilated[i](inputs * padding_mask)
  540. hidden_states = self.norms_1[i](hidden_states.transpose(1, -1)).transpose(1, -1)
  541. hidden_states = nn.functional.gelu(hidden_states)
  542. hidden_states = self.convs_pointwise[i](hidden_states)
  543. hidden_states = self.norms_2[i](hidden_states.transpose(1, -1)).transpose(1, -1)
  544. hidden_states = nn.functional.gelu(hidden_states)
  545. hidden_states = self.dropout(hidden_states)
  546. inputs = inputs + hidden_states
  547. return inputs * padding_mask
  548. class VitsConvFlow(nn.Module):
  549. def __init__(self, config: VitsConfig):
  550. super().__init__()
  551. self.filter_channels = config.hidden_size
  552. self.half_channels = config.depth_separable_channels // 2
  553. self.num_bins = config.duration_predictor_flow_bins
  554. self.tail_bound = config.duration_predictor_tail_bound
  555. self.conv_pre = nn.Conv1d(self.half_channels, self.filter_channels, 1)
  556. self.conv_dds = VitsDilatedDepthSeparableConv(config)
  557. self.conv_proj = nn.Conv1d(self.filter_channels, self.half_channels * (self.num_bins * 3 - 1), 1)
  558. def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
  559. first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
  560. hidden_states = self.conv_pre(first_half)
  561. hidden_states = self.conv_dds(hidden_states, padding_mask, global_conditioning)
  562. hidden_states = self.conv_proj(hidden_states) * padding_mask
  563. batch_size, channels, length = first_half.shape
  564. hidden_states = hidden_states.reshape(batch_size, channels, -1, length).permute(0, 1, 3, 2)
  565. unnormalized_widths = hidden_states[..., : self.num_bins] / math.sqrt(self.filter_channels)
  566. unnormalized_heights = hidden_states[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
  567. unnormalized_derivatives = hidden_states[..., 2 * self.num_bins :]
  568. second_half, log_abs_det = _unconstrained_rational_quadratic_spline(
  569. second_half,
  570. unnormalized_widths,
  571. unnormalized_heights,
  572. unnormalized_derivatives,
  573. reverse=reverse,
  574. tail_bound=self.tail_bound,
  575. )
  576. outputs = torch.cat([first_half, second_half], dim=1) * padding_mask
  577. if not reverse:
  578. log_determinant = torch.sum(log_abs_det * padding_mask, [1, 2])
  579. return outputs, log_determinant
  580. else:
  581. return outputs, None
  582. class VitsElementwiseAffine(nn.Module):
  583. def __init__(self, config: VitsConfig):
  584. super().__init__()
  585. self.channels = config.depth_separable_channels
  586. self.translate = nn.Parameter(torch.zeros(self.channels, 1))
  587. self.log_scale = nn.Parameter(torch.zeros(self.channels, 1))
  588. def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
  589. if not reverse:
  590. outputs = self.translate + torch.exp(self.log_scale) * inputs
  591. outputs = outputs * padding_mask
  592. log_determinant = torch.sum(self.log_scale * padding_mask, [1, 2])
  593. return outputs, log_determinant
  594. else:
  595. outputs = (inputs - self.translate) * torch.exp(-self.log_scale) * padding_mask
  596. return outputs, None
  597. class VitsStochasticDurationPredictor(nn.Module):
  598. def __init__(self, config):
  599. super().__init__()
  600. embed_dim = config.speaker_embedding_size
  601. filter_channels = config.hidden_size
  602. self.conv_pre = nn.Conv1d(filter_channels, filter_channels, 1)
  603. self.conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
  604. self.conv_dds = VitsDilatedDepthSeparableConv(
  605. config,
  606. dropout_rate=config.duration_predictor_dropout,
  607. )
  608. if embed_dim != 0:
  609. self.cond = nn.Conv1d(embed_dim, filter_channels, 1)
  610. self.flows = nn.ModuleList()
  611. self.flows.append(VitsElementwiseAffine(config))
  612. for _ in range(config.duration_predictor_num_flows):
  613. self.flows.append(VitsConvFlow(config))
  614. self.post_conv_pre = nn.Conv1d(1, filter_channels, 1)
  615. self.post_conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
  616. self.post_conv_dds = VitsDilatedDepthSeparableConv(
  617. config,
  618. dropout_rate=config.duration_predictor_dropout,
  619. )
  620. self.post_flows = nn.ModuleList()
  621. self.post_flows.append(VitsElementwiseAffine(config))
  622. for _ in range(config.duration_predictor_num_flows):
  623. self.post_flows.append(VitsConvFlow(config))
  624. def forward(self, inputs, padding_mask, global_conditioning=None, durations=None, reverse=False, noise_scale=1.0):
  625. inputs = torch.detach(inputs)
  626. inputs = self.conv_pre(inputs)
  627. if global_conditioning is not None:
  628. global_conditioning = torch.detach(global_conditioning)
  629. inputs = inputs + self.cond(global_conditioning)
  630. inputs = self.conv_dds(inputs, padding_mask)
  631. inputs = self.conv_proj(inputs) * padding_mask
  632. if not reverse:
  633. hidden_states = self.post_conv_pre(durations)
  634. hidden_states = self.post_conv_dds(hidden_states, padding_mask)
  635. hidden_states = self.post_conv_proj(hidden_states) * padding_mask
  636. random_posterior = (
  637. torch.randn(durations.size(0), 2, durations.size(2)).to(device=inputs.device, dtype=inputs.dtype)
  638. * padding_mask
  639. )
  640. log_determinant_posterior_sum = 0
  641. latents_posterior = random_posterior
  642. for flow in self.post_flows:
  643. latents_posterior, log_determinant = flow(
  644. latents_posterior, padding_mask, global_conditioning=inputs + hidden_states
  645. )
  646. latents_posterior = torch.flip(latents_posterior, [1])
  647. log_determinant_posterior_sum += log_determinant
  648. first_half, second_half = torch.split(latents_posterior, [1, 1], dim=1)
  649. log_determinant_posterior_sum += torch.sum(
  650. (nn.functional.logsigmoid(first_half) + nn.functional.logsigmoid(-first_half)) * padding_mask, [1, 2]
  651. )
  652. logq = (
  653. torch.sum(-0.5 * (math.log(2 * math.pi) + (random_posterior**2)) * padding_mask, [1, 2])
  654. - log_determinant_posterior_sum
  655. )
  656. first_half = (durations - torch.sigmoid(first_half)) * padding_mask
  657. first_half = torch.log(torch.clamp_min(first_half, 1e-5)) * padding_mask
  658. log_determinant_sum = torch.sum(-first_half, [1, 2])
  659. latents = torch.cat([first_half, second_half], dim=1)
  660. for flow in self.flows:
  661. latents, log_determinant = flow(latents, padding_mask, global_conditioning=inputs)
  662. latents = torch.flip(latents, [1])
  663. log_determinant_sum += log_determinant
  664. nll = torch.sum(0.5 * (math.log(2 * math.pi) + (latents**2)) * padding_mask, [1, 2]) - log_determinant_sum
  665. return nll + logq
  666. else:
  667. flows = list(reversed(self.flows))
  668. flows = flows[:-2] + [flows[-1]] # remove a useless vflow
  669. latents = (
  670. torch.randn(inputs.size(0), 2, inputs.size(2)).to(device=inputs.device, dtype=inputs.dtype)
  671. * noise_scale
  672. )
  673. for flow in flows:
  674. latents = torch.flip(latents, [1])
  675. latents, _ = flow(latents, padding_mask, global_conditioning=inputs, reverse=True)
  676. log_duration, _ = torch.split(latents, [1, 1], dim=1)
  677. return log_duration
  678. class VitsDurationPredictor(nn.Module):
  679. def __init__(self, config):
  680. super().__init__()
  681. kernel_size = config.duration_predictor_kernel_size
  682. filter_channels = config.duration_predictor_filter_channels
  683. self.dropout = nn.Dropout(config.duration_predictor_dropout)
  684. self.conv_1 = nn.Conv1d(config.hidden_size, filter_channels, kernel_size, padding=kernel_size // 2)
  685. self.norm_1 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps)
  686. self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
  687. self.norm_2 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps)
  688. self.proj = nn.Conv1d(filter_channels, 1, 1)
  689. if config.speaker_embedding_size != 0:
  690. self.cond = nn.Conv1d(config.speaker_embedding_size, config.hidden_size, 1)
  691. def forward(self, inputs, padding_mask, global_conditioning=None):
  692. inputs = torch.detach(inputs)
  693. if global_conditioning is not None:
  694. global_conditioning = torch.detach(global_conditioning)
  695. inputs = inputs + self.cond(global_conditioning)
  696. inputs = self.conv_1(inputs * padding_mask)
  697. inputs = torch.relu(inputs)
  698. inputs = self.norm_1(inputs.transpose(1, -1)).transpose(1, -1)
  699. inputs = self.dropout(inputs)
  700. inputs = self.conv_2(inputs * padding_mask)
  701. inputs = torch.relu(inputs)
  702. inputs = self.norm_2(inputs.transpose(1, -1)).transpose(1, -1)
  703. inputs = self.dropout(inputs)
  704. inputs = self.proj(inputs * padding_mask)
  705. return inputs * padding_mask
  706. class VitsAttention(nn.Module):
  707. """Multi-headed attention with relative positional representation."""
  708. def __init__(self, config: VitsConfig):
  709. super().__init__()
  710. self.embed_dim = config.hidden_size
  711. self.num_heads = config.num_attention_heads
  712. self.dropout = config.attention_dropout
  713. self.window_size = config.window_size
  714. self.head_dim = self.embed_dim // self.num_heads
  715. self.scaling = self.head_dim**-0.5
  716. if (self.head_dim * self.num_heads) != self.embed_dim:
  717. raise ValueError(
  718. f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.embed_dim}"
  719. f" and `num_attention_heads`: {self.num_heads})."
  720. )
  721. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
  722. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
  723. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
  724. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
  725. if self.window_size:
  726. self.emb_rel_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
  727. self.emb_rel_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
  728. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  729. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  730. def forward(
  731. self,
  732. hidden_states: torch.Tensor,
  733. key_value_states: torch.Tensor | None = None,
  734. attention_mask: torch.Tensor | None = None,
  735. output_attentions: bool = False,
  736. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  737. """Input shape: Batch x Time x Channel"""
  738. # if key_value_states are provided this layer is used as a cross-attention layer
  739. # for the decoder
  740. bsz, tgt_len, _ = hidden_states.size()
  741. # get query proj
  742. query_states = self.q_proj(hidden_states) * self.scaling
  743. # self_attention
  744. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  745. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  746. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  747. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  748. key_states = key_states.view(*proj_shape)
  749. value_states = value_states.view(*proj_shape)
  750. src_len = key_states.size(1)
  751. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  752. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  753. raise ValueError(
  754. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  755. f" {attn_weights.size()}"
  756. )
  757. if self.window_size is not None:
  758. key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, src_len)
  759. relative_logits = torch.matmul(query_states, key_relative_embeddings.transpose(-2, -1))
  760. rel_pos_bias = self._relative_position_to_absolute_position(relative_logits)
  761. attn_weights += rel_pos_bias
  762. if attention_mask is not None:
  763. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  764. raise ValueError(
  765. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  766. )
  767. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  768. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  769. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  770. if output_attentions:
  771. # this operation is a bit awkward, but it's required to
  772. # make sure that attn_weights keeps its gradient.
  773. # In order to do so, attn_weights have to be reshaped
  774. # twice and have to be reused in the following
  775. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  776. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  777. else:
  778. attn_weights_reshaped = None
  779. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  780. attn_output = torch.bmm(attn_probs, value_states)
  781. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  782. raise ValueError(
  783. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  784. f" {attn_output.size()}"
  785. )
  786. if self.window_size is not None:
  787. value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, src_len)
  788. relative_weights = self._absolute_position_to_relative_position(attn_probs)
  789. rel_pos_bias = torch.matmul(relative_weights, value_relative_embeddings)
  790. attn_output += rel_pos_bias
  791. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  792. attn_output = attn_output.transpose(1, 2)
  793. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  794. # partitioned across GPUs when using tensor-parallelism.
  795. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  796. attn_output = self.out_proj(attn_output)
  797. return attn_output, attn_weights_reshaped
  798. def _get_relative_embeddings(self, relative_embeddings, length):
  799. pad_length = max(length - (self.window_size + 1), 0)
  800. if pad_length > 0:
  801. relative_embeddings = nn.functional.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0])
  802. slice_start_position = max((self.window_size + 1) - length, 0)
  803. slice_end_position = slice_start_position + 2 * length - 1
  804. return relative_embeddings[:, slice_start_position:slice_end_position]
  805. def _relative_position_to_absolute_position(self, x):
  806. batch_heads, length, _ = x.size()
  807. # Concat columns of pad to shift from relative to absolute indexing.
  808. x = nn.functional.pad(x, [0, 1, 0, 0, 0, 0])
  809. # Concat extra elements so to add up to shape (len+1, 2*len-1).
  810. x_flat = x.view([batch_heads, length * 2 * length])
  811. x_flat = nn.functional.pad(x_flat, [0, length - 1, 0, 0])
  812. # Reshape and slice out the padded elements.
  813. x_final = x_flat.view([batch_heads, length + 1, 2 * length - 1])
  814. x_final = x_final[:, :length, length - 1 :]
  815. return x_final
  816. def _absolute_position_to_relative_position(self, x):
  817. batch_heads, length, _ = x.size()
  818. # Pad along column
  819. x = nn.functional.pad(x, [0, length - 1, 0, 0, 0, 0])
  820. x_flat = x.view([batch_heads, length * (2 * length - 1)])
  821. # Add 0's in the beginning that will skew the elements after reshape
  822. x_flat = nn.functional.pad(x_flat, [length, 0, 0, 0])
  823. x_final = x_flat.view([batch_heads, length, 2 * length])[:, :, 1:]
  824. return x_final
  825. class VitsFeedForward(nn.Module):
  826. def __init__(self, config):
  827. super().__init__()
  828. self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size)
  829. self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size)
  830. self.dropout = nn.Dropout(config.activation_dropout)
  831. if isinstance(config.hidden_act, str):
  832. self.act_fn = ACT2FN[config.hidden_act]
  833. else:
  834. self.act_fn = config.hidden_act
  835. if config.ffn_kernel_size > 1:
  836. pad_left = (config.ffn_kernel_size - 1) // 2
  837. pad_right = config.ffn_kernel_size // 2
  838. self.padding = [pad_left, pad_right, 0, 0, 0, 0]
  839. else:
  840. self.padding = None
  841. def forward(self, hidden_states, padding_mask):
  842. hidden_states = hidden_states.permute(0, 2, 1)
  843. padding_mask = padding_mask.permute(0, 2, 1)
  844. hidden_states = hidden_states * padding_mask
  845. if self.padding is not None:
  846. hidden_states = nn.functional.pad(hidden_states, self.padding)
  847. hidden_states = self.conv_1(hidden_states)
  848. hidden_states = self.act_fn(hidden_states)
  849. hidden_states = self.dropout(hidden_states)
  850. hidden_states = hidden_states * padding_mask
  851. if self.padding is not None:
  852. hidden_states = nn.functional.pad(hidden_states, self.padding)
  853. hidden_states = self.conv_2(hidden_states)
  854. hidden_states = hidden_states * padding_mask
  855. hidden_states = hidden_states.permute(0, 2, 1)
  856. return hidden_states
  857. class VitsEncoderLayer(GradientCheckpointingLayer):
  858. def __init__(self, config: VitsConfig):
  859. super().__init__()
  860. self.attention = VitsAttention(config)
  861. self.dropout = nn.Dropout(config.hidden_dropout)
  862. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  863. self.feed_forward = VitsFeedForward(config)
  864. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  865. def forward(
  866. self,
  867. hidden_states: torch.Tensor,
  868. padding_mask: torch.FloatTensor,
  869. attention_mask: torch.Tensor | None = None,
  870. output_attentions: bool = False,
  871. ):
  872. residual = hidden_states
  873. hidden_states, attn_weights = self.attention(
  874. hidden_states=hidden_states,
  875. attention_mask=attention_mask,
  876. output_attentions=output_attentions,
  877. )
  878. hidden_states = self.dropout(hidden_states)
  879. hidden_states = self.layer_norm(residual + hidden_states)
  880. residual = hidden_states
  881. hidden_states = self.feed_forward(hidden_states, padding_mask)
  882. hidden_states = self.dropout(hidden_states)
  883. hidden_states = self.final_layer_norm(residual + hidden_states)
  884. outputs = (hidden_states,)
  885. if output_attentions:
  886. outputs += (attn_weights,)
  887. return outputs
  888. class VitsEncoder(nn.Module):
  889. def __init__(self, config: VitsConfig):
  890. super().__init__()
  891. self.config = config
  892. self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  893. self.gradient_checkpointing = False
  894. self.layerdrop = config.layerdrop
  895. def forward(
  896. self,
  897. hidden_states: torch.FloatTensor,
  898. padding_mask: torch.FloatTensor,
  899. attention_mask: torch.Tensor | None = None,
  900. output_attentions: bool | None = None,
  901. output_hidden_states: bool | None = None,
  902. return_dict: bool | None = None,
  903. ) -> tuple | BaseModelOutput:
  904. all_hidden_states = () if output_hidden_states else None
  905. all_self_attentions = () if output_attentions else None
  906. attention_mask = create_bidirectional_mask(
  907. config=self.config,
  908. inputs_embeds=hidden_states,
  909. attention_mask=attention_mask,
  910. )
  911. hidden_states = hidden_states * padding_mask
  912. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  913. for encoder_layer in self.layers:
  914. if output_hidden_states:
  915. all_hidden_states = all_hidden_states + (hidden_states,)
  916. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  917. dropout_probability = np.random.uniform(0, 1)
  918. skip_the_layer = self.training and (dropout_probability < self.layerdrop)
  919. if not skip_the_layer or synced_gpus:
  920. # under fsdp or deepspeed zero3 all gpus must run in sync
  921. layer_outputs = encoder_layer(
  922. hidden_states,
  923. attention_mask=attention_mask,
  924. padding_mask=padding_mask,
  925. output_attentions=output_attentions,
  926. )
  927. hidden_states = layer_outputs[0]
  928. if skip_the_layer:
  929. layer_outputs = (None, None)
  930. if output_attentions:
  931. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  932. hidden_states = hidden_states * padding_mask
  933. if output_hidden_states:
  934. all_hidden_states = all_hidden_states + (hidden_states,)
  935. if not return_dict:
  936. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  937. return BaseModelOutput(
  938. last_hidden_state=hidden_states,
  939. hidden_states=all_hidden_states,
  940. attentions=all_self_attentions,
  941. )
  942. class VitsTextEncoder(nn.Module):
  943. """
  944. Transformer encoder that uses relative positional representation instead of absolute positional encoding.
  945. """
  946. def __init__(self, config: VitsConfig):
  947. super().__init__()
  948. self.config = config
  949. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
  950. self.encoder = VitsEncoder(config)
  951. self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1)
  952. def forward(
  953. self,
  954. input_ids: torch.Tensor,
  955. padding_mask: torch.FloatTensor,
  956. attention_mask: torch.Tensor | None = None,
  957. output_attentions: bool | None = None,
  958. output_hidden_states: bool | None = None,
  959. return_dict: bool | None = True,
  960. ) -> tuple[torch.Tensor] | VitsTextEncoderOutput:
  961. hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size)
  962. encoder_outputs = self.encoder(
  963. hidden_states=hidden_states,
  964. padding_mask=padding_mask,
  965. attention_mask=attention_mask,
  966. output_attentions=output_attentions,
  967. output_hidden_states=output_hidden_states,
  968. return_dict=return_dict,
  969. )
  970. last_hidden_state = encoder_outputs[0] if not return_dict else encoder_outputs.last_hidden_state
  971. stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2) * padding_mask
  972. prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2)
  973. if not return_dict:
  974. outputs = (last_hidden_state, prior_means, prior_log_variances) + encoder_outputs[1:]
  975. return outputs
  976. return VitsTextEncoderOutput(
  977. last_hidden_state=last_hidden_state,
  978. prior_means=prior_means,
  979. prior_log_variances=prior_log_variances,
  980. hidden_states=encoder_outputs.hidden_states,
  981. attentions=encoder_outputs.attentions,
  982. )
  983. @auto_docstring
  984. class VitsPreTrainedModel(PreTrainedModel):
  985. config: VitsConfig
  986. base_model_prefix = "vits"
  987. main_input_name = "input_ids"
  988. supports_gradient_checkpointing = True
  989. @torch.no_grad()
  990. def _init_weights(self, module: nn.Module):
  991. """Initialize the weights"""
  992. std = self.config.initializer_range
  993. if isinstance(module, nn.Linear):
  994. init.normal_(module.weight, mean=0.0, std=std)
  995. if module.bias is not None:
  996. init.zeros_(module.bias)
  997. elif isinstance(module, nn.LayerNorm):
  998. init.zeros_(module.bias)
  999. init.ones_(module.weight)
  1000. elif isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
  1001. init.kaiming_normal_(module.weight)
  1002. if module.bias is not None:
  1003. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  1004. init.uniform_(module.bias, a=-k, b=k)
  1005. elif isinstance(module, nn.Embedding):
  1006. init.normal_(module.weight, mean=0.0, std=std)
  1007. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  1008. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  1009. init.zeros_(module.weight[module.padding_idx])
  1010. elif isinstance(module, VitsAttention):
  1011. if self.config.window_size:
  1012. head_dim = self.config.hidden_size // self.config.num_attention_heads
  1013. init.normal_(module.emb_rel_k, std=head_dim**-0.5)
  1014. init.normal_(module.emb_rel_v, std=head_dim**-0.5)
  1015. elif isinstance(module, VitsElementwiseAffine):
  1016. init.zeros_(module.translate)
  1017. init.zeros_(module.log_scale)
  1018. @auto_docstring(
  1019. custom_intro="""
  1020. The complete VITS model, for text-to-speech synthesis.
  1021. """
  1022. )
  1023. class VitsModel(VitsPreTrainedModel):
  1024. def __init__(self, config: VitsConfig):
  1025. super().__init__(config)
  1026. self.config = config
  1027. self.text_encoder = VitsTextEncoder(config)
  1028. self.flow = VitsResidualCouplingBlock(config)
  1029. self.decoder = VitsHifiGan(config)
  1030. if config.use_stochastic_duration_prediction:
  1031. self.duration_predictor = VitsStochasticDurationPredictor(config)
  1032. else:
  1033. self.duration_predictor = VitsDurationPredictor(config)
  1034. if config.num_speakers > 1:
  1035. self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
  1036. # This is used only for training.
  1037. self.posterior_encoder = VitsPosteriorEncoder(config)
  1038. # These parameters control the synthesised speech properties
  1039. self.speaking_rate = config.speaking_rate
  1040. self.noise_scale = config.noise_scale
  1041. self.noise_scale_duration = config.noise_scale_duration
  1042. # Initialize weights and apply final processing
  1043. self.post_init()
  1044. @auto_docstring
  1045. def forward(
  1046. self,
  1047. input_ids: torch.Tensor | None = None,
  1048. attention_mask: torch.Tensor | None = None,
  1049. speaker_id: int | None = None,
  1050. output_attentions: bool | None = None,
  1051. output_hidden_states: bool | None = None,
  1052. return_dict: bool | None = None,
  1053. labels: torch.FloatTensor | None = None,
  1054. speaking_rate: float | None = None,
  1055. **kwargs,
  1056. ) -> tuple[Any] | VitsModelOutput:
  1057. r"""
  1058. speaker_id (`int`, *optional*):
  1059. Which speaker embedding to use. Only used for multispeaker models.
  1060. labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*):
  1061. Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
  1062. computation.
  1063. speaking_rate (`float`, *optional*):
  1064. Speaking rate.
  1065. Example:
  1066. ```python
  1067. >>> from transformers import VitsTokenizer, VitsModel, set_seed
  1068. >>> import torch
  1069. >>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
  1070. >>> model = VitsModel.from_pretrained("facebook/mms-tts-eng")
  1071. >>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt")
  1072. >>> set_seed(555) # make deterministic
  1073. >>> with torch.no_grad():
  1074. ... outputs = model(inputs["input_ids"])
  1075. >>> outputs.waveform.shape
  1076. torch.Size([1, 45824])
  1077. ```
  1078. """
  1079. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1080. output_hidden_states = (
  1081. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1082. )
  1083. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1084. if labels is not None:
  1085. raise NotImplementedError("Training of VITS is not supported yet.")
  1086. mask_dtype = self.text_encoder.embed_tokens.weight.dtype
  1087. if attention_mask is not None:
  1088. input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
  1089. else:
  1090. input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
  1091. if self.config.num_speakers > 1 and speaker_id is not None:
  1092. if not 0 <= speaker_id < self.config.num_speakers:
  1093. raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.")
  1094. if isinstance(speaker_id, int):
  1095. speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
  1096. speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1)
  1097. else:
  1098. speaker_embeddings = None
  1099. text_encoder_output = self.text_encoder(
  1100. input_ids=input_ids,
  1101. padding_mask=input_padding_mask,
  1102. attention_mask=attention_mask,
  1103. output_attentions=output_attentions,
  1104. output_hidden_states=output_hidden_states,
  1105. return_dict=return_dict,
  1106. )
  1107. hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
  1108. hidden_states = hidden_states.transpose(1, 2)
  1109. input_padding_mask = input_padding_mask.transpose(1, 2)
  1110. prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
  1111. prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
  1112. if self.config.use_stochastic_duration_prediction:
  1113. log_duration = self.duration_predictor(
  1114. hidden_states,
  1115. input_padding_mask,
  1116. speaker_embeddings,
  1117. reverse=True,
  1118. noise_scale=self.noise_scale_duration,
  1119. )
  1120. else:
  1121. log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
  1122. if speaking_rate is None:
  1123. speaking_rate = self.speaking_rate
  1124. length_scale = 1.0 / speaking_rate
  1125. duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
  1126. predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
  1127. # Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
  1128. indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
  1129. output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
  1130. output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
  1131. # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
  1132. attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
  1133. batch_size, _, output_length, input_length = attn_mask.shape
  1134. cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
  1135. indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
  1136. valid_indices = indices.unsqueeze(0) < cum_duration
  1137. valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
  1138. padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
  1139. attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
  1140. # Expand prior distribution
  1141. prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
  1142. prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
  1143. prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
  1144. latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
  1145. spectrogram = latents * output_padding_mask
  1146. waveform = self.decoder(spectrogram, speaker_embeddings)
  1147. waveform = waveform.squeeze(1)
  1148. sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
  1149. if not return_dict:
  1150. outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
  1151. return outputs
  1152. return VitsModelOutput(
  1153. waveform=waveform,
  1154. sequence_lengths=sequence_lengths,
  1155. spectrogram=spectrogram,
  1156. hidden_states=text_encoder_output.hidden_states,
  1157. attentions=text_encoder_output.attentions,
  1158. )
  1159. __all__ = ["VitsModel", "VitsPreTrainedModel"]