| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449 |
- # Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch SEW model."""
- import math
- import torch
- from torch import nn
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...integrations.deepspeed import is_deepspeed_zero3_enabled
- from ...integrations.fsdp import is_fsdp_managed_module
- from ...modeling_outputs import BaseModelOutput
- from ...modeling_utils import PreTrainedModel
- from ...utils import auto_docstring
- from ...utils.generic import is_flash_attention_requested
- from ..wav2vec2.modeling_wav2vec2 import (
- Wav2Vec2Attention,
- Wav2Vec2EncoderLayer,
- Wav2Vec2FeatureEncoder,
- Wav2Vec2FeedForward,
- Wav2Vec2ForCTC,
- Wav2Vec2ForSequenceClassification,
- Wav2Vec2GroupNormConvLayer,
- Wav2Vec2LayerNormConvLayer,
- Wav2Vec2NoLayerNormConvLayer,
- Wav2Vec2SamePadLayer,
- _compute_mask_indices,
- )
- from .configuration_sew import SEWConfig
- _HIDDEN_STATES_START_POSITION = 1
- class SEWNoLayerNormConvLayer(Wav2Vec2NoLayerNormConvLayer):
- pass
- class SEWLayerNormConvLayer(Wav2Vec2LayerNormConvLayer):
- pass
- class SEWGroupNormConvLayer(Wav2Vec2GroupNormConvLayer):
- pass
- class SEWPositionalConvEmbedding(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.conv = nn.Conv1d(
- config.hidden_size,
- config.hidden_size,
- kernel_size=config.num_conv_pos_embeddings,
- padding=config.num_conv_pos_embeddings // 2,
- groups=config.num_conv_pos_embedding_groups,
- stride=config.squeeze_factor,
- )
- weight_norm = nn.utils.weight_norm
- if hasattr(nn.utils.parametrizations, "weight_norm"):
- weight_norm = nn.utils.parametrizations.weight_norm
- if is_deepspeed_zero3_enabled():
- import deepspeed
- with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
- self.conv = weight_norm(self.conv, name="weight", dim=2)
- if hasattr(self.conv, "parametrizations"):
- weight_g = self.conv.parametrizations.weight.original0
- weight_v = self.conv.parametrizations.weight.original1
- else:
- weight_g = self.conv.weight_g
- weight_v = self.conv.weight_v
- deepspeed.zero.register_external_parameter(self, weight_v)
- deepspeed.zero.register_external_parameter(self, weight_g)
- else:
- self.conv = weight_norm(self.conv, name="weight", dim=2)
- self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings)
- self.activation = ACT2FN[config.feat_extract_activation]
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.padding(hidden_states)
- hidden_states = self.activation(hidden_states)
- return hidden_states
- class SEWSamePadLayer(Wav2Vec2SamePadLayer):
- pass
- class SEWUpsampling(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor)
- self.activation = ACT2FN[config.feat_extract_activation]
- self.squeeze_factor = config.squeeze_factor
- def forward(self, hidden_states):
- hidden_states = self.projection(hidden_states)
- hidden_states = self.activation(hidden_states)
- if self.squeeze_factor > 1:
- # transform embedding channels to sequence length
- bsz, src_len, src_embed_dim = hidden_states.size()
- tgt_len = src_len * self.squeeze_factor
- tgt_embed_dim = src_embed_dim // self.squeeze_factor
- hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim)
- hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim)
- return hidden_states
- class SEWFeatureEncoder(Wav2Vec2FeatureEncoder):
- pass
- class SEWAttention(Wav2Vec2Attention):
- pass
- class SEWFeedForward(Wav2Vec2FeedForward):
- pass
- class SEWEncoderLayer(Wav2Vec2EncoderLayer):
- pass
- class SEWEncoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.pos_conv_embed = SEWPositionalConvEmbedding(config)
- self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor)
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout)
- self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.upsample = SEWUpsampling(config)
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- output_attentions=False,
- output_hidden_states=False,
- return_dict=True,
- ):
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- if attention_mask is not None:
- expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- if is_flash_attention_requested(self.config):
- # make sure padded tokens output 0
- hidden_states[~expand_attention_mask] = 0.0
- # 2d mask is passed through the layers
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- else:
- # make sure padded tokens output 0
- hidden_states[~expand_attention_mask] = 0.0
- input_lengths = (attention_mask.long()).sum(-1)
- # apply pooling formula to get real output_lengths
- output_lengths = input_lengths // self.config.squeeze_factor
- max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor
- attention_ids = (
- torch.arange(0, max_encoder_length, device=output_lengths.device)
- .view(1, -1)
- .expand(output_lengths.shape[0], -1)
- )
- attention_mask = (attention_ids < output_lengths.view(-1, 1)).long()
- # extend attention_mask
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
- attention_mask = attention_mask.expand(
- attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
- )
- n_input_timesteps = hidden_states.shape[1]
- hidden_states = hidden_states.transpose(1, 2)
- position_embeddings = self.pos_conv_embed(hidden_states)
- pooled_hidden_states = self.pool(hidden_states)
- min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1))
- hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length]
- hidden_states = hidden_states.transpose(1, 2)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = self.dropout(hidden_states)
- synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
- for layer in self.layers:
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
- dropout_probability = torch.rand([])
- skip_the_layer = self.training and dropout_probability < self.config.layerdrop
- if not skip_the_layer or synced_gpus:
- # under fsdp or deepspeed zero3 all gpus must run in sync
- layer_outputs = layer(
- hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
- )
- hidden_states = layer_outputs[0]
- if skip_the_layer:
- layer_outputs = (None, None)
- if output_attentions:
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- hidden_states = self.upsample(hidden_states)
- if hidden_states.shape[1] < n_input_timesteps:
- hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1]))
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- @auto_docstring
- class SEWPreTrainedModel(PreTrainedModel):
- config: SEWConfig
- base_model_prefix = "sew"
- main_input_name = "input_values"
- input_modalities = "audio"
- supports_gradient_checkpointing = True
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = False # needs a proper look into the mask creation
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- if isinstance(module, SEWPositionalConvEmbedding):
- init.normal_(
- module.conv.weight,
- mean=0,
- std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
- )
- init.constant_(module.conv.bias, 0)
- elif isinstance(module, nn.Linear):
- init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
- elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- elif isinstance(module, nn.Conv1d):
- if is_deepspeed_zero3_enabled():
- import deepspeed
- if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
- with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
- init.kaiming_normal_(module.weight)
- else:
- with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
- init.kaiming_normal_(module.weight)
- else:
- init.kaiming_normal_(module.weight)
- if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
- init.zeros_(module.bias)
- def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int):
- """
- Computes the output length of the convolutional layers
- """
- def _conv_out_length(input_length, kernel_size, stride):
- # 1D convolutional layer output length formula taken
- # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
- return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
- for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
- input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
- return input_lengths
- def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
- output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
- batch_size = attention_mask.shape[0]
- attention_mask = torch.zeros(
- (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
- )
- # these two operations makes sure that all values before the output lengths idxs are attended to
- attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
- attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
- return attention_mask
- @auto_docstring
- class SEWModel(SEWPreTrainedModel):
- def __init__(self, config: SEWConfig):
- super().__init__(config)
- self.config = config
- self.feature_extractor = SEWFeatureEncoder(config)
- self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
- self.project_features = config.conv_dim[-1] != config.hidden_size
- if self.project_features:
- self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
- self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
- if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
- self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
- self.encoder = SEWEncoder(config)
- # Initialize weights and apply final processing
- self.post_init()
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
- def _mask_hidden_states(
- self,
- hidden_states: torch.FloatTensor,
- mask_time_indices: torch.FloatTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- ):
- """
- Masks extracted features along time axis and/or along feature axis according to
- [SpecAugment](https://huggingface.co/papers/1904.08779).
- """
- # `config.apply_spec_augment` can set masking to False
- if not getattr(self.config, "apply_spec_augment", True):
- return hidden_states
- # generate indices & apply SpecAugment along time axis
- batch_size, sequence_length, hidden_size = hidden_states.size()
- if mask_time_indices is not None:
- # apply SpecAugment along time axis with given mask_time_indices
- hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
- elif self.config.mask_time_prob > 0 and self.training:
- mask_time_indices = _compute_mask_indices(
- (batch_size, sequence_length),
- mask_prob=self.config.mask_time_prob,
- mask_length=self.config.mask_time_length,
- attention_mask=attention_mask,
- min_masks=self.config.mask_time_min_masks,
- )
- mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
- hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
- if self.config.mask_feature_prob > 0 and self.training:
- # generate indices & apply SpecAugment along feature axis
- mask_feature_indices = _compute_mask_indices(
- (batch_size, hidden_size),
- mask_prob=self.config.mask_feature_prob,
- mask_length=self.config.mask_feature_length,
- min_masks=self.config.mask_feature_min_masks,
- )
- mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
- mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
- hidden_states[mask_feature_indices] = 0
- return hidden_states
- @auto_docstring
- def forward(
- self,
- input_values: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- mask_time_indices: torch.FloatTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | BaseModelOutput:
- r"""
- mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
- masked extracted features in *config.proj_codevector_dim* space.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- extract_features = self.feature_extractor(input_values)
- extract_features = extract_features.transpose(1, 2)
- extract_features = self.layer_norm(extract_features)
- if self.project_features:
- extract_features = self.feature_projection(extract_features)
- hidden_states = self.feature_dropout(extract_features)
- if attention_mask is not None:
- # compute reduced attention_mask corresponding to feature vectors
- attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
- hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
- encoder_outputs = self.encoder(
- hidden_states,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = encoder_outputs[0]
- if not return_dict:
- return (hidden_states,) + encoder_outputs[1:]
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- class SEWForCTC(Wav2Vec2ForCTC):
- pass
- class SEWForSequenceClassification(Wav2Vec2ForSequenceClassification):
- pass
- __all__ = ["SEWForCTC", "SEWForSequenceClassification", "SEWModel", "SEWPreTrainedModel"]
|