| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153 |
- # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch Wav2Vec2 model."""
- import math
- import warnings
- from collections.abc import Callable
- from dataclasses import dataclass
- import numpy as np
- import torch
- from safetensors.torch import load_file as safe_load_file
- from torch import nn
- from torch.nn import CrossEntropyLoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...integrations.deepspeed import is_deepspeed_zero3_enabled
- from ...integrations.fsdp import is_fsdp_managed_module
- from ...masking_utils import create_bidirectional_mask
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- CausalLMOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
- Wav2Vec2BaseModelOutput,
- XVectorOutput,
- )
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, get_torch_context_manager_or_global_device
- from ...processing_utils import Unpack
- from ...utils import (
- ModelOutput,
- TransformersKwargs,
- auto_docstring,
- cached_file,
- check_torch_load_is_safe,
- is_peft_available,
- logging,
- )
- from .configuration_wav2vec2 import Wav2Vec2Config
- WAV2VEC2_ADAPTER_PT_FILE = "adapter.{}.bin"
- WAV2VEC2_ADAPTER_SAFE_FILE = "adapter.{}.safetensors"
- logger = logging.get_logger(__name__)
- _HIDDEN_STATES_START_POSITION = 2
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`Wav2Vec2ForPreTraining`], with potential hidden states and attentions.
- """
- )
- class Wav2Vec2ForPreTrainingOutput(ModelOutput):
- r"""
- loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
- Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
- paper](https://huggingface.co/papers/2006.11477).
- projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
- Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
- projected quantized states.
- projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
- Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
- target vectors for contrastive loss.
- codevector_perplexity (`torch.FloatTensor` of shape `(1,)`):
- The perplexity of the codevector distribution, used to measure the diversity of the codebook.
- contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
- The contrastive loss (L_m) as stated in the [official paper](https://huggingface.co/papers/2006.11477).
- diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
- The diversity loss (L_d) as stated in the [official paper](https://huggingface.co/papers/2006.11477).
- """
- loss: torch.FloatTensor | None = None
- projected_states: torch.FloatTensor | None = None
- projected_quantized_states: torch.FloatTensor | None = None
- codevector_perplexity: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- contrastive_loss: torch.FloatTensor | None = None
- diversity_loss: torch.FloatTensor | None = None
- def _compute_mask_indices(
- shape: tuple[int, int],
- mask_prob: float,
- mask_length: int,
- attention_mask: torch.LongTensor | None = None,
- min_masks: int = 0,
- ) -> np.ndarray:
- """
- Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
- ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
- CPU as part of the preprocessing during training.
- Args:
- shape: The shape for which to compute masks. This should be of a tuple of size 2 where
- the first element is the batch size and the second element is the length of the axis to span.
- mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
- independently generated mask spans of length `mask_length` is computed by
- `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
- actual percentage will be smaller.
- mask_length: size of the mask
- min_masks: minimum number of masked spans
- attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
- each batch dimension.
- """
- batch_size, sequence_length = shape
- if mask_length < 1:
- raise ValueError("`mask_length` has to be bigger than 0.")
- if mask_length > sequence_length:
- raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
- f" and `sequence_length`: {sequence_length}`"
- )
- # epsilon is used for probabilistic rounding
- epsilon = np.random.rand(1).item()
- def compute_num_masked_span(input_length):
- """Given input length, compute how many spans should be masked"""
- num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
- num_masked_span = max(num_masked_span, min_masks)
- # make sure num masked span <= sequence_length
- if num_masked_span * mask_length > sequence_length:
- num_masked_span = sequence_length // mask_length
- # make sure num_masked span is also <= input_length - (mask_length - 1)
- if input_length - (mask_length - 1) < num_masked_span:
- num_masked_span = max(input_length - (mask_length - 1), 0)
- return num_masked_span
- # compute number of masked spans in batch
- input_lengths = (
- attention_mask.detach().sum(-1).tolist()
- if attention_mask is not None
- else [sequence_length for _ in range(batch_size)]
- )
- # SpecAugment mask to fill
- spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
- spec_aug_mask_idxs = []
- max_num_masked_span = compute_num_masked_span(sequence_length)
- if max_num_masked_span == 0:
- return spec_aug_mask
- for input_length in input_lengths:
- # compute num of masked spans for this input
- num_masked_span = compute_num_masked_span(input_length)
- # get random indices to mask
- spec_aug_mask_idx = np.random.choice(
- np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
- )
- # pick first sampled index that will serve as a dummy index to pad vector
- # to ensure same dimension for all batches due to probabilistic rounding
- # Picking first sample just pads those vectors twice.
- if len(spec_aug_mask_idx) == 0:
- # this case can only happen if `input_length` is strictly smaller then
- # `sequence_length` in which case the last token has to be a padding
- # token which we can use as a dummy mask id
- dummy_mask_idx = sequence_length - 1
- else:
- dummy_mask_idx = spec_aug_mask_idx[0]
- spec_aug_mask_idx = np.concatenate(
- [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
- )
- spec_aug_mask_idxs.append(spec_aug_mask_idx)
- spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
- # expand masked indices to masked spans
- spec_aug_mask_idxs = np.broadcast_to(
- spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
- # add offset to the starting indexes so that indexes now create a span
- offsets = np.arange(mask_length)[None, None, :]
- offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
- batch_size, max_num_masked_span * mask_length
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
- # ensure that we cannot have indices larger than sequence_length
- if spec_aug_mask_idxs.max() > sequence_length - 1:
- spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
- # scatter indices to mask
- np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
- return spec_aug_mask
- def _sample_negative_indices(features_shape: tuple, num_negatives: int, mask_time_indices: np.ndarray | None = None):
- """
- Sample `num_negatives` vectors from feature vectors.
- """
- batch_size, sequence_length = features_shape
- # generate indices of the positive vectors themselves, repeat them `num_negatives` times
- sequence_length_range = np.arange(sequence_length)
- # get `num_negatives` random vector indices from the same utterance
- sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
- mask_time_indices = (
- mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
- )
- for batch_idx in range(batch_size):
- high = mask_time_indices[batch_idx].sum() - 1
- mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
- feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
- sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
- # avoid sampling the same positive vector, but keep the distribution uniform
- sampled_indices[sampled_indices >= feature_indices] += 1
- # remap to actual indices
- sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
- # correct for batch size
- sampled_negative_indices[batch_idx] += batch_idx * sequence_length
- return sampled_negative_indices
- class Wav2Vec2NoLayerNormConvLayer(GradientCheckpointingLayer):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.activation = ACT2FN[config.feat_extract_activation]
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.activation(hidden_states)
- return hidden_states
- class Wav2Vec2LayerNormConvLayer(GradientCheckpointingLayer):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
- self.activation = ACT2FN[config.feat_extract_activation]
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = hidden_states.transpose(-2, -1)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = hidden_states.transpose(-2, -1)
- hidden_states = self.activation(hidden_states)
- return hidden_states
- class Wav2Vec2GroupNormConvLayer(GradientCheckpointingLayer):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.activation = ACT2FN[config.feat_extract_activation]
- self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = self.activation(hidden_states)
- return hidden_states
- class Wav2Vec2PositionalConvEmbedding(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.conv = nn.Conv1d(
- config.hidden_size,
- config.hidden_size,
- kernel_size=config.num_conv_pos_embeddings,
- padding=config.num_conv_pos_embeddings // 2,
- groups=config.num_conv_pos_embedding_groups,
- )
- weight_norm = nn.utils.weight_norm
- if hasattr(nn.utils.parametrizations, "weight_norm"):
- weight_norm = nn.utils.parametrizations.weight_norm
- if is_deepspeed_zero3_enabled():
- import deepspeed
- with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
- self.conv = weight_norm(self.conv, name="weight", dim=2)
- if hasattr(self.conv, "parametrizations"):
- weight_g = self.conv.parametrizations.weight.original0
- weight_v = self.conv.parametrizations.weight.original1
- else:
- weight_g = self.conv.weight_g
- weight_v = self.conv.weight_v
- deepspeed.zero.register_external_parameter(self, weight_v)
- deepspeed.zero.register_external_parameter(self, weight_g)
- else:
- self.conv = weight_norm(self.conv, name="weight", dim=2)
- self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
- self.activation = ACT2FN[config.feat_extract_activation]
- def forward(self, hidden_states):
- hidden_states = hidden_states.transpose(1, 2)
- hidden_states = self.conv(hidden_states)
- hidden_states = self.padding(hidden_states)
- hidden_states = self.activation(hidden_states)
- hidden_states = hidden_states.transpose(1, 2)
- return hidden_states
- class Wav2Vec2SamePadLayer(nn.Module):
- def __init__(self, num_conv_pos_embeddings):
- super().__init__()
- self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
- def forward(self, hidden_states):
- if self.num_pad_remove > 0:
- hidden_states = hidden_states[:, :, : -self.num_pad_remove]
- return hidden_states
- class Wav2Vec2FeatureEncoder(nn.Module):
- """Construct the features from raw audio waveform"""
- def __init__(self, config):
- super().__init__()
- if config.feat_extract_norm == "group":
- conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [
- Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
- ]
- elif config.feat_extract_norm == "layer":
- conv_layers = [
- Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
- ]
- else:
- raise ValueError(
- f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
- )
- self.conv_layers = nn.ModuleList(conv_layers)
- self.gradient_checkpointing = False
- self._requires_grad = True
- def _freeze_parameters(self):
- for param in self.parameters():
- param.requires_grad = False
- self._requires_grad = False
- def forward(self, input_values):
- hidden_states = input_values[:, None]
- # make sure hidden_states require grad for gradient_checkpointing
- if self._requires_grad and self.training:
- hidden_states.requires_grad = True
- for conv_layer in self.conv_layers:
- hidden_states = conv_layer(hidden_states)
- return hidden_states
- class Wav2Vec2FeatureProjection(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
- self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
- self.dropout = nn.Dropout(config.feat_proj_dropout)
- def forward(self, hidden_states):
- # non-projected hidden states are needed for quantization
- norm_hidden_states = self.layer_norm(hidden_states)
- hidden_states = self.projection(norm_hidden_states)
- hidden_states = self.dropout(hidden_states)
- return hidden_states, norm_hidden_states
- # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float | None = None,
- dropout: float = 0.0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- if scaling is None:
- scaling = query.size(-1) ** -0.5
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class Wav2Vec2Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(
- self,
- embed_dim: int,
- num_heads: int,
- dropout: float = 0.0,
- is_decoder: bool = False,
- bias: bool = True,
- is_causal: bool = False,
- config: Wav2Vec2Config | None = None,
- ):
- super().__init__()
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.dropout = dropout
- self.head_dim = embed_dim // num_heads
- self.config = config
- if (self.head_dim * num_heads) != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
- f" and `num_heads`: {num_heads})."
- )
- self.scaling = self.head_dim**-0.5
- self.is_decoder = is_decoder
- self.is_causal = is_causal
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = False,
- # TODO: we need a refactor so that the different attention modules can get their specific kwargs
- # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
- """Input shape: Batch x Time x Channel"""
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
- # determine input shapes
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- # get query proj
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- current_states = key_value_states if is_cross_attention else hidden_states
- kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
- key_states = self.k_proj(current_states).view(kv_shape).transpose(1, 2)
- value_states = self.v_proj(current_states).view(kv_shape).transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.dropout,
- scaling=self.scaling,
- output_attentions=output_attentions,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights, None
- class Wav2Vec2FeedForward(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.intermediate_dropout = nn.Dropout(config.activation_dropout)
- self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.output_dropout = nn.Dropout(config.hidden_dropout)
- def forward(self, hidden_states):
- hidden_states = self.intermediate_dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- hidden_states = self.intermediate_dropout(hidden_states)
- hidden_states = self.output_dense(hidden_states)
- hidden_states = self.output_dropout(hidden_states)
- return hidden_states
- class Wav2Vec2EncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config):
- super().__init__()
- self.attention = Wav2Vec2Attention(
- embed_dim=config.hidden_size,
- num_heads=config.num_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=False,
- config=config,
- )
- self.dropout = nn.Dropout(config.hidden_dropout)
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.feed_forward = Wav2Vec2FeedForward(config)
- self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states, attention_mask=None, output_attentions=False):
- attn_residual = hidden_states
- hidden_states, attn_weights, _ = self.attention(
- hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
- )
- hidden_states = self.dropout(hidden_states)
- hidden_states = attn_residual + hidden_states
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = hidden_states + self.feed_forward(hidden_states)
- hidden_states = self.final_layer_norm(hidden_states)
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (attn_weights,)
- return outputs
- class Wav2Vec2EncoderLayerStableLayerNorm(GradientCheckpointingLayer):
- def __init__(self, config):
- super().__init__()
- self.attention = Wav2Vec2Attention(
- embed_dim=config.hidden_size,
- num_heads=config.num_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=False,
- config=config,
- )
- self.dropout = nn.Dropout(config.hidden_dropout)
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.feed_forward = Wav2Vec2FeedForward(config)
- self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- if getattr(config, "adapter_attn_dim", None) is not None:
- self.adapter_layer = Wav2Vec2AttnAdapterLayer(config)
- else:
- self.adapter_layer = None
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool = False,
- ):
- attn_residual = hidden_states
- hidden_states = self.layer_norm(hidden_states)
- hidden_states, attn_weights, _ = self.attention(
- hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
- )
- hidden_states = self.dropout(hidden_states)
- hidden_states = attn_residual + hidden_states
- hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
- if self.adapter_layer is not None:
- hidden_states = hidden_states + self.adapter_layer(hidden_states)
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (attn_weights,)
- return outputs
- class Wav2Vec2Encoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout)
- self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states: torch.tensor,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- ):
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- if attention_mask is not None:
- # make sure padded tokens output 0
- expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_attention_mask] = 0
- attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=hidden_states,
- attention_mask=attention_mask,
- )
- position_embeddings = self.pos_conv_embed(hidden_states)
- hidden_states = hidden_states + position_embeddings.to(hidden_states.device)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = self.dropout(hidden_states)
- synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
- for layer in self.layers:
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
- dropout_probability = torch.rand([])
- skip_the_layer = self.training and dropout_probability < self.config.layerdrop
- if not skip_the_layer or synced_gpus:
- # under fsdp or deepspeed zero3 all gpus must run in sync
- layer_outputs = layer(
- hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
- )
- hidden_states = layer_outputs[0]
- if skip_the_layer:
- layer_outputs = (None, None)
- if output_attentions:
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- class Wav2Vec2EncoderStableLayerNorm(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout)
- self.layers = nn.ModuleList(
- [Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
- )
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- output_attentions=False,
- output_hidden_states=False,
- return_dict=True,
- ):
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- if attention_mask is not None:
- # make sure padded tokens output 0
- expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_attention_mask] = 0
- attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=hidden_states,
- attention_mask=attention_mask,
- )
- position_embeddings = self.pos_conv_embed(hidden_states)
- hidden_states = hidden_states + position_embeddings
- hidden_states = self.dropout(hidden_states)
- synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
- for layer in self.layers:
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
- dropout_probability = torch.rand([])
- skip_the_layer = self.training and dropout_probability < self.config.layerdrop
- if not skip_the_layer or synced_gpus:
- # under fsdp or deepspeed zero3 all gpus must run in sync
- # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
- layer_outputs = layer(
- hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
- )
- hidden_states = layer_outputs[0]
- if skip_the_layer:
- layer_outputs = (None, None)
- if output_attentions:
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
- hidden_states = self.layer_norm(hidden_states)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- class Wav2Vec2GumbelVectorQuantizer(nn.Module):
- """
- Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
- GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
- """
- def __init__(self, config):
- super().__init__()
- self.num_groups = config.num_codevector_groups
- self.num_vars = config.num_codevectors_per_group
- if config.codevector_dim % self.num_groups != 0:
- raise ValueError(
- f"`config.codevector_dim {config.codevector_dim} must be divisible "
- f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
- )
- # storage for codebook variables (codewords)
- self.codevectors = nn.Parameter(
- torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
- )
- self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
- # can be decayed for training
- self.temperature = 2
- @staticmethod
- def _compute_perplexity(probs, mask=None):
- if mask is not None:
- mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
- probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
- marginal_probs = probs.sum(dim=0) / mask.sum()
- else:
- marginal_probs = probs.mean(dim=0)
- perplexity = torch.exp(-torch.sum(torch.xlogy(marginal_probs, marginal_probs), dim=-1)).sum()
- return perplexity
- def forward(self, hidden_states, mask_time_indices=None):
- batch_size, sequence_length, hidden_size = hidden_states.shape
- # project to codevector dim
- hidden_states = self.weight_proj(hidden_states)
- hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
- if self.training:
- # sample code vector probs via gumbel in differentiateable way
- codevector_probs = nn.functional.gumbel_softmax(
- hidden_states.float(), tau=self.temperature, hard=True
- ).type_as(hidden_states)
- # compute perplexity
- codevector_soft_dist = torch.softmax(
- hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
- )
- perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
- else:
- # take argmax in non-differentiable way
- # comptute hard codevector distribution (one hot)
- codevector_idx = hidden_states.argmax(dim=-1)
- codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
- -1, codevector_idx.view(-1, 1), 1.0
- )
- codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
- perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
- codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
- # use probs to retrieve codevectors
- codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
- codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
- codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
- return codevectors, perplexity
- class Wav2Vec2Adapter(nn.Module):
- def __init__(self, config):
- super().__init__()
- # feature dim might need to be down-projected
- if config.output_hidden_size != config.hidden_size:
- self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
- self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
- else:
- self.proj = self.proj_layer_norm = None
- self.layers = nn.ModuleList(Wav2Vec2AdapterLayer(config) for _ in range(config.num_adapter_layers))
- self.layerdrop = config.layerdrop
- def forward(self, hidden_states):
- # down project hidden_states if necessary
- if self.proj is not None and self.proj_layer_norm is not None:
- hidden_states = self.proj(hidden_states)
- hidden_states = self.proj_layer_norm(hidden_states)
- hidden_states = hidden_states.transpose(1, 2)
- for layer in self.layers:
- layerdrop_prob = np.random.random()
- if not self.training or (layerdrop_prob > self.layerdrop):
- hidden_states = layer(hidden_states)
- hidden_states = hidden_states.transpose(1, 2)
- return hidden_states
- class Wav2Vec2AdapterLayer(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.conv = nn.Conv1d(
- config.output_hidden_size,
- 2 * config.output_hidden_size,
- config.adapter_kernel_size,
- stride=config.adapter_stride,
- padding=1,
- )
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = nn.functional.glu(hidden_states, dim=1)
- return hidden_states
- class Wav2Vec2AttnAdapterLayer(nn.Module):
- def __init__(self, config):
- """
- Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
- up training throughput.
- """
- super().__init__()
- self.input_dim = config.adapter_attn_dim
- self.hidden_dim = config.hidden_size
- self.norm = nn.LayerNorm(self.hidden_dim)
- self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
- self.act_fn = nn.ReLU()
- self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
- def forward(self, hidden_states: torch.FloatTensor):
- hidden_states = self.norm(hidden_states)
- hidden_states = self.linear_1(hidden_states)
- hidden_states = self.act_fn(hidden_states)
- hidden_states = self.linear_2(hidden_states)
- return hidden_states
- @auto_docstring
- class Wav2Vec2PreTrainedModel(PreTrainedModel):
- config: Wav2Vec2Config
- base_model_prefix = "wav2vec2"
- main_input_name = "input_values"
- input_modalities = "audio"
- supports_gradient_checkpointing = True
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
- if isinstance(module, Wav2Vec2ForPreTraining):
- module.project_hid.reset_parameters()
- module.project_q.reset_parameters()
- # gumbel softmax requires special init
- elif isinstance(module, Wav2Vec2GumbelVectorQuantizer):
- init.normal_(module.weight_proj.weight, mean=0.0, std=1)
- init.zeros_(module.weight_proj.bias)
- init.uniform_(module.codevectors)
- elif isinstance(module, Wav2Vec2PositionalConvEmbedding):
- init.normal_(
- module.conv.weight,
- mean=0,
- std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
- )
- init.constant_(module.conv.bias, 0)
- elif isinstance(module, Wav2Vec2FeatureProjection):
- k = math.sqrt(1 / module.projection.in_features)
- init.uniform_(module.projection.weight, a=-k, b=k)
- init.uniform_(module.projection.bias, a=-k, b=k)
- elif isinstance(module, nn.Linear):
- init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- elif isinstance(module, nn.Conv1d):
- init.kaiming_normal_(module.weight)
- if module.bias is not None:
- k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
- init.uniform_(module.bias, a=-k, b=k)
- def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int, add_adapter: bool | None = None):
- """
- Computes the output length of the convolutional layers
- """
- add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
- def _conv_out_length(input_length, kernel_size, stride):
- # 1D convolutional layer output length formula taken
- # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
- return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
- for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
- input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
- if add_adapter:
- for _ in range(self.config.num_adapter_layers):
- input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
- return input_lengths
- def _get_feature_vector_attention_mask(
- self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
- ):
- # Effectively attention_mask.sum(-1), but not inplace to be able to run
- # on inference mode.
- non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
- output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
- output_lengths = output_lengths.to(torch.long)
- batch_size = attention_mask.shape[0]
- attention_mask = torch.zeros(
- (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
- )
- # these two operations makes sure that all values before the output lengths idxs are attended to
- attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
- attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
- return attention_mask
- def _get_adapters(self):
- if self.config.adapter_attn_dim is None:
- raise ValueError(f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`.")
- adapter_weights = {}
- for name, module in self.named_modules():
- if isinstance(module, Wav2Vec2AttnAdapterLayer):
- for param_name, param in module.named_parameters():
- adapter_weights[".".join([name, param_name])] = param
- if isinstance(self, Wav2Vec2ForCTC):
- for name, param in self.lm_head.named_parameters():
- adapter_weights[".".join(["lm_head", name])] = param
- return adapter_weights
- def init_adapter_layers(self):
- """
- (Re-)initialize attention adapter layers and lm head for adapter-only fine-tuning
- """
- # init attention adapters
- for module in self.modules():
- if isinstance(module, Wav2Vec2AttnAdapterLayer):
- self._init_weights(module)
- # init lm head
- if isinstance(self, Wav2Vec2ForCTC):
- self._init_weights(self.lm_head)
- def load_adapter(self, target_lang: str, force_load=True, **kwargs):
- r"""
- Load a language adapter model from a pre-trained adapter model.
- Parameters:
- target_lang (`str`):
- Has to be a language id of an existing adapter weight. Adapter weights are stored in the format
- adapter.<lang>.safetensors or adapter.<lang>.bin
- force_load (`bool`, defaults to `True`):
- Whether the weights shall be loaded even if `target_lang` matches `self.target_lang`.
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
- standard cache should not be used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
- proxies (`dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only(`bool`, *optional*, defaults to `False`):
- Whether or not to only look at local files (i.e., do not try to download the model).
- token (`str` or `bool`, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
- the token generated when running `hf auth login` (stored in `~/.huggingface`).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
- <Tip>
- To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
- </Tip>
- mirror (`str`, *optional*):
- Mirror source to accelerate downloads in China. If you are from China and have an accessibility
- problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
- Please refer to the mirror site for more information.
- <Tip>
- Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
- use this method in a firewalled environment.
- </Tip>
- Examples:
- ```python
- >>> from transformers import Wav2Vec2ForCTC, AutoProcessor
- >>> ckpt = "facebook/mms-1b-all"
- >>> processor = AutoProcessor.from_pretrained(ckpt)
- >>> model = Wav2Vec2ForCTC.from_pretrained(ckpt, target_lang="eng")
- >>> # set specific language
- >>> processor.tokenizer.set_target_lang("spa")
- >>> model.load_adapter("spa")
- ```
- """
- if self.config.adapter_attn_dim is None:
- raise ValueError(f"Cannot load_adapter for {target_lang} if `config.adapter_attn_dim` is not defined.")
- if target_lang == self.target_lang and not force_load:
- logger.warning(f"Adapter weights are already set to {target_lang}.")
- return
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", False)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
- model_path_or_id = self.config._name_or_path
- state_dict = None
- # 1. Let's first try loading a safetensors adapter weight
- if use_safetensors is not False:
- filepath = WAV2VEC2_ADAPTER_SAFE_FILE.format(target_lang)
- try:
- weight_path = cached_file(
- model_path_or_id,
- filename=filepath,
- force_download=force_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- cache_dir=cache_dir,
- )
- state_dict = safe_load_file(weight_path)
- except OSError:
- if use_safetensors:
- # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
- # to the original exception.
- raise
- except Exception:
- # For any other exception, we throw a generic error.
- if use_safetensors:
- raise OSError(
- f"Can't load the model for '{model_path_or_id}'. If you were trying to load it"
- " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
- f" same name. Otherwise, make sure '{model_path_or_id}' is the correct path to a"
- f" directory containing a file named {filepath}."
- )
- # 2. If this didn't work let's try loading a PyTorch adapter weight
- if state_dict is None:
- filepath = WAV2VEC2_ADAPTER_PT_FILE.format(target_lang)
- try:
- weight_path = cached_file(
- model_path_or_id,
- filename=filepath,
- force_download=force_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- cache_dir=cache_dir,
- )
- check_torch_load_is_safe()
- state_dict = torch.load(
- weight_path,
- map_location="cpu",
- weights_only=True,
- )
- except OSError:
- # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
- # to the original exception.
- raise
- except ValueError:
- raise
- except Exception:
- # For any other exception, we throw a generic error.
- raise OSError(
- f"Can't load the model for '{model_path_or_id}'. If you were trying to load it"
- " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
- f" same name. Otherwise, make sure '{model_path_or_id}' is the correct path to a"
- f" directory containing a file named {filepath}."
- )
- adapter_weights = self._get_adapters()
- unexpected_keys = set(state_dict.keys()) - set(adapter_weights.keys())
- missing_keys = set(adapter_weights.keys()) - set(state_dict.keys())
- if len(unexpected_keys) > 0:
- raise ValueError(f"The adapter weights {weight_path} has unexpected keys: {', '.join(unexpected_keys)}.")
- elif len(missing_keys) > 0:
- raise ValueError(f"The adapter weights {weight_path} has missing keys: {', '.join(missing_keys)}.")
- # make sure now vocab size is correct
- target_vocab_size = state_dict["lm_head.weight"].shape[0]
- if target_vocab_size != self.config.vocab_size:
- self.lm_head = nn.Linear(
- self.config.output_hidden_size, target_vocab_size, device=self.device, dtype=self.dtype
- )
- self.config.vocab_size = target_vocab_size
- # make sure that adapter weights are put in exactly the same precision and device placement and overwritten adapter weights
- state_dict = {k: v.to(adapter_weights[k]) for k, v in state_dict.items()}
- self.load_state_dict(state_dict, strict=False)
- # set target language correctly
- self.target_lang = target_lang
- @auto_docstring
- class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
- def __init__(self, config: Wav2Vec2Config):
- super().__init__(config)
- self.config = config
- self.feature_extractor = Wav2Vec2FeatureEncoder(config)
- self.feature_projection = Wav2Vec2FeatureProjection(config)
- # model only needs masking vector if mask prob is > 0.0
- if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
- self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
- if config.do_stable_layer_norm:
- self.encoder = Wav2Vec2EncoderStableLayerNorm(config)
- else:
- self.encoder = Wav2Vec2Encoder(config)
- self.adapter = Wav2Vec2Adapter(config) if config.add_adapter else None
- # Initialize weights and apply final processing
- self.post_init()
- def freeze_feature_encoder(self):
- """
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
- not be updated during training.
- """
- self.feature_extractor._freeze_parameters()
- def _mask_hidden_states(
- self,
- hidden_states: torch.FloatTensor,
- mask_time_indices: torch.FloatTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- ):
- """
- Masks extracted features along time axis and/or along feature axis according to
- [SpecAugment](https://huggingface.co/papers/1904.08779).
- """
- # `config.apply_spec_augment` can set masking to False
- if not getattr(self.config, "apply_spec_augment", True):
- return hidden_states
- # generate indices & apply SpecAugment along time axis
- batch_size, sequence_length, hidden_size = hidden_states.size()
- if mask_time_indices is not None:
- # apply SpecAugment along time axis with given mask_time_indices
- hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
- elif self.config.mask_time_prob > 0 and self.training:
- mask_time_indices = _compute_mask_indices(
- (batch_size, sequence_length),
- mask_prob=self.config.mask_time_prob,
- mask_length=self.config.mask_time_length,
- attention_mask=attention_mask,
- min_masks=self.config.mask_time_min_masks,
- )
- mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
- hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
- if self.config.mask_feature_prob > 0 and self.training:
- # generate indices & apply SpecAugment along feature axis
- mask_feature_indices = _compute_mask_indices(
- (batch_size, hidden_size),
- mask_prob=self.config.mask_feature_prob,
- mask_length=self.config.mask_feature_length,
- min_masks=self.config.mask_feature_min_masks,
- )
- mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
- mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
- hidden_states[mask_feature_indices] = 0
- return hidden_states
- @auto_docstring
- def forward(
- self,
- input_values: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- mask_time_indices: torch.FloatTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | Wav2Vec2BaseModelOutput:
- r"""
- mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
- masked extracted features in *config.proj_codevector_dim* space.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- extract_features = self.feature_extractor(input_values)
- extract_features = extract_features.transpose(1, 2)
- if attention_mask is not None:
- # compute reduced attention_mask corresponding to feature vectors
- attention_mask = self._get_feature_vector_attention_mask(
- extract_features.shape[1], attention_mask, add_adapter=False
- )
- hidden_states, extract_features = self.feature_projection(extract_features)
- hidden_states = self._mask_hidden_states(
- hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
- )
- encoder_outputs = self.encoder(
- hidden_states,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = encoder_outputs[0]
- if self.adapter is not None:
- hidden_states = self.adapter(hidden_states)
- if not return_dict:
- return (hidden_states, extract_features) + encoder_outputs[1:]
- return Wav2Vec2BaseModelOutput(
- last_hidden_state=hidden_states,
- extract_features=extract_features,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- Wav2Vec2 Model with a quantizer and `VQ` head on top.
- """
- )
- class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
- def __init__(self, config: Wav2Vec2Config):
- super().__init__(config)
- self.wav2vec2 = Wav2Vec2Model(config)
- self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
- self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
- self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
- self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
- # Initialize weights and apply final processing
- self.post_init()
- def set_gumbel_temperature(self, temperature: int):
- """
- Set the Gumbel softmax temperature to a given value. Only necessary for training
- """
- self.quantizer.temperature = temperature
- def freeze_feature_encoder(self):
- """
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
- not be updated during training.
- """
- self.wav2vec2.feature_extractor._freeze_parameters()
- @staticmethod
- def compute_contrastive_logits(
- target_features: torch.FloatTensor,
- negative_features: torch.FloatTensor,
- predicted_features: torch.FloatTensor,
- temperature: float = 0.1,
- ):
- """
- Compute logits for contrastive loss based using cosine similarity as the distance measure between
- `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
- """
- target_features = torch.cat([target_features, negative_features], dim=0)
- logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
- target_features
- )
- # apply temperature
- logits = logits / temperature
- return logits
- @auto_docstring
- def forward(
- self,
- input_values: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- mask_time_indices: torch.BoolTensor | None = None,
- sampled_negative_indices: torch.BoolTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | Wav2Vec2ForPreTrainingOutput:
- r"""
- mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
- masked extracted features in *config.proj_codevector_dim* space.
- sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
- Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
- Required input for pre-training.
- Example:
- ```python
- >>> import torch
- >>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
- >>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
- >>> from datasets import load_dataset
- >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
- >>> model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
- >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
- >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
- >>> # compute masked indices
- >>> batch_size, raw_sequence_length = input_values.shape
- >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
- >>> mask_time_indices = _compute_mask_indices(
- ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
- ... )
- >>> sampled_negative_indices = _sample_negative_indices(
- ... features_shape=(batch_size, sequence_length),
- ... num_negatives=model.config.num_negatives,
- ... mask_time_indices=mask_time_indices,
- ... )
- >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
- >>> sampled_negative_indices = torch.tensor(
- ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
- ... )
- >>> with torch.no_grad():
- ... outputs = model(input_values, mask_time_indices=mask_time_indices)
- >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
- >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
- >>> # show that cosine similarity is much higher than random
- >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
- tensor(True)
- >>> # for contrastive loss training model should be put into train mode
- >>> model = model.train()
- >>> loss = model(
- ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
- ... ).loss
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if mask_time_indices is not None:
- mask_time_indices = mask_time_indices.to(torch.bool)
- outputs = self.wav2vec2(
- input_values,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- mask_time_indices=mask_time_indices,
- return_dict=return_dict,
- )
- # 1. project all transformed features (including masked) to final vq dim
- transformer_features = self.project_hid(outputs[0])
- # 2. quantize all (unmasked) extracted features and project to final vq dim
- extract_features = self.dropout_features(outputs[1])
- if attention_mask is not None:
- # compute reduced attention_mask corresponding to feature vectors
- attention_mask = self._get_feature_vector_attention_mask(
- extract_features.shape[1], attention_mask, add_adapter=False
- )
- quantized_features, codevector_perplexity = self.quantizer(
- extract_features, mask_time_indices=mask_time_indices
- )
- quantized_features = quantized_features.to(self.project_q.weight.dtype)
- quantized_features = self.project_q(quantized_features)
- loss = contrastive_loss = diversity_loss = None
- if sampled_negative_indices is not None:
- batch_size, sequence_length, hidden_size = quantized_features.shape
- # for training, we sample negatives
- # 3. sample K negatives (distractors) quantized states for contrastive loss
- # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
- # sample negative quantized vectors BTC => (BxT)C
- negative_quantized_features = quantized_features.view(-1, hidden_size)[
- sampled_negative_indices.long().view(-1)
- ]
- negative_quantized_features = negative_quantized_features.view(
- batch_size, sequence_length, -1, hidden_size
- ).permute(2, 0, 1, 3)
- # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
- # of equation (3) in https://huggingface.co/papers/2006.11477
- logits = self.compute_contrastive_logits(
- quantized_features[None, :],
- negative_quantized_features,
- transformer_features,
- self.config.contrastive_logits_temperature,
- )
- # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
- # its cosine similarity will be masked
- neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
- if neg_is_pos.any():
- logits[1:][neg_is_pos] = float("-inf")
- # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
- # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
- logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
- target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
- contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
- # 7. compute diversity loss: \mathbf{L}_d
- num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
- diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
- # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
- loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
- if not return_dict:
- if loss is not None:
- return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
- return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
- return Wav2Vec2ForPreTrainingOutput(
- loss=loss,
- projected_states=transformer_features,
- projected_quantized_states=quantized_features,
- codevector_perplexity=codevector_perplexity,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- contrastive_loss=contrastive_loss,
- diversity_loss=diversity_loss,
- )
- @auto_docstring(
- custom_intro="""
- Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
- """
- )
- class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
- def __init__(self, config, target_lang: str | None = None):
- r"""
- target_lang (`str`, *optional*):
- Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
- adapter.<lang>.bin. Only relevant when using an instance of [`Wav2Vec2ForCTC`] with adapters. Uses 'eng' by
- default.
- """
- super().__init__(config)
- self.wav2vec2 = Wav2Vec2Model(config)
- self.dropout = nn.Dropout(config.final_dropout)
- self.target_lang = target_lang
- if config.vocab_size is None:
- raise ValueError(
- f"You are trying to instantiate {self.__class__} with a configuration that "
- "does not define the vocabulary size of the language model head. Please "
- "instantiate the model as follows: `Wav2Vec2ForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
- "or define `vocab_size` of your model's configuration."
- )
- output_hidden_size = (
- config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
- )
- self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
- # Initialize weights and apply final processing
- self.post_init()
- def tie_weights(self, **kwargs):
- """
- This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
- passing `target_lang=...` to `from_pretrained(...)`.
- This method is **not** supposed to be called by the user and is prone to be changed in the future.
- """
- if get_torch_context_manager_or_global_device() == torch.device("meta"):
- return
- # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
- # correctly load adapter layers for Wav2Vec2 so that we do not have to introduce a new API to
- # [`PreTrainedModel`]. While slightly hacky, Wav2Vec2 never has to tie input and output embeddings, so that it is
- # ok to repurpose this function here.
- target_lang = self.target_lang
- if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
- raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
- elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
- logger.info("By default `target_lang` is set to 'eng'.")
- elif target_lang is not None:
- self.load_adapter(target_lang, force_load=True)
- def freeze_feature_encoder(self):
- """
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
- not be updated during training.
- """
- self.wav2vec2.feature_extractor._freeze_parameters()
- def freeze_base_model(self):
- """
- Calling this function will disable the gradient computation for the base model so that its parameters will not
- be updated during training. Only the classification head will be updated.
- """
- for param in self.wav2vec2.parameters():
- param.requires_grad = False
- @auto_docstring
- def forward(
- self,
- input_values: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- labels: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple | CausalLMOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
- Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
- the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
- All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
- config.vocab_size - 1]`.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if labels is not None and labels.max() >= self.config.vocab_size:
- raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
- outputs = self.wav2vec2(
- input_values,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = outputs[0]
- hidden_states = self.dropout(hidden_states)
- logits = self.lm_head(hidden_states)
- loss = None
- if labels is not None:
- # retrieve loss input_lengths from attention_mask
- attention_mask = (
- attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
- )
- input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
- # assuming that padded tokens are filled with -100
- # when not being attended to
- labels_mask = labels >= 0
- target_lengths = labels_mask.sum(-1)
- flattened_targets = labels.masked_select(labels_mask)
- # ctc_loss doesn't support fp16
- log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
- with torch.backends.cudnn.flags(enabled=False):
- loss = nn.functional.ctc_loss(
- log_probs,
- flattened_targets,
- input_lengths,
- target_lengths,
- blank=self.config.pad_token_id,
- reduction=self.config.ctc_loss_reduction,
- zero_infinity=self.config.ctc_zero_infinity,
- )
- if not return_dict:
- output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
- return ((loss,) + output) if loss is not None else output
- return CausalLMOutput(
- loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
- )
- @auto_docstring(
- custom_intro="""
- Wav2Vec2 Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
- SUPERB Keyword Spotting.
- """
- )
- class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- if hasattr(config, "add_adapter") and config.add_adapter:
- raise ValueError(
- "Sequence classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)"
- )
- self.wav2vec2 = Wav2Vec2Model(config)
- num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
- if config.use_weighted_layer_sum:
- self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
- self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
- self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- def freeze_feature_encoder(self):
- """
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
- not be updated during training.
- """
- self.wav2vec2.feature_extractor._freeze_parameters()
- def freeze_base_model(self):
- """
- Calling this function will disable the gradient computation for the base model so that its parameters will not
- be updated during training. Only the classification head will be updated.
- """
- for param in self.wav2vec2.parameters():
- param.requires_grad = False
- @auto_docstring
- def forward(
- self,
- input_values: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- labels: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple | SequenceClassifierOutput:
- r"""
- input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
- Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
- into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
- (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
- To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
- into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
- outputs = self.wav2vec2(
- input_values,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- if self.config.use_weighted_layer_sum:
- hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
- hidden_states = torch.stack(hidden_states, dim=1)
- norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
- hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
- else:
- hidden_states = outputs[0]
- hidden_states = self.projector(hidden_states)
- if attention_mask is None:
- pooled_output = hidden_states.mean(dim=1)
- else:
- padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
- expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_padding_mask] = 0.0
- pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
- if not return_dict:
- output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
- return ((loss,) + output) if loss is not None else output
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- if hasattr(config, "add_adapter") and config.add_adapter:
- raise ValueError(
- "Audio frame classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)"
- )
- self.wav2vec2 = Wav2Vec2Model(config)
- num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
- if config.use_weighted_layer_sum:
- self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- self.num_labels = config.num_labels
- self.post_init()
- def freeze_feature_encoder(self):
- """
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
- not be updated during training.
- """
- self.wav2vec2.feature_extractor._freeze_parameters()
- def freeze_base_model(self):
- """
- Calling this function will disable the gradient computation for the base model so that its parameters will not
- be updated during training. Only the classification head will be updated.
- """
- for param in self.wav2vec2.parameters():
- param.requires_grad = False
- @auto_docstring
- def forward(
- self,
- input_values: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | TokenClassifierOutput:
- r"""
- input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
- Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
- into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
- (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
- To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
- into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
- outputs = self.wav2vec2(
- input_values,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- if self.config.use_weighted_layer_sum:
- hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
- hidden_states = torch.stack(hidden_states, dim=1)
- norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
- hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
- else:
- hidden_states = outputs[0]
- logits = self.classifier(hidden_states)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
- if not return_dict:
- output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
- return output
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- class AMSoftmaxLoss(nn.Module):
- def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
- super().__init__()
- self.scale = scale
- self.margin = margin
- self.num_labels = num_labels
- self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
- self.loss = nn.CrossEntropyLoss()
- def forward(self, hidden_states, labels):
- labels = labels.flatten()
- weight = nn.functional.normalize(self.weight, dim=0)
- hidden_states = nn.functional.normalize(hidden_states, dim=1)
- cos_theta = torch.mm(hidden_states, weight)
- psi = cos_theta - self.margin
- onehot = nn.functional.one_hot(labels, self.num_labels)
- logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
- loss = self.loss(logits, labels)
- return loss
- class TDNNLayer(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
- self.out_conv_dim = config.tdnn_dim[layer_id]
- self.kernel_size = config.tdnn_kernel[layer_id]
- self.dilation = config.tdnn_dilation[layer_id]
- self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
- self.activation = nn.ReLU()
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- if is_peft_available():
- from peft.tuners.lora import LoraLayer
- if is_peft_available():
- if isinstance(self.kernel, LoraLayer):
- warnings.warn(
- "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
- "You should exclude TDNNLayer from LoRA's target modules.",
- )
- # for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
- hidden_states = hidden_states.transpose(1, 2)
- weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
- hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
- hidden_states = hidden_states.transpose(1, 2)
- hidden_states = self.activation(hidden_states)
- return hidden_states
- @auto_docstring(
- custom_intro="""
- Wav2Vec2 Model with an XVector feature extraction head on top for tasks like Speaker Verification.
- """
- )
- class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.wav2vec2 = Wav2Vec2Model(config)
- num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
- if config.use_weighted_layer_sum:
- self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
- self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
- tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
- self.tdnn = nn.ModuleList(tdnn_layers)
- self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
- self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
- self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
- self.post_init()
- def freeze_feature_encoder(self):
- """
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
- not be updated during training.
- """
- self.wav2vec2.feature_extractor._freeze_parameters()
- def freeze_base_model(self):
- """
- Calling this function will disable the gradient computation for the base model so that its parameters will not
- be updated during training. Only the classification head will be updated.
- """
- for param in self.wav2vec2.parameters():
- param.requires_grad = False
- def _get_tdnn_output_lengths(self, input_lengths: torch.LongTensor | int):
- """
- Computes the output length of the TDNN layers
- """
- def _conv_out_length(input_length, kernel_size, stride):
- # 1D convolutional layer output length formula taken
- # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
- return (input_length - kernel_size) // stride + 1
- for kernel_size in self.config.tdnn_kernel:
- input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
- return input_lengths
- @auto_docstring
- def forward(
- self,
- input_values: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- labels: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple | XVectorOutput:
- r"""
- input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
- Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
- into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
- (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
- To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
- into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
- outputs = self.wav2vec2(
- input_values,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- if self.config.use_weighted_layer_sum:
- hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
- hidden_states = torch.stack(hidden_states, dim=1)
- norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
- hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
- else:
- hidden_states = outputs[0]
- hidden_states = self.projector(hidden_states)
- for tdnn_layer in self.tdnn:
- hidden_states = tdnn_layer(hidden_states)
- # Statistic Pooling
- if attention_mask is None:
- mean_features = hidden_states.mean(dim=1)
- std_features = hidden_states.std(dim=1)
- else:
- feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
- tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
- mean_features = []
- std_features = []
- for i, length in enumerate(tdnn_output_lengths):
- mean_features.append(hidden_states[i, :length].mean(dim=0))
- std_features.append(hidden_states[i, :length].std(dim=0))
- mean_features = torch.stack(mean_features)
- std_features = torch.stack(std_features)
- statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
- output_embeddings = self.feature_extractor(statistic_pooling)
- logits = self.classifier(output_embeddings)
- loss = None
- if labels is not None:
- loss = self.objective(logits, labels)
- if not return_dict:
- output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
- return ((loss,) + output) if loss is not None else output
- return XVectorOutput(
- loss=loss,
- logits=logits,
- embeddings=output_embeddings,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- __all__ = [
- "Wav2Vec2ForAudioFrameClassification",
- "Wav2Vec2ForCTC",
- "Wav2Vec2ForPreTraining",
- "Wav2Vec2ForSequenceClassification",
- "Wav2Vec2ForXVector",
- "Wav2Vec2Model",
- "Wav2Vec2PreTrainedModel",
- ]
|