| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/wavlm/modular_wavlm.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_wavlm.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- import math
- import warnings
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn import CrossEntropyLoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...integrations.deepspeed import is_deepspeed_zero3_enabled
- from ...integrations.fsdp import is_fsdp_managed_module
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- CausalLMOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
- Wav2Vec2BaseModelOutput,
- XVectorOutput,
- )
- from ...modeling_utils import PreTrainedModel, get_torch_context_manager_or_global_device
- from ...utils import auto_docstring, is_peft_available, logging
- from .configuration_wavlm import WavLMConfig
- logger = logging.get_logger(__name__)
- class WavLMSamePadLayer(nn.Module):
- def __init__(self, num_conv_pos_embeddings):
- super().__init__()
- self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
- def forward(self, hidden_states):
- if self.num_pad_remove > 0:
- hidden_states = hidden_states[:, :, : -self.num_pad_remove]
- return hidden_states
- class WavLMPositionalConvEmbedding(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.conv = nn.Conv1d(
- config.hidden_size,
- config.hidden_size,
- kernel_size=config.num_conv_pos_embeddings,
- padding=config.num_conv_pos_embeddings // 2,
- groups=config.num_conv_pos_embedding_groups,
- )
- weight_norm = nn.utils.weight_norm
- if hasattr(nn.utils.parametrizations, "weight_norm"):
- weight_norm = nn.utils.parametrizations.weight_norm
- if is_deepspeed_zero3_enabled():
- import deepspeed
- with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
- self.conv = weight_norm(self.conv, name="weight", dim=2)
- if hasattr(self.conv, "parametrizations"):
- weight_g = self.conv.parametrizations.weight.original0
- weight_v = self.conv.parametrizations.weight.original1
- else:
- weight_g = self.conv.weight_g
- weight_v = self.conv.weight_v
- deepspeed.zero.register_external_parameter(self, weight_v)
- deepspeed.zero.register_external_parameter(self, weight_g)
- else:
- self.conv = weight_norm(self.conv, name="weight", dim=2)
- self.padding = WavLMSamePadLayer(config.num_conv_pos_embeddings)
- self.activation = ACT2FN[config.feat_extract_activation]
- def forward(self, hidden_states):
- hidden_states = hidden_states.transpose(1, 2)
- hidden_states = self.conv(hidden_states)
- hidden_states = self.padding(hidden_states)
- hidden_states = self.activation(hidden_states)
- hidden_states = hidden_states.transpose(1, 2)
- return hidden_states
- class WavLMFeatureProjection(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
- self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
- self.dropout = nn.Dropout(config.feat_proj_dropout)
- def forward(self, hidden_states):
- # non-projected hidden states are needed for quantization
- norm_hidden_states = self.layer_norm(hidden_states)
- hidden_states = self.projection(norm_hidden_states)
- hidden_states = self.dropout(hidden_states)
- return hidden_states, norm_hidden_states
- class WavLMAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(
- self,
- embed_dim: int,
- num_heads: int,
- dropout: float | int = 0.0,
- num_buckets: int = 320,
- max_distance: int = 800,
- has_relative_position_bias: bool = True,
- ):
- super().__init__()
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.dropout = dropout
- self.head_dim = embed_dim // num_heads
- if (self.head_dim * num_heads) != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
- f" and `num_heads`: {num_heads})."
- )
- self.scaling = self.head_dim**-0.5
- self.k_proj = nn.Linear(embed_dim, embed_dim)
- self.v_proj = nn.Linear(embed_dim, embed_dim)
- self.q_proj = nn.Linear(embed_dim, embed_dim)
- self.out_proj = nn.Linear(embed_dim, embed_dim)
- self.num_buckets = num_buckets
- self.max_distance = max_distance
- self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1))
- self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
- if has_relative_position_bias:
- self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- position_bias: torch.Tensor | None = None,
- output_attentions: bool = False,
- index=0,
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
- """Attention layer with relative attention"""
- bsz, tgt_len, _ = hidden_states.size()
- # first pass of attention layer creates position bias
- if position_bias is None:
- position_bias = self.compute_bias(tgt_len, tgt_len)
- position_bias = (
- position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, tgt_len)
- )
- # Compute relative position bias:
- # 1) get reshape hidden_states
- gated_hidden_states = hidden_states.view(hidden_states.shape[:-1] + (self.num_heads, -1))
- gated_hidden_states = gated_hidden_states.permute(0, 2, 1, 3)
- # 2) project hidden states
- relative_position_proj = self.gru_rel_pos_linear(gated_hidden_states)
- relative_position_proj = relative_position_proj.view(gated_hidden_states.shape[:-1] + (2, 4)).sum(-1)
- # 3) compute gate for position bias from projected hidden states
- gate_a, gate_b = torch.sigmoid(relative_position_proj).chunk(2, dim=-1)
- gate_output = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
- # 4) apply gate to position bias to compute gated position_bias
- gated_position_bias = gate_output.view(bsz * self.num_heads, -1, 1) * position_bias
- gated_position_bias = gated_position_bias.view((-1, tgt_len, tgt_len))
- attn_output, attn_weights = self.torch_multi_head_self_attention(
- hidden_states, attention_mask, gated_position_bias, output_attentions
- )
- return attn_output, attn_weights, position_bias
- def torch_multi_head_self_attention(
- self,
- hidden_states: torch.FloatTensor,
- attention_mask: torch.LongTensor | torch.BoolTensor,
- gated_position_bias: torch.FloatTensor,
- output_attentions: bool,
- ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
- """simple wrapper around torch's multi_head_attention_forward function"""
- # self-attention assumes q = k = v
- query = key = value = hidden_states.transpose(0, 1)
- key_padding_mask = attention_mask.ne(1) if attention_mask is not None else None
- # disable bias and add_zero_attn
- bias_k = bias_v = None
- add_zero_attn = False
- # PyTorch 1.3.0 has F.multi_head_attention_forward defined
- # so no problem with backwards compatibility
- attn_output, attn_weights = F.multi_head_attention_forward(
- query,
- key,
- value,
- self.embed_dim,
- self.num_heads,
- torch.empty([0]),
- torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
- bias_k,
- bias_v,
- add_zero_attn,
- self.dropout,
- self.out_proj.weight,
- self.out_proj.bias,
- self.training,
- key_padding_mask,
- output_attentions,
- gated_position_bias,
- use_separate_proj_weight=True,
- q_proj_weight=self.q_proj.weight,
- k_proj_weight=self.k_proj.weight,
- v_proj_weight=self.v_proj.weight,
- )
- # [Seq_Len, Batch Size, ...] -> [Batch Size, Seq_Len, ...]
- attn_output = attn_output.transpose(0, 1)
- if attn_weights is not None:
- # IMPORTANT: Attention weights are averaged weights
- # here which should not be the case. This is an open issue
- # on PyTorch: https://github.com/pytorch/pytorch/issues/32590
- attn_weights = attn_weights[:, None].broadcast_to(
- attn_weights.shape[:1] + (self.num_heads,) + attn_weights.shape[1:]
- )
- return attn_output, attn_weights
- def compute_bias(self, query_length: int, key_length: int) -> torch.FloatTensor:
- context_position = torch.arange(query_length, dtype=torch.long)[:, None]
- memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
- relative_position = memory_position - context_position
- relative_position_bucket = self._relative_positions_bucket(relative_position)
- relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device)
- values = self.rel_attn_embed(relative_position_bucket)
- values = values.permute([2, 0, 1])
- return values
- def _relative_positions_bucket(self, relative_positions: torch.FloatTensor) -> torch.FloatTensor:
- num_buckets = self.num_buckets // 2
- relative_buckets = (relative_positions > 0).to(torch.long) * num_buckets
- relative_positions = torch.abs(relative_positions)
- max_exact = num_buckets // 2
- is_small = relative_positions < max_exact
- relative_positions_if_large = torch.log(relative_positions.float() / max_exact)
- relative_positions_if_large = relative_positions_if_large / math.log(self.max_distance / max_exact)
- relative_positions_if_large = relative_positions_if_large * (num_buckets - max_exact)
- relative_position_if_large = (max_exact + relative_positions_if_large).to(torch.long)
- relative_position_if_large = torch.min(
- relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
- )
- relative_buckets += torch.where(is_small, relative_positions, relative_position_if_large)
- return relative_buckets
- class WavLMFeedForward(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.intermediate_dropout = nn.Dropout(config.activation_dropout)
- self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.output_dropout = nn.Dropout(config.hidden_dropout)
- def forward(self, hidden_states):
- hidden_states = self.intermediate_dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- hidden_states = self.intermediate_dropout(hidden_states)
- hidden_states = self.output_dense(hidden_states)
- hidden_states = self.output_dropout(hidden_states)
- return hidden_states
- class WavLMEncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
- super().__init__()
- self.attention = WavLMAttention(
- embed_dim=config.hidden_size,
- num_heads=config.num_attention_heads,
- dropout=config.attention_dropout,
- num_buckets=config.num_buckets,
- max_distance=config.max_bucket_distance,
- has_relative_position_bias=has_relative_position_bias,
- )
- self.dropout = nn.Dropout(config.hidden_dropout)
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.feed_forward = WavLMFeedForward(config)
- self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
- attn_residual = hidden_states
- hidden_states, attn_weights, position_bias = self.attention(
- hidden_states,
- attention_mask=attention_mask,
- position_bias=position_bias,
- output_attentions=output_attentions,
- index=index,
- )
- hidden_states = self.dropout(hidden_states)
- hidden_states = attn_residual + hidden_states
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = hidden_states + self.feed_forward(hidden_states)
- hidden_states = self.final_layer_norm(hidden_states)
- outputs = (hidden_states, position_bias)
- if output_attentions:
- outputs += (attn_weights,)
- return outputs
- class WavLMEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
- def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
- super().__init__()
- self.attention = WavLMAttention(
- embed_dim=config.hidden_size,
- num_heads=config.num_attention_heads,
- dropout=config.attention_dropout,
- num_buckets=config.num_buckets,
- max_distance=config.max_bucket_distance,
- has_relative_position_bias=has_relative_position_bias,
- )
- self.dropout = nn.Dropout(config.hidden_dropout)
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.feed_forward = WavLMFeedForward(config)
- self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
- attn_residual = hidden_states
- hidden_states = self.layer_norm(hidden_states)
- hidden_states, attn_weights, position_bias = self.attention(
- hidden_states,
- attention_mask=attention_mask,
- position_bias=position_bias,
- output_attentions=output_attentions,
- )
- hidden_states = self.dropout(hidden_states)
- hidden_states = attn_residual + hidden_states
- hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
- outputs = (hidden_states, position_bias)
- if output_attentions:
- outputs += (attn_weights,)
- return outputs
- class WavLMEncoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout)
- self.layers = nn.ModuleList(
- [WavLMEncoderLayer(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers)]
- )
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- output_attentions=False,
- output_hidden_states=False,
- return_dict=True,
- ):
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- if attention_mask is not None:
- # make sure padded tokens output 0
- expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_attention_mask] = 0
- position_embeddings = self.pos_conv_embed(hidden_states)
- hidden_states = hidden_states + position_embeddings
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = self.dropout(hidden_states)
- synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
- position_bias = None
- for i, layer in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
- dropout_probability = torch.rand([])
- skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
- if not skip_the_layer or synced_gpus:
- # under fsdp or deepspeed zero3 all gpus must run in sync
- layer_outputs = layer(
- hidden_states,
- attention_mask=attention_mask,
- position_bias=position_bias,
- output_attentions=output_attentions,
- index=i,
- )
- hidden_states, position_bias = layer_outputs[:2]
- if skip_the_layer:
- layer_outputs = (None, None, None)
- if output_attentions:
- all_self_attentions = all_self_attentions + (layer_outputs[2],)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- class WavLMEncoderStableLayerNorm(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout)
- self.layers = nn.ModuleList(
- [
- WavLMEncoderLayerStableLayerNorm(config, has_relative_position_bias=(i == 0))
- for i in range(config.num_hidden_layers)
- ]
- )
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- output_attentions=False,
- output_hidden_states=False,
- return_dict=True,
- ):
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- if attention_mask is not None:
- # make sure padded tokens are not attended to
- expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_attention_mask] = 0
- position_embeddings = self.pos_conv_embed(hidden_states)
- hidden_states = hidden_states + position_embeddings
- hidden_states = self.dropout(hidden_states)
- synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
- position_bias = None
- for i, layer in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
- dropout_probability = torch.rand([])
- skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
- if not skip_the_layer or synced_gpus:
- # under fsdp or deepspeed zero3 all gpus must run in sync
- # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
- layer_outputs = layer(
- hidden_states,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- position_bias=position_bias,
- )
- hidden_states, position_bias = layer_outputs[:2]
- if skip_the_layer:
- layer_outputs = (None, None, None)
- if output_attentions:
- all_self_attentions = all_self_attentions + (layer_outputs[2],)
- hidden_states = self.layer_norm(hidden_states)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
- )
- class WavLMGumbelVectorQuantizer(nn.Module):
- """
- Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
- GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
- """
- def __init__(self, config):
- super().__init__()
- self.num_groups = config.num_codevector_groups
- self.num_vars = config.num_codevectors_per_group
- if config.codevector_dim % self.num_groups != 0:
- raise ValueError(
- f"`config.codevector_dim {config.codevector_dim} must be divisible"
- f" by `config.num_codevector_groups` {self.num_groups} "
- "for concatenation."
- )
- # storage for codebook variables (codewords)
- self.codevectors = nn.Parameter(
- torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
- )
- self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
- # can be decayed for training
- self.temperature = 2
- @staticmethod
- def _compute_perplexity(probs):
- marginal_probs = probs.mean(dim=0)
- perplexity = torch.exp(-torch.sum(torch.xlogy(marginal_probs, marginal_probs), dim=-1)).sum()
- return perplexity
- def forward(self, hidden_states):
- batch_size, sequence_length, hidden_size = hidden_states.shape
- # project to codevector dim
- hidden_states = self.weight_proj(hidden_states)
- hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
- if self.training:
- # sample code vector probs via gumbel in differentiateable way
- codevector_probs = nn.functional.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True)
- codevector_probs = codevector_probs.type_as(hidden_states)
- # compute perplexity
- codevector_soft_dist = torch.softmax(
- hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
- )
- perplexity = self._compute_perplexity(codevector_soft_dist)
- else:
- # take argmax in non-differentiable way
- # comptute hard codevector distribution (one hot)
- codevector_idx = hidden_states.argmax(dim=-1)
- codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
- -1, codevector_idx.view(-1, 1), 1.0
- )
- codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
- perplexity = self._compute_perplexity(codevector_probs)
- codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
- # use probs to retrieve codevectors
- codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
- codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
- codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
- return codevectors, perplexity
- @auto_docstring
- class WavLMPreTrainedModel(PreTrainedModel):
- config: WavLMConfig
- base_model_prefix = "wavlm"
- main_input_name = "input_values"
- input_modalities = "audio"
- supports_gradient_checkpointing = True
- _supports_flash_attn = False
- _supports_sdpa = False
- _supports_flex_attn = False
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- # gumbel softmax requires special init
- if isinstance(module, WavLMGumbelVectorQuantizer):
- init.normal_(module.weight_proj.weight, mean=0.0, std=1)
- init.zeros_(module.weight_proj.bias)
- init.uniform_(module.codevectors)
- elif isinstance(module, WavLMPositionalConvEmbedding):
- init.normal_(
- module.conv.weight,
- mean=0,
- std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
- )
- init.constant_(module.conv.bias, 0)
- elif isinstance(module, WavLMFeatureProjection):
- k = math.sqrt(1 / module.projection.in_features)
- init.uniform_(module.projection.weight, a=-k, b=k)
- init.uniform_(module.projection.bias, a=-k, b=k)
- elif isinstance(module, nn.Linear):
- init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- elif isinstance(module, nn.Conv1d):
- init.kaiming_normal_(module.weight)
- if module.bias is not None:
- k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
- init.uniform_(module.bias, a=-k, b=k)
- def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int, add_adapter: bool | None = None):
- """
- Computes the output length of the convolutional layers
- """
- add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
- def _conv_out_length(input_length, kernel_size, stride):
- # 1D convolutional layer output length formula taken
- # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
- return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
- for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
- input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
- if add_adapter:
- for _ in range(self.config.num_adapter_layers):
- input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
- return input_lengths
- def _get_feature_vector_attention_mask(
- self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
- ):
- # Effectively attention_mask.sum(-1), but not inplace to be able to run
- # on inference mode.
- non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
- output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
- output_lengths = output_lengths.to(torch.long)
- batch_size = attention_mask.shape[0]
- attention_mask = torch.zeros(
- (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
- )
- # these two operations makes sure that all values before the output lengths idxs are attended to
- attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
- attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
- return attention_mask
- class WavLMNoLayerNormConvLayer(GradientCheckpointingLayer):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.activation = ACT2FN[config.feat_extract_activation]
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.activation(hidden_states)
- return hidden_states
- class WavLMLayerNormConvLayer(GradientCheckpointingLayer):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
- self.activation = ACT2FN[config.feat_extract_activation]
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = hidden_states.transpose(-2, -1)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = hidden_states.transpose(-2, -1)
- hidden_states = self.activation(hidden_states)
- return hidden_states
- class WavLMGroupNormConvLayer(GradientCheckpointingLayer):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
- self.out_conv_dim = config.conv_dim[layer_id]
- self.conv = nn.Conv1d(
- self.in_conv_dim,
- self.out_conv_dim,
- kernel_size=config.conv_kernel[layer_id],
- stride=config.conv_stride[layer_id],
- bias=config.conv_bias,
- )
- self.activation = ACT2FN[config.feat_extract_activation]
- self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = self.activation(hidden_states)
- return hidden_states
- class WavLMFeatureEncoder(nn.Module):
- """Construct the features from raw audio waveform"""
- def __init__(self, config):
- super().__init__()
- if config.feat_extract_norm == "group":
- conv_layers = [WavLMGroupNormConvLayer(config, layer_id=0)] + [
- WavLMNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
- ]
- elif config.feat_extract_norm == "layer":
- conv_layers = [WavLMLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
- else:
- raise ValueError(
- f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
- )
- self.conv_layers = nn.ModuleList(conv_layers)
- self.gradient_checkpointing = False
- self._requires_grad = True
- def _freeze_parameters(self):
- for param in self.parameters():
- param.requires_grad = False
- self._requires_grad = False
- def forward(self, input_values):
- hidden_states = input_values[:, None]
- # make sure hidden_states require grad for gradient_checkpointing
- if self._requires_grad and self.training:
- hidden_states.requires_grad = True
- for conv_layer in self.conv_layers:
- hidden_states = conv_layer(hidden_states)
- return hidden_states
- class WavLMAdapterLayer(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.conv = nn.Conv1d(
- config.output_hidden_size,
- 2 * config.output_hidden_size,
- config.adapter_kernel_size,
- stride=config.adapter_stride,
- padding=1,
- )
- def forward(self, hidden_states):
- hidden_states = self.conv(hidden_states)
- hidden_states = nn.functional.glu(hidden_states, dim=1)
- return hidden_states
- class WavLMAdapter(nn.Module):
- def __init__(self, config):
- super().__init__()
- # feature dim might need to be down-projected
- if config.output_hidden_size != config.hidden_size:
- self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
- self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
- else:
- self.proj = self.proj_layer_norm = None
- self.layers = nn.ModuleList(WavLMAdapterLayer(config) for _ in range(config.num_adapter_layers))
- self.layerdrop = config.layerdrop
- def forward(self, hidden_states):
- # down project hidden_states if necessary
- if self.proj is not None and self.proj_layer_norm is not None:
- hidden_states = self.proj(hidden_states)
- hidden_states = self.proj_layer_norm(hidden_states)
- hidden_states = hidden_states.transpose(1, 2)
- for layer in self.layers:
- layerdrop_prob = np.random.random()
- if not self.training or (layerdrop_prob > self.layerdrop):
- hidden_states = layer(hidden_states)
- hidden_states = hidden_states.transpose(1, 2)
- return hidden_states
- def _compute_mask_indices(
- shape: tuple[int, int],
- mask_prob: float,
- mask_length: int,
- attention_mask: torch.LongTensor | None = None,
- min_masks: int = 0,
- ) -> np.ndarray:
- """
- Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
- ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
- CPU as part of the preprocessing during training.
- Args:
- shape: The shape for which to compute masks. This should be of a tuple of size 2 where
- the first element is the batch size and the second element is the length of the axis to span.
- mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
- independently generated mask spans of length `mask_length` is computed by
- `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
- actual percentage will be smaller.
- mask_length: size of the mask
- min_masks: minimum number of masked spans
- attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
- each batch dimension.
- """
- batch_size, sequence_length = shape
- if mask_length < 1:
- raise ValueError("`mask_length` has to be bigger than 0.")
- if mask_length > sequence_length:
- raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
- f" and `sequence_length`: {sequence_length}`"
- )
- # epsilon is used for probabilistic rounding
- epsilon = np.random.rand(1).item()
- def compute_num_masked_span(input_length):
- """Given input length, compute how many spans should be masked"""
- num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
- num_masked_span = max(num_masked_span, min_masks)
- # make sure num masked span <= sequence_length
- if num_masked_span * mask_length > sequence_length:
- num_masked_span = sequence_length // mask_length
- # make sure num_masked span is also <= input_length - (mask_length - 1)
- if input_length - (mask_length - 1) < num_masked_span:
- num_masked_span = max(input_length - (mask_length - 1), 0)
- return num_masked_span
- # compute number of masked spans in batch
- input_lengths = (
- attention_mask.detach().sum(-1).tolist()
- if attention_mask is not None
- else [sequence_length for _ in range(batch_size)]
- )
- # SpecAugment mask to fill
- spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
- spec_aug_mask_idxs = []
- max_num_masked_span = compute_num_masked_span(sequence_length)
- if max_num_masked_span == 0:
- return spec_aug_mask
- for input_length in input_lengths:
- # compute num of masked spans for this input
- num_masked_span = compute_num_masked_span(input_length)
- # get random indices to mask
- spec_aug_mask_idx = np.random.choice(
- np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
- )
- # pick first sampled index that will serve as a dummy index to pad vector
- # to ensure same dimension for all batches due to probabilistic rounding
- # Picking first sample just pads those vectors twice.
- if len(spec_aug_mask_idx) == 0:
- # this case can only happen if `input_length` is strictly smaller then
- # `sequence_length` in which case the last token has to be a padding
- # token which we can use as a dummy mask id
- dummy_mask_idx = sequence_length - 1
- else:
- dummy_mask_idx = spec_aug_mask_idx[0]
- spec_aug_mask_idx = np.concatenate(
- [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
- )
- spec_aug_mask_idxs.append(spec_aug_mask_idx)
- spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
- # expand masked indices to masked spans
- spec_aug_mask_idxs = np.broadcast_to(
- spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
- # add offset to the starting indexes so that indexes now create a span
- offsets = np.arange(mask_length)[None, None, :]
- offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
- batch_size, max_num_masked_span * mask_length
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
- # ensure that we cannot have indices larger than sequence_length
- if spec_aug_mask_idxs.max() > sequence_length - 1:
- spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
- # scatter indices to mask
- np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
- return spec_aug_mask
- WavLMBaseModelOutput = Wav2Vec2BaseModelOutput
- @auto_docstring
- class WavLMModel(WavLMPreTrainedModel):
- def __init__(self, config: WavLMConfig):
- super().__init__(config)
- self.config = config
- self.feature_extractor = WavLMFeatureEncoder(config)
- self.feature_projection = WavLMFeatureProjection(config)
- # model only needs masking vector if mask prob is > 0.0
- if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
- self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
- if config.do_stable_layer_norm:
- self.encoder = WavLMEncoderStableLayerNorm(config)
- else:
- self.encoder = WavLMEncoder(config)
- self.adapter = WavLMAdapter(config) if config.add_adapter else None
- # Initialize weights and apply final processing
- self.post_init()
- def freeze_feature_encoder(self):
- """
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
- not be updated during training.
- """
- self.feature_extractor._freeze_parameters()
- def _mask_hidden_states(
- self,
- hidden_states: torch.FloatTensor,
- mask_time_indices: torch.FloatTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- ):
- """
- Masks extracted features along time axis and/or along feature axis according to
- [SpecAugment](https://huggingface.co/papers/1904.08779).
- """
- # `config.apply_spec_augment` can set masking to False
- if not getattr(self.config, "apply_spec_augment", True):
- return hidden_states
- # generate indices & apply SpecAugment along time axis
- batch_size, sequence_length, hidden_size = hidden_states.size()
- if mask_time_indices is not None:
- # apply SpecAugment along time axis with given mask_time_indices
- hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
- elif self.config.mask_time_prob > 0 and self.training:
- mask_time_indices = _compute_mask_indices(
- (batch_size, sequence_length),
- mask_prob=self.config.mask_time_prob,
- mask_length=self.config.mask_time_length,
- attention_mask=attention_mask,
- min_masks=self.config.mask_time_min_masks,
- )
- mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
- hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
- if self.config.mask_feature_prob > 0 and self.training:
- # generate indices & apply SpecAugment along feature axis
- mask_feature_indices = _compute_mask_indices(
- (batch_size, hidden_size),
- mask_prob=self.config.mask_feature_prob,
- mask_length=self.config.mask_feature_length,
- min_masks=self.config.mask_feature_min_masks,
- )
- mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
- mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
- hidden_states[mask_feature_indices] = 0
- return hidden_states
- @auto_docstring
- def forward(
- self,
- input_values: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- mask_time_indices: torch.FloatTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | WavLMBaseModelOutput:
- r"""
- mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
- masked extracted features in *config.proj_codevector_dim* space.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- extract_features = self.feature_extractor(input_values)
- extract_features = extract_features.transpose(1, 2)
- if attention_mask is not None:
- # compute reduced attention_mask corresponding to feature vectors
- attention_mask = self._get_feature_vector_attention_mask(
- extract_features.shape[1], attention_mask, add_adapter=False
- )
- hidden_states, extract_features = self.feature_projection(extract_features)
- hidden_states = self._mask_hidden_states(
- hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
- )
- encoder_outputs = self.encoder(
- hidden_states,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = encoder_outputs[0]
- if self.adapter is not None:
- hidden_states = self.adapter(hidden_states)
- if not return_dict:
- return (hidden_states, extract_features) + encoder_outputs[1:]
- return WavLMBaseModelOutput(
- last_hidden_state=hidden_states,
- extract_features=extract_features,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- _HIDDEN_STATES_START_POSITION = 2
- @auto_docstring(
- custom_intro="""
- WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
- """
- )
- class WavLMForCTC(WavLMPreTrainedModel):
- def __init__(self, config, target_lang: str | None = None):
- r"""
- target_lang (`str`, *optional*):
- Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
- adapter.<lang>.bin. Only relevant when using an instance of [`WavLMForCTC`] with adapters. Uses 'eng' by
- default.
- """
- super().__init__(config)
- self.wavlm = WavLMModel(config)
- self.dropout = nn.Dropout(config.final_dropout)
- self.target_lang = target_lang
- if config.vocab_size is None:
- raise ValueError(
- f"You are trying to instantiate {self.__class__} with a configuration that "
- "does not define the vocabulary size of the language model head. Please "
- "instantiate the model as follows: `WavLMForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
- "or define `vocab_size` of your model's configuration."
- )
- output_hidden_size = (
- config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
- )
- self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
- # Initialize weights and apply final processing
- self.post_init()
- def tie_weights(self, **kwargs):
- """
- This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
- passing `target_lang=...` to `from_pretrained(...)`.
- This method is **not** supposed to be called by the user and is prone to be changed in the future.
- """
- if get_torch_context_manager_or_global_device() == torch.device("meta"):
- return
- # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
- # correctly load adapter layers for WavLM so that we do not have to introduce a new API to
- # [`PreTrainedModel`]. While slightly hacky, WavLM never has to tie input and output embeddings, so that it is
- # ok to repurpose this function here.
- target_lang = self.target_lang
- if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
- raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
- elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
- logger.info("By default `target_lang` is set to 'eng'.")
- elif target_lang is not None:
- self.load_adapter(target_lang, force_load=True)
- def freeze_feature_encoder(self):
- """
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
- not be updated during training.
- """
- self.wavlm.feature_extractor._freeze_parameters()
- def freeze_base_model(self):
- """
- Calling this function will disable the gradient computation for the base model so that its parameters will not
- be updated during training. Only the classification head will be updated.
- """
- for param in self.wavlm.parameters():
- param.requires_grad = False
- @auto_docstring
- def forward(
- self,
- input_values: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- labels: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple | CausalLMOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
- Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
- the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
- All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
- config.vocab_size - 1]`.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if labels is not None and labels.max() >= self.config.vocab_size:
- raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
- outputs = self.wavlm(
- input_values,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = outputs[0]
- hidden_states = self.dropout(hidden_states)
- logits = self.lm_head(hidden_states)
- loss = None
- if labels is not None:
- # retrieve loss input_lengths from attention_mask
- attention_mask = (
- attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
- )
- input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
- # assuming that padded tokens are filled with -100
- # when not being attended to
- labels_mask = labels >= 0
- target_lengths = labels_mask.sum(-1)
- flattened_targets = labels.masked_select(labels_mask)
- # ctc_loss doesn't support fp16
- log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
- with torch.backends.cudnn.flags(enabled=False):
- loss = nn.functional.ctc_loss(
- log_probs,
- flattened_targets,
- input_lengths,
- target_lengths,
- blank=self.config.pad_token_id,
- reduction=self.config.ctc_loss_reduction,
- zero_infinity=self.config.ctc_zero_infinity,
- )
- if not return_dict:
- output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
- return ((loss,) + output) if loss is not None else output
- return CausalLMOutput(
- loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
- )
- @auto_docstring(
- custom_intro="""
- WavLM Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
- SUPERB Keyword Spotting.
- """
- )
- class WavLMForSequenceClassification(WavLMPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- if hasattr(config, "add_adapter") and config.add_adapter:
- raise ValueError(
- "Sequence classification does not support the use of WavLM adapters (config.add_adapter=True)"
- )
- self.wavlm = WavLMModel(config)
- num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
- if config.use_weighted_layer_sum:
- self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
- self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
- self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- def freeze_feature_encoder(self):
- """
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
- not be updated during training.
- """
- self.wavlm.feature_extractor._freeze_parameters()
- def freeze_base_model(self):
- """
- Calling this function will disable the gradient computation for the base model so that its parameters will not
- be updated during training. Only the classification head will be updated.
- """
- for param in self.wavlm.parameters():
- param.requires_grad = False
- @auto_docstring
- def forward(
- self,
- input_values: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- labels: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple | SequenceClassifierOutput:
- r"""
- input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
- Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
- into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
- (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
- To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
- into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
- outputs = self.wavlm(
- input_values,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- if self.config.use_weighted_layer_sum:
- hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
- hidden_states = torch.stack(hidden_states, dim=1)
- norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
- hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
- else:
- hidden_states = outputs[0]
- hidden_states = self.projector(hidden_states)
- if attention_mask is None:
- pooled_output = hidden_states.mean(dim=1)
- else:
- padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
- expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_padding_mask] = 0.0
- pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
- if not return_dict:
- output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
- return ((loss,) + output) if loss is not None else output
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- if hasattr(config, "add_adapter") and config.add_adapter:
- raise ValueError(
- "Audio frame classification does not support the use of WavLM adapters (config.add_adapter=True)"
- )
- self.wavlm = WavLMModel(config)
- num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
- if config.use_weighted_layer_sum:
- self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- self.num_labels = config.num_labels
- self.post_init()
- def freeze_feature_encoder(self):
- """
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
- not be updated during training.
- """
- self.wavlm.feature_extractor._freeze_parameters()
- def freeze_base_model(self):
- """
- Calling this function will disable the gradient computation for the base model so that its parameters will not
- be updated during training. Only the classification head will be updated.
- """
- for param in self.wavlm.parameters():
- param.requires_grad = False
- @auto_docstring
- def forward(
- self,
- input_values: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | TokenClassifierOutput:
- r"""
- input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
- Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
- into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
- (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
- To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
- into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
- outputs = self.wavlm(
- input_values,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- if self.config.use_weighted_layer_sum:
- hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
- hidden_states = torch.stack(hidden_states, dim=1)
- norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
- hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
- else:
- hidden_states = outputs[0]
- logits = self.classifier(hidden_states)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
- if not return_dict:
- output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
- return output
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- class AMSoftmaxLoss(nn.Module):
- def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
- super().__init__()
- self.scale = scale
- self.margin = margin
- self.num_labels = num_labels
- self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
- self.loss = nn.CrossEntropyLoss()
- def forward(self, hidden_states, labels):
- labels = labels.flatten()
- weight = nn.functional.normalize(self.weight, dim=0)
- hidden_states = nn.functional.normalize(hidden_states, dim=1)
- cos_theta = torch.mm(hidden_states, weight)
- psi = cos_theta - self.margin
- onehot = nn.functional.one_hot(labels, self.num_labels)
- logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
- loss = self.loss(logits, labels)
- return loss
- class TDNNLayer(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
- self.out_conv_dim = config.tdnn_dim[layer_id]
- self.kernel_size = config.tdnn_kernel[layer_id]
- self.dilation = config.tdnn_dilation[layer_id]
- self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
- self.activation = nn.ReLU()
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- if is_peft_available():
- from peft.tuners.lora import LoraLayer
- if is_peft_available():
- if isinstance(self.kernel, LoraLayer):
- warnings.warn(
- "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
- "You should exclude TDNNLayer from LoRA's target modules.",
- )
- # for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
- hidden_states = hidden_states.transpose(1, 2)
- weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
- hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
- hidden_states = hidden_states.transpose(1, 2)
- hidden_states = self.activation(hidden_states)
- return hidden_states
- @auto_docstring(
- custom_intro="""
- WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
- """
- )
- class WavLMForXVector(WavLMPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.wavlm = WavLMModel(config)
- num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
- if config.use_weighted_layer_sum:
- self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
- self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
- tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
- self.tdnn = nn.ModuleList(tdnn_layers)
- self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
- self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
- self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
- self.post_init()
- def freeze_feature_encoder(self):
- """
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
- not be updated during training.
- """
- self.wavlm.feature_extractor._freeze_parameters()
- def freeze_base_model(self):
- """
- Calling this function will disable the gradient computation for the base model so that its parameters will not
- be updated during training. Only the classification head will be updated.
- """
- for param in self.wavlm.parameters():
- param.requires_grad = False
- def _get_tdnn_output_lengths(self, input_lengths: torch.LongTensor | int):
- """
- Computes the output length of the TDNN layers
- """
- def _conv_out_length(input_length, kernel_size, stride):
- # 1D convolutional layer output length formula taken
- # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
- return (input_length - kernel_size) // stride + 1
- for kernel_size in self.config.tdnn_kernel:
- input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
- return input_lengths
- @auto_docstring
- def forward(
- self,
- input_values: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- labels: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple | XVectorOutput:
- r"""
- input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
- Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
- into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
- (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
- To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
- into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
- outputs = self.wavlm(
- input_values,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- if self.config.use_weighted_layer_sum:
- hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
- hidden_states = torch.stack(hidden_states, dim=1)
- norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
- hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
- else:
- hidden_states = outputs[0]
- hidden_states = self.projector(hidden_states)
- for tdnn_layer in self.tdnn:
- hidden_states = tdnn_layer(hidden_states)
- # Statistic Pooling
- if attention_mask is None:
- mean_features = hidden_states.mean(dim=1)
- std_features = hidden_states.std(dim=1)
- else:
- feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
- tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
- mean_features = []
- std_features = []
- for i, length in enumerate(tdnn_output_lengths):
- mean_features.append(hidden_states[i, :length].mean(dim=0))
- std_features.append(hidden_states[i, :length].std(dim=0))
- mean_features = torch.stack(mean_features)
- std_features = torch.stack(std_features)
- statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
- output_embeddings = self.feature_extractor(statistic_pooling)
- logits = self.classifier(output_embeddings)
- loss = None
- if labels is not None:
- loss = self.objective(logits, labels)
- if not return_dict:
- output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
- return ((loss,) + output) if loss is not None else output
- return XVectorOutput(
- loss=loss,
- logits=logits,
- embeddings=output_embeddings,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- __all__ = [
- "WavLMForAudioFrameClassification",
- "WavLMForCTC",
- "WavLMForSequenceClassification",
- "WavLMForXVector",
- "WavLMModel",
- "WavLMPreTrainedModel",
- ]
|