| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432 |
- # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch UniSpeech model."""
- import math
- from dataclasses import dataclass
- import torch
- import torch.nn as nn
- from ... import initialization as init
- from ...modeling_outputs import ModelOutput, Wav2Vec2BaseModelOutput
- from ...modeling_utils import PreTrainedModel
- from ...utils import auto_docstring, logging
- from ..wav2vec2.modeling_wav2vec2 import (
- Wav2Vec2Encoder,
- Wav2Vec2EncoderStableLayerNorm,
- Wav2Vec2FeatureEncoder,
- Wav2Vec2FeatureProjection,
- Wav2Vec2ForCTC,
- Wav2Vec2ForSequenceClassification,
- Wav2Vec2GumbelVectorQuantizer,
- Wav2Vec2Model,
- Wav2Vec2PositionalConvEmbedding,
- )
- from .configuration_unispeech import UniSpeechConfig
- logger = logging.get_logger(__name__)
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`UniSpeechForPreTrainingOutput`], with potential hidden states and attentions.
- """
- )
- class UniSpeechForPreTrainingOutput(ModelOutput):
- r"""
- loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`):
- Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
- paper](https://huggingface.co/papers/2006.11477).
- projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
- Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
- projected quantized states.
- projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
- Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
- target vectors for contrastive loss.
- codevector_perplexity (`torch.FloatTensor` of shape `(1,)`):
- The perplexity of the codevector distribution, used to measure the diversity of the codebook.
- """
- loss: torch.FloatTensor | None = None
- projected_states: torch.FloatTensor | None = None
- projected_quantized_states: torch.FloatTensor | None = None
- codevector_perplexity: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- class UniSpeechPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding):
- pass
- class UniSpeechFeatureEncoder(Wav2Vec2FeatureEncoder):
- pass
- class UniSpeechFeatureProjection(Wav2Vec2FeatureProjection):
- pass
- class UniSpeechEncoder(Wav2Vec2Encoder):
- pass
- class UniSpeechEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm):
- pass
- class UniSpeechGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer):
- @staticmethod
- def _compute_perplexity(probs):
- marginal_probs = probs.mean(dim=0)
- perplexity = torch.exp(-torch.sum(torch.xlogy(marginal_probs, marginal_probs), dim=-1)).sum()
- return perplexity
- def forward(self, hidden_states):
- batch_size, sequence_length, hidden_size = hidden_states.shape
- # project to codevector dim
- hidden_states = self.weight_proj(hidden_states)
- hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
- if self.training:
- # sample code vector probs via gumbel in differentiateable way
- codevector_probs = nn.functional.gumbel_softmax(
- hidden_states.float(), tau=self.temperature, hard=True
- ).type_as(hidden_states)
- # compute perplexity
- codevector_soft_dist = torch.softmax(
- hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
- )
- perplexity = self._compute_perplexity(codevector_soft_dist)
- else:
- # take argmax in non-differentiable way
- # comptute hard codevector distribution (one hot)
- codevector_idx = hidden_states.argmax(dim=-1)
- codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
- -1, codevector_idx.view(-1, 1), 1.0
- )
- codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
- perplexity = self._compute_perplexity(codevector_probs)
- codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
- # use probs to retrieve codevectors
- codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
- codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
- codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
- return codevectors, perplexity
- @auto_docstring
- class UniSpeechPreTrainedModel(PreTrainedModel):
- config: UniSpeechConfig
- base_model_prefix = "unispeech"
- main_input_name = "input_values"
- input_modalities = "audio"
- supports_gradient_checkpointing = True
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- # gumbel softmax requires special init
- if isinstance(module, UniSpeechGumbelVectorQuantizer):
- init.normal_(module.weight_proj.weight, mean=0.0, std=1)
- init.zeros_(module.weight_proj.bias)
- init.uniform_(module.codevectors)
- elif isinstance(module, UniSpeechPositionalConvEmbedding):
- init.normal_(
- module.conv.weight,
- mean=0,
- std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
- )
- init.constant_(module.conv.bias, 0)
- elif isinstance(module, UniSpeechFeatureProjection):
- k = math.sqrt(1 / module.projection.in_features)
- init.uniform_(module.projection.weight, a=-k, b=k)
- init.uniform_(module.projection.bias, a=-k, b=k)
- elif isinstance(module, nn.Linear):
- init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- elif isinstance(module, nn.Conv1d):
- init.kaiming_normal_(module.weight)
- if module.bias is not None:
- k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
- init.uniform_(module.bias, a=-k, b=k)
- def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int):
- """
- Computes the output length of the convolutional layers
- """
- def _conv_out_length(input_length, kernel_size, stride):
- # 1D convolutional layer output length formula taken
- # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
- return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
- for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
- input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
- return input_lengths
- def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
- # Effectively attention_mask.sum(-1), but not inplace to be able to run
- # on inference mode.
- non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
- output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
- batch_size = attention_mask.shape[0]
- attention_mask = torch.zeros(
- (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
- )
- # these two operations makes sure that all values before the output lengths idxs are attended to
- attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
- attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
- return attention_mask
- UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput
- class UniSpeechModel(UniSpeechPreTrainedModel, Wav2Vec2Model):
- def __init__(self, config: UniSpeechConfig):
- UniSpeechPreTrainedModel.__init__(self, config)
- self.config = config
- self.feature_extractor = UniSpeechFeatureEncoder(config)
- self.feature_projection = UniSpeechFeatureProjection(config)
- if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
- self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
- if config.do_stable_layer_norm:
- self.encoder = UniSpeechEncoderStableLayerNorm(config)
- else:
- self.encoder = UniSpeechEncoder(config)
- # Initialize weights and apply final processing
- self.post_init()
- def freeze_feature_encoder(self):
- raise AttributeError("Not needed for UniSpeech")
- def forward(
- self,
- input_values: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- mask_time_indices: torch.FloatTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | UniSpeechBaseModelOutput:
- r"""
- mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
- masked extracted features in *config.proj_codevector_dim* space.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- extract_features = self.feature_extractor(input_values)
- extract_features = extract_features.transpose(1, 2)
- if attention_mask is not None:
- # compute reduced attention_mask corresponding to feature vectors
- attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
- hidden_states, extract_features = self.feature_projection(extract_features)
- hidden_states = self._mask_hidden_states(
- hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
- )
- encoder_outputs = self.encoder(
- hidden_states,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = encoder_outputs[0]
- if not return_dict:
- return (hidden_states, extract_features) + encoder_outputs[1:]
- return UniSpeechBaseModelOutput(
- last_hidden_state=hidden_states,
- extract_features=extract_features,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- UniSpeech Model with a vector-quantization module and ctc loss for pre-training.
- """
- )
- class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
- def __init__(self, config: UniSpeechConfig):
- super().__init__(config)
- self.unispeech = UniSpeechModel(config)
- self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
- self.quantizer = UniSpeechGumbelVectorQuantizer(config)
- self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
- self.project_hid = nn.Linear(config.proj_codevector_dim, config.hidden_size)
- self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes)
- self.dropout = nn.Dropout(config.final_dropout)
- # Initialize weights and apply final processing
- self.post_init()
- def set_gumbel_temperature(self, temperature: int):
- """
- Set the Gumbel softmax temperature to a given value. Only necessary for training
- """
- self.quantizer.temperature = temperature
- def freeze_feature_encoder(self):
- """
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
- not be updated during training.
- """
- self.unispeech.feature_extractor._freeze_parameters()
- @staticmethod
- def compute_contrastive_logits(
- target_features: torch.FloatTensor,
- negative_features: torch.FloatTensor,
- predicted_features: torch.FloatTensor,
- temperature: int = 1,
- ):
- """
- Compute logits for contrastive loss based using cosine similarity as the distance measure between
- `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
- """
- target_features = torch.cat([target_features, negative_features], dim=0)
- logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1)
- logits = logits.type_as(target_features)
- # apply temperature
- logits = logits / temperature
- return logits
- @auto_docstring
- def forward(
- self,
- input_values: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | UniSpeechForPreTrainingOutput:
- r"""
- Example:
- ```python
- >>> import torch
- >>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining
- >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv")
- >>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
- >>> # TODO: Add full pretraining example
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.unispeech(
- input_values,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- transformer_features = outputs[0]
- # quantize all (unmasked) extracted features and project to final vq dim
- extract_features = self.dropout_features(outputs[1])
- quantized_features, codevector_perplexity = self.quantizer(extract_features)
- # project quantized features twice
- quantized_features = self.project_q(quantized_features.to(self.project_q.weight.dtype))
- quantized_features = self.project_hid(quantized_features)
- prob_replace_matrix = torch.empty(transformer_features.size(0), transformer_features.size(1)).fill_(
- self.config.replace_prob
- )
- prob_replace_matrix = prob_replace_matrix.transpose(0, 1)
- sampled_replace_matrix = torch.bernoulli(prob_replace_matrix).bool().to(transformer_features.device)
- sampled_replace_matrix = sampled_replace_matrix.transpose(0, 1)
- sampled_replace_matrix = sampled_replace_matrix.unsqueeze(-1)
- logits = transformer_features.masked_fill(sampled_replace_matrix, 0.0) + (
- quantized_features.masked_fill(~sampled_replace_matrix, 0.0)
- )
- # project to ctc units
- logits = self.dropout(logits)
- logits = self.ctc_proj(logits)
- # TODO(PVP) - add negative sampling & loss computation
- loss = None
- if not return_dict:
- if loss is not None:
- return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
- return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
- return UniSpeechForPreTrainingOutput(
- loss=loss,
- projected_states=transformer_features,
- projected_quantized_states=quantized_features,
- codevector_perplexity=codevector_perplexity,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- class UniSpeechForCTC(Wav2Vec2ForCTC):
- pass
- class UniSpeechForSequenceClassification(Wav2Vec2ForSequenceClassification):
- pass
- __all__ = [
- "UniSpeechForCTC",
- "UniSpeechForPreTraining",
- "UniSpeechForSequenceClassification",
- "UniSpeechModel",
- "UniSpeechPreTrainedModel",
- ]
|