modular_unispeech.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch UniSpeech model."""
  15. import math
  16. from dataclasses import dataclass
  17. import torch
  18. import torch.nn as nn
  19. from ... import initialization as init
  20. from ...modeling_outputs import ModelOutput, Wav2Vec2BaseModelOutput
  21. from ...modeling_utils import PreTrainedModel
  22. from ...utils import auto_docstring, logging
  23. from ..wav2vec2.modeling_wav2vec2 import (
  24. Wav2Vec2Encoder,
  25. Wav2Vec2EncoderStableLayerNorm,
  26. Wav2Vec2FeatureEncoder,
  27. Wav2Vec2FeatureProjection,
  28. Wav2Vec2ForCTC,
  29. Wav2Vec2ForSequenceClassification,
  30. Wav2Vec2GumbelVectorQuantizer,
  31. Wav2Vec2Model,
  32. Wav2Vec2PositionalConvEmbedding,
  33. )
  34. from .configuration_unispeech import UniSpeechConfig
  35. logger = logging.get_logger(__name__)
  36. @dataclass
  37. @auto_docstring(
  38. custom_intro="""
  39. Output type of [`UniSpeechForPreTrainingOutput`], with potential hidden states and attentions.
  40. """
  41. )
  42. class UniSpeechForPreTrainingOutput(ModelOutput):
  43. r"""
  44. loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`):
  45. Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
  46. paper](https://huggingface.co/papers/2006.11477).
  47. projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
  48. Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
  49. projected quantized states.
  50. projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
  51. Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
  52. target vectors for contrastive loss.
  53. codevector_perplexity (`torch.FloatTensor` of shape `(1,)`):
  54. The perplexity of the codevector distribution, used to measure the diversity of the codebook.
  55. """
  56. loss: torch.FloatTensor | None = None
  57. projected_states: torch.FloatTensor | None = None
  58. projected_quantized_states: torch.FloatTensor | None = None
  59. codevector_perplexity: torch.FloatTensor | None = None
  60. hidden_states: tuple[torch.FloatTensor] | None = None
  61. attentions: tuple[torch.FloatTensor] | None = None
  62. class UniSpeechPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding):
  63. pass
  64. class UniSpeechFeatureEncoder(Wav2Vec2FeatureEncoder):
  65. pass
  66. class UniSpeechFeatureProjection(Wav2Vec2FeatureProjection):
  67. pass
  68. class UniSpeechEncoder(Wav2Vec2Encoder):
  69. pass
  70. class UniSpeechEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm):
  71. pass
  72. class UniSpeechGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer):
  73. @staticmethod
  74. def _compute_perplexity(probs):
  75. marginal_probs = probs.mean(dim=0)
  76. perplexity = torch.exp(-torch.sum(torch.xlogy(marginal_probs, marginal_probs), dim=-1)).sum()
  77. return perplexity
  78. def forward(self, hidden_states):
  79. batch_size, sequence_length, hidden_size = hidden_states.shape
  80. # project to codevector dim
  81. hidden_states = self.weight_proj(hidden_states)
  82. hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
  83. if self.training:
  84. # sample code vector probs via gumbel in differentiateable way
  85. codevector_probs = nn.functional.gumbel_softmax(
  86. hidden_states.float(), tau=self.temperature, hard=True
  87. ).type_as(hidden_states)
  88. # compute perplexity
  89. codevector_soft_dist = torch.softmax(
  90. hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
  91. )
  92. perplexity = self._compute_perplexity(codevector_soft_dist)
  93. else:
  94. # take argmax in non-differentiable way
  95. # comptute hard codevector distribution (one hot)
  96. codevector_idx = hidden_states.argmax(dim=-1)
  97. codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
  98. -1, codevector_idx.view(-1, 1), 1.0
  99. )
  100. codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
  101. perplexity = self._compute_perplexity(codevector_probs)
  102. codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
  103. # use probs to retrieve codevectors
  104. codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
  105. codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
  106. codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
  107. return codevectors, perplexity
  108. @auto_docstring
  109. class UniSpeechPreTrainedModel(PreTrainedModel):
  110. config: UniSpeechConfig
  111. base_model_prefix = "unispeech"
  112. main_input_name = "input_values"
  113. input_modalities = "audio"
  114. supports_gradient_checkpointing = True
  115. _supports_flash_attn = True
  116. _supports_sdpa = True
  117. _supports_flex_attn = True
  118. @torch.no_grad()
  119. def _init_weights(self, module):
  120. """Initialize the weights"""
  121. # gumbel softmax requires special init
  122. if isinstance(module, UniSpeechGumbelVectorQuantizer):
  123. init.normal_(module.weight_proj.weight, mean=0.0, std=1)
  124. init.zeros_(module.weight_proj.bias)
  125. init.uniform_(module.codevectors)
  126. elif isinstance(module, UniSpeechPositionalConvEmbedding):
  127. init.normal_(
  128. module.conv.weight,
  129. mean=0,
  130. std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
  131. )
  132. init.constant_(module.conv.bias, 0)
  133. elif isinstance(module, UniSpeechFeatureProjection):
  134. k = math.sqrt(1 / module.projection.in_features)
  135. init.uniform_(module.projection.weight, a=-k, b=k)
  136. init.uniform_(module.projection.bias, a=-k, b=k)
  137. elif isinstance(module, nn.Linear):
  138. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  139. if module.bias is not None:
  140. init.zeros_(module.bias)
  141. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  142. init.zeros_(module.bias)
  143. init.ones_(module.weight)
  144. elif isinstance(module, nn.Conv1d):
  145. init.kaiming_normal_(module.weight)
  146. if module.bias is not None:
  147. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  148. init.uniform_(module.bias, a=-k, b=k)
  149. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int):
  150. """
  151. Computes the output length of the convolutional layers
  152. """
  153. def _conv_out_length(input_length, kernel_size, stride):
  154. # 1D convolutional layer output length formula taken
  155. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  156. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  157. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  158. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  159. return input_lengths
  160. def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
  161. # Effectively attention_mask.sum(-1), but not inplace to be able to run
  162. # on inference mode.
  163. non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
  164. output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
  165. batch_size = attention_mask.shape[0]
  166. attention_mask = torch.zeros(
  167. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  168. )
  169. # these two operations makes sure that all values before the output lengths idxs are attended to
  170. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  171. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  172. return attention_mask
  173. UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput
  174. class UniSpeechModel(UniSpeechPreTrainedModel, Wav2Vec2Model):
  175. def __init__(self, config: UniSpeechConfig):
  176. UniSpeechPreTrainedModel.__init__(self, config)
  177. self.config = config
  178. self.feature_extractor = UniSpeechFeatureEncoder(config)
  179. self.feature_projection = UniSpeechFeatureProjection(config)
  180. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  181. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  182. if config.do_stable_layer_norm:
  183. self.encoder = UniSpeechEncoderStableLayerNorm(config)
  184. else:
  185. self.encoder = UniSpeechEncoder(config)
  186. # Initialize weights and apply final processing
  187. self.post_init()
  188. def freeze_feature_encoder(self):
  189. raise AttributeError("Not needed for UniSpeech")
  190. def forward(
  191. self,
  192. input_values: torch.Tensor | None,
  193. attention_mask: torch.Tensor | None = None,
  194. mask_time_indices: torch.FloatTensor | None = None,
  195. output_attentions: bool | None = None,
  196. output_hidden_states: bool | None = None,
  197. return_dict: bool | None = None,
  198. **kwargs,
  199. ) -> tuple | UniSpeechBaseModelOutput:
  200. r"""
  201. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  202. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  203. masked extracted features in *config.proj_codevector_dim* space.
  204. """
  205. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  206. output_hidden_states = (
  207. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  208. )
  209. return_dict = return_dict if return_dict is not None else self.config.return_dict
  210. extract_features = self.feature_extractor(input_values)
  211. extract_features = extract_features.transpose(1, 2)
  212. if attention_mask is not None:
  213. # compute reduced attention_mask corresponding to feature vectors
  214. attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
  215. hidden_states, extract_features = self.feature_projection(extract_features)
  216. hidden_states = self._mask_hidden_states(
  217. hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
  218. )
  219. encoder_outputs = self.encoder(
  220. hidden_states,
  221. attention_mask=attention_mask,
  222. output_attentions=output_attentions,
  223. output_hidden_states=output_hidden_states,
  224. return_dict=return_dict,
  225. )
  226. hidden_states = encoder_outputs[0]
  227. if not return_dict:
  228. return (hidden_states, extract_features) + encoder_outputs[1:]
  229. return UniSpeechBaseModelOutput(
  230. last_hidden_state=hidden_states,
  231. extract_features=extract_features,
  232. hidden_states=encoder_outputs.hidden_states,
  233. attentions=encoder_outputs.attentions,
  234. )
  235. @auto_docstring(
  236. custom_intro="""
  237. UniSpeech Model with a vector-quantization module and ctc loss for pre-training.
  238. """
  239. )
  240. class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
  241. def __init__(self, config: UniSpeechConfig):
  242. super().__init__(config)
  243. self.unispeech = UniSpeechModel(config)
  244. self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
  245. self.quantizer = UniSpeechGumbelVectorQuantizer(config)
  246. self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
  247. self.project_hid = nn.Linear(config.proj_codevector_dim, config.hidden_size)
  248. self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes)
  249. self.dropout = nn.Dropout(config.final_dropout)
  250. # Initialize weights and apply final processing
  251. self.post_init()
  252. def set_gumbel_temperature(self, temperature: int):
  253. """
  254. Set the Gumbel softmax temperature to a given value. Only necessary for training
  255. """
  256. self.quantizer.temperature = temperature
  257. def freeze_feature_encoder(self):
  258. """
  259. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  260. not be updated during training.
  261. """
  262. self.unispeech.feature_extractor._freeze_parameters()
  263. @staticmethod
  264. def compute_contrastive_logits(
  265. target_features: torch.FloatTensor,
  266. negative_features: torch.FloatTensor,
  267. predicted_features: torch.FloatTensor,
  268. temperature: int = 1,
  269. ):
  270. """
  271. Compute logits for contrastive loss based using cosine similarity as the distance measure between
  272. `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
  273. """
  274. target_features = torch.cat([target_features, negative_features], dim=0)
  275. logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1)
  276. logits = logits.type_as(target_features)
  277. # apply temperature
  278. logits = logits / temperature
  279. return logits
  280. @auto_docstring
  281. def forward(
  282. self,
  283. input_values: torch.Tensor | None,
  284. attention_mask: torch.Tensor | None = None,
  285. output_attentions: bool | None = None,
  286. output_hidden_states: bool | None = None,
  287. return_dict: bool | None = None,
  288. **kwargs,
  289. ) -> tuple | UniSpeechForPreTrainingOutput:
  290. r"""
  291. Example:
  292. ```python
  293. >>> import torch
  294. >>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining
  295. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv")
  296. >>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
  297. >>> # TODO: Add full pretraining example
  298. ```"""
  299. return_dict = return_dict if return_dict is not None else self.config.return_dict
  300. outputs = self.unispeech(
  301. input_values,
  302. attention_mask=attention_mask,
  303. output_attentions=output_attentions,
  304. output_hidden_states=output_hidden_states,
  305. return_dict=return_dict,
  306. )
  307. transformer_features = outputs[0]
  308. # quantize all (unmasked) extracted features and project to final vq dim
  309. extract_features = self.dropout_features(outputs[1])
  310. quantized_features, codevector_perplexity = self.quantizer(extract_features)
  311. # project quantized features twice
  312. quantized_features = self.project_q(quantized_features.to(self.project_q.weight.dtype))
  313. quantized_features = self.project_hid(quantized_features)
  314. prob_replace_matrix = torch.empty(transformer_features.size(0), transformer_features.size(1)).fill_(
  315. self.config.replace_prob
  316. )
  317. prob_replace_matrix = prob_replace_matrix.transpose(0, 1)
  318. sampled_replace_matrix = torch.bernoulli(prob_replace_matrix).bool().to(transformer_features.device)
  319. sampled_replace_matrix = sampled_replace_matrix.transpose(0, 1)
  320. sampled_replace_matrix = sampled_replace_matrix.unsqueeze(-1)
  321. logits = transformer_features.masked_fill(sampled_replace_matrix, 0.0) + (
  322. quantized_features.masked_fill(~sampled_replace_matrix, 0.0)
  323. )
  324. # project to ctc units
  325. logits = self.dropout(logits)
  326. logits = self.ctc_proj(logits)
  327. # TODO(PVP) - add negative sampling & loss computation
  328. loss = None
  329. if not return_dict:
  330. if loss is not None:
  331. return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
  332. return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
  333. return UniSpeechForPreTrainingOutput(
  334. loss=loss,
  335. projected_states=transformer_features,
  336. projected_quantized_states=quantized_features,
  337. codevector_perplexity=codevector_perplexity,
  338. hidden_states=outputs.hidden_states,
  339. attentions=outputs.attentions,
  340. )
  341. class UniSpeechForCTC(Wav2Vec2ForCTC):
  342. pass
  343. class UniSpeechForSequenceClassification(Wav2Vec2ForSequenceClassification):
  344. pass
  345. __all__ = [
  346. "UniSpeechForCTC",
  347. "UniSpeechForPreTraining",
  348. "UniSpeechForSequenceClassification",
  349. "UniSpeechModel",
  350. "UniSpeechPreTrainedModel",
  351. ]