modeling_wavlm.py 68 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/wavlm/modular_wavlm.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_wavlm.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. import math
  8. import warnings
  9. import numpy as np
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. from torch.nn import CrossEntropyLoss
  14. from ... import initialization as init
  15. from ...activations import ACT2FN
  16. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  17. from ...integrations.fsdp import is_fsdp_managed_module
  18. from ...modeling_layers import GradientCheckpointingLayer
  19. from ...modeling_outputs import (
  20. BaseModelOutput,
  21. CausalLMOutput,
  22. SequenceClassifierOutput,
  23. TokenClassifierOutput,
  24. Wav2Vec2BaseModelOutput,
  25. XVectorOutput,
  26. )
  27. from ...modeling_utils import PreTrainedModel, get_torch_context_manager_or_global_device
  28. from ...utils import auto_docstring, is_peft_available, logging
  29. from .configuration_wavlm import WavLMConfig
  30. logger = logging.get_logger(__name__)
  31. class WavLMSamePadLayer(nn.Module):
  32. def __init__(self, num_conv_pos_embeddings):
  33. super().__init__()
  34. self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
  35. def forward(self, hidden_states):
  36. if self.num_pad_remove > 0:
  37. hidden_states = hidden_states[:, :, : -self.num_pad_remove]
  38. return hidden_states
  39. class WavLMPositionalConvEmbedding(nn.Module):
  40. def __init__(self, config):
  41. super().__init__()
  42. self.conv = nn.Conv1d(
  43. config.hidden_size,
  44. config.hidden_size,
  45. kernel_size=config.num_conv_pos_embeddings,
  46. padding=config.num_conv_pos_embeddings // 2,
  47. groups=config.num_conv_pos_embedding_groups,
  48. )
  49. weight_norm = nn.utils.weight_norm
  50. if hasattr(nn.utils.parametrizations, "weight_norm"):
  51. weight_norm = nn.utils.parametrizations.weight_norm
  52. if is_deepspeed_zero3_enabled():
  53. import deepspeed
  54. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  55. self.conv = weight_norm(self.conv, name="weight", dim=2)
  56. if hasattr(self.conv, "parametrizations"):
  57. weight_g = self.conv.parametrizations.weight.original0
  58. weight_v = self.conv.parametrizations.weight.original1
  59. else:
  60. weight_g = self.conv.weight_g
  61. weight_v = self.conv.weight_v
  62. deepspeed.zero.register_external_parameter(self, weight_v)
  63. deepspeed.zero.register_external_parameter(self, weight_g)
  64. else:
  65. self.conv = weight_norm(self.conv, name="weight", dim=2)
  66. self.padding = WavLMSamePadLayer(config.num_conv_pos_embeddings)
  67. self.activation = ACT2FN[config.feat_extract_activation]
  68. def forward(self, hidden_states):
  69. hidden_states = hidden_states.transpose(1, 2)
  70. hidden_states = self.conv(hidden_states)
  71. hidden_states = self.padding(hidden_states)
  72. hidden_states = self.activation(hidden_states)
  73. hidden_states = hidden_states.transpose(1, 2)
  74. return hidden_states
  75. class WavLMFeatureProjection(nn.Module):
  76. def __init__(self, config):
  77. super().__init__()
  78. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  79. self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  80. self.dropout = nn.Dropout(config.feat_proj_dropout)
  81. def forward(self, hidden_states):
  82. # non-projected hidden states are needed for quantization
  83. norm_hidden_states = self.layer_norm(hidden_states)
  84. hidden_states = self.projection(norm_hidden_states)
  85. hidden_states = self.dropout(hidden_states)
  86. return hidden_states, norm_hidden_states
  87. class WavLMAttention(nn.Module):
  88. """Multi-headed attention from 'Attention Is All You Need' paper"""
  89. def __init__(
  90. self,
  91. embed_dim: int,
  92. num_heads: int,
  93. dropout: float | int = 0.0,
  94. num_buckets: int = 320,
  95. max_distance: int = 800,
  96. has_relative_position_bias: bool = True,
  97. ):
  98. super().__init__()
  99. self.embed_dim = embed_dim
  100. self.num_heads = num_heads
  101. self.dropout = dropout
  102. self.head_dim = embed_dim // num_heads
  103. if (self.head_dim * num_heads) != self.embed_dim:
  104. raise ValueError(
  105. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  106. f" and `num_heads`: {num_heads})."
  107. )
  108. self.scaling = self.head_dim**-0.5
  109. self.k_proj = nn.Linear(embed_dim, embed_dim)
  110. self.v_proj = nn.Linear(embed_dim, embed_dim)
  111. self.q_proj = nn.Linear(embed_dim, embed_dim)
  112. self.out_proj = nn.Linear(embed_dim, embed_dim)
  113. self.num_buckets = num_buckets
  114. self.max_distance = max_distance
  115. self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1))
  116. self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
  117. if has_relative_position_bias:
  118. self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads)
  119. def forward(
  120. self,
  121. hidden_states: torch.Tensor,
  122. attention_mask: torch.Tensor | None = None,
  123. position_bias: torch.Tensor | None = None,
  124. output_attentions: bool = False,
  125. index=0,
  126. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  127. """Attention layer with relative attention"""
  128. bsz, tgt_len, _ = hidden_states.size()
  129. # first pass of attention layer creates position bias
  130. if position_bias is None:
  131. position_bias = self.compute_bias(tgt_len, tgt_len)
  132. position_bias = (
  133. position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, tgt_len)
  134. )
  135. # Compute relative position bias:
  136. # 1) get reshape hidden_states
  137. gated_hidden_states = hidden_states.view(hidden_states.shape[:-1] + (self.num_heads, -1))
  138. gated_hidden_states = gated_hidden_states.permute(0, 2, 1, 3)
  139. # 2) project hidden states
  140. relative_position_proj = self.gru_rel_pos_linear(gated_hidden_states)
  141. relative_position_proj = relative_position_proj.view(gated_hidden_states.shape[:-1] + (2, 4)).sum(-1)
  142. # 3) compute gate for position bias from projected hidden states
  143. gate_a, gate_b = torch.sigmoid(relative_position_proj).chunk(2, dim=-1)
  144. gate_output = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
  145. # 4) apply gate to position bias to compute gated position_bias
  146. gated_position_bias = gate_output.view(bsz * self.num_heads, -1, 1) * position_bias
  147. gated_position_bias = gated_position_bias.view((-1, tgt_len, tgt_len))
  148. attn_output, attn_weights = self.torch_multi_head_self_attention(
  149. hidden_states, attention_mask, gated_position_bias, output_attentions
  150. )
  151. return attn_output, attn_weights, position_bias
  152. def torch_multi_head_self_attention(
  153. self,
  154. hidden_states: torch.FloatTensor,
  155. attention_mask: torch.LongTensor | torch.BoolTensor,
  156. gated_position_bias: torch.FloatTensor,
  157. output_attentions: bool,
  158. ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
  159. """simple wrapper around torch's multi_head_attention_forward function"""
  160. # self-attention assumes q = k = v
  161. query = key = value = hidden_states.transpose(0, 1)
  162. key_padding_mask = attention_mask.ne(1) if attention_mask is not None else None
  163. # disable bias and add_zero_attn
  164. bias_k = bias_v = None
  165. add_zero_attn = False
  166. # PyTorch 1.3.0 has F.multi_head_attention_forward defined
  167. # so no problem with backwards compatibility
  168. attn_output, attn_weights = F.multi_head_attention_forward(
  169. query,
  170. key,
  171. value,
  172. self.embed_dim,
  173. self.num_heads,
  174. torch.empty([0]),
  175. torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
  176. bias_k,
  177. bias_v,
  178. add_zero_attn,
  179. self.dropout,
  180. self.out_proj.weight,
  181. self.out_proj.bias,
  182. self.training,
  183. key_padding_mask,
  184. output_attentions,
  185. gated_position_bias,
  186. use_separate_proj_weight=True,
  187. q_proj_weight=self.q_proj.weight,
  188. k_proj_weight=self.k_proj.weight,
  189. v_proj_weight=self.v_proj.weight,
  190. )
  191. # [Seq_Len, Batch Size, ...] -> [Batch Size, Seq_Len, ...]
  192. attn_output = attn_output.transpose(0, 1)
  193. if attn_weights is not None:
  194. # IMPORTANT: Attention weights are averaged weights
  195. # here which should not be the case. This is an open issue
  196. # on PyTorch: https://github.com/pytorch/pytorch/issues/32590
  197. attn_weights = attn_weights[:, None].broadcast_to(
  198. attn_weights.shape[:1] + (self.num_heads,) + attn_weights.shape[1:]
  199. )
  200. return attn_output, attn_weights
  201. def compute_bias(self, query_length: int, key_length: int) -> torch.FloatTensor:
  202. context_position = torch.arange(query_length, dtype=torch.long)[:, None]
  203. memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
  204. relative_position = memory_position - context_position
  205. relative_position_bucket = self._relative_positions_bucket(relative_position)
  206. relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device)
  207. values = self.rel_attn_embed(relative_position_bucket)
  208. values = values.permute([2, 0, 1])
  209. return values
  210. def _relative_positions_bucket(self, relative_positions: torch.FloatTensor) -> torch.FloatTensor:
  211. num_buckets = self.num_buckets // 2
  212. relative_buckets = (relative_positions > 0).to(torch.long) * num_buckets
  213. relative_positions = torch.abs(relative_positions)
  214. max_exact = num_buckets // 2
  215. is_small = relative_positions < max_exact
  216. relative_positions_if_large = torch.log(relative_positions.float() / max_exact)
  217. relative_positions_if_large = relative_positions_if_large / math.log(self.max_distance / max_exact)
  218. relative_positions_if_large = relative_positions_if_large * (num_buckets - max_exact)
  219. relative_position_if_large = (max_exact + relative_positions_if_large).to(torch.long)
  220. relative_position_if_large = torch.min(
  221. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  222. )
  223. relative_buckets += torch.where(is_small, relative_positions, relative_position_if_large)
  224. return relative_buckets
  225. class WavLMFeedForward(nn.Module):
  226. def __init__(self, config):
  227. super().__init__()
  228. self.intermediate_dropout = nn.Dropout(config.activation_dropout)
  229. self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
  230. if isinstance(config.hidden_act, str):
  231. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  232. else:
  233. self.intermediate_act_fn = config.hidden_act
  234. self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
  235. self.output_dropout = nn.Dropout(config.hidden_dropout)
  236. def forward(self, hidden_states):
  237. hidden_states = self.intermediate_dense(hidden_states)
  238. hidden_states = self.intermediate_act_fn(hidden_states)
  239. hidden_states = self.intermediate_dropout(hidden_states)
  240. hidden_states = self.output_dense(hidden_states)
  241. hidden_states = self.output_dropout(hidden_states)
  242. return hidden_states
  243. class WavLMEncoderLayer(GradientCheckpointingLayer):
  244. def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
  245. super().__init__()
  246. self.attention = WavLMAttention(
  247. embed_dim=config.hidden_size,
  248. num_heads=config.num_attention_heads,
  249. dropout=config.attention_dropout,
  250. num_buckets=config.num_buckets,
  251. max_distance=config.max_bucket_distance,
  252. has_relative_position_bias=has_relative_position_bias,
  253. )
  254. self.dropout = nn.Dropout(config.hidden_dropout)
  255. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  256. self.feed_forward = WavLMFeedForward(config)
  257. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  258. def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
  259. attn_residual = hidden_states
  260. hidden_states, attn_weights, position_bias = self.attention(
  261. hidden_states,
  262. attention_mask=attention_mask,
  263. position_bias=position_bias,
  264. output_attentions=output_attentions,
  265. index=index,
  266. )
  267. hidden_states = self.dropout(hidden_states)
  268. hidden_states = attn_residual + hidden_states
  269. hidden_states = self.layer_norm(hidden_states)
  270. hidden_states = hidden_states + self.feed_forward(hidden_states)
  271. hidden_states = self.final_layer_norm(hidden_states)
  272. outputs = (hidden_states, position_bias)
  273. if output_attentions:
  274. outputs += (attn_weights,)
  275. return outputs
  276. class WavLMEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
  277. def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
  278. super().__init__()
  279. self.attention = WavLMAttention(
  280. embed_dim=config.hidden_size,
  281. num_heads=config.num_attention_heads,
  282. dropout=config.attention_dropout,
  283. num_buckets=config.num_buckets,
  284. max_distance=config.max_bucket_distance,
  285. has_relative_position_bias=has_relative_position_bias,
  286. )
  287. self.dropout = nn.Dropout(config.hidden_dropout)
  288. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  289. self.feed_forward = WavLMFeedForward(config)
  290. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  291. def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
  292. attn_residual = hidden_states
  293. hidden_states = self.layer_norm(hidden_states)
  294. hidden_states, attn_weights, position_bias = self.attention(
  295. hidden_states,
  296. attention_mask=attention_mask,
  297. position_bias=position_bias,
  298. output_attentions=output_attentions,
  299. )
  300. hidden_states = self.dropout(hidden_states)
  301. hidden_states = attn_residual + hidden_states
  302. hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
  303. outputs = (hidden_states, position_bias)
  304. if output_attentions:
  305. outputs += (attn_weights,)
  306. return outputs
  307. class WavLMEncoder(nn.Module):
  308. def __init__(self, config):
  309. super().__init__()
  310. self.config = config
  311. self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
  312. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  313. self.dropout = nn.Dropout(config.hidden_dropout)
  314. self.layers = nn.ModuleList(
  315. [WavLMEncoderLayer(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers)]
  316. )
  317. self.gradient_checkpointing = False
  318. def forward(
  319. self,
  320. hidden_states,
  321. attention_mask=None,
  322. output_attentions=False,
  323. output_hidden_states=False,
  324. return_dict=True,
  325. ):
  326. all_hidden_states = () if output_hidden_states else None
  327. all_self_attentions = () if output_attentions else None
  328. if attention_mask is not None:
  329. # make sure padded tokens output 0
  330. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  331. hidden_states[~expand_attention_mask] = 0
  332. position_embeddings = self.pos_conv_embed(hidden_states)
  333. hidden_states = hidden_states + position_embeddings
  334. hidden_states = self.layer_norm(hidden_states)
  335. hidden_states = self.dropout(hidden_states)
  336. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  337. position_bias = None
  338. for i, layer in enumerate(self.layers):
  339. if output_hidden_states:
  340. all_hidden_states = all_hidden_states + (hidden_states,)
  341. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  342. dropout_probability = torch.rand([])
  343. skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
  344. if not skip_the_layer or synced_gpus:
  345. # under fsdp or deepspeed zero3 all gpus must run in sync
  346. layer_outputs = layer(
  347. hidden_states,
  348. attention_mask=attention_mask,
  349. position_bias=position_bias,
  350. output_attentions=output_attentions,
  351. index=i,
  352. )
  353. hidden_states, position_bias = layer_outputs[:2]
  354. if skip_the_layer:
  355. layer_outputs = (None, None, None)
  356. if output_attentions:
  357. all_self_attentions = all_self_attentions + (layer_outputs[2],)
  358. if output_hidden_states:
  359. all_hidden_states = all_hidden_states + (hidden_states,)
  360. if not return_dict:
  361. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  362. return BaseModelOutput(
  363. last_hidden_state=hidden_states,
  364. hidden_states=all_hidden_states,
  365. attentions=all_self_attentions,
  366. )
  367. class WavLMEncoderStableLayerNorm(nn.Module):
  368. def __init__(self, config):
  369. super().__init__()
  370. self.config = config
  371. self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
  372. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  373. self.dropout = nn.Dropout(config.hidden_dropout)
  374. self.layers = nn.ModuleList(
  375. [
  376. WavLMEncoderLayerStableLayerNorm(config, has_relative_position_bias=(i == 0))
  377. for i in range(config.num_hidden_layers)
  378. ]
  379. )
  380. self.gradient_checkpointing = False
  381. def forward(
  382. self,
  383. hidden_states,
  384. attention_mask=None,
  385. output_attentions=False,
  386. output_hidden_states=False,
  387. return_dict=True,
  388. ):
  389. all_hidden_states = () if output_hidden_states else None
  390. all_self_attentions = () if output_attentions else None
  391. if attention_mask is not None:
  392. # make sure padded tokens are not attended to
  393. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  394. hidden_states[~expand_attention_mask] = 0
  395. position_embeddings = self.pos_conv_embed(hidden_states)
  396. hidden_states = hidden_states + position_embeddings
  397. hidden_states = self.dropout(hidden_states)
  398. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  399. position_bias = None
  400. for i, layer in enumerate(self.layers):
  401. if output_hidden_states:
  402. all_hidden_states = all_hidden_states + (hidden_states,)
  403. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  404. dropout_probability = torch.rand([])
  405. skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
  406. if not skip_the_layer or synced_gpus:
  407. # under fsdp or deepspeed zero3 all gpus must run in sync
  408. # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
  409. layer_outputs = layer(
  410. hidden_states,
  411. attention_mask=attention_mask,
  412. output_attentions=output_attentions,
  413. position_bias=position_bias,
  414. )
  415. hidden_states, position_bias = layer_outputs[:2]
  416. if skip_the_layer:
  417. layer_outputs = (None, None, None)
  418. if output_attentions:
  419. all_self_attentions = all_self_attentions + (layer_outputs[2],)
  420. hidden_states = self.layer_norm(hidden_states)
  421. if output_hidden_states:
  422. all_hidden_states = all_hidden_states + (hidden_states,)
  423. if not return_dict:
  424. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  425. return BaseModelOutput(
  426. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
  427. )
  428. class WavLMGumbelVectorQuantizer(nn.Module):
  429. """
  430. Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
  431. GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
  432. """
  433. def __init__(self, config):
  434. super().__init__()
  435. self.num_groups = config.num_codevector_groups
  436. self.num_vars = config.num_codevectors_per_group
  437. if config.codevector_dim % self.num_groups != 0:
  438. raise ValueError(
  439. f"`config.codevector_dim {config.codevector_dim} must be divisible"
  440. f" by `config.num_codevector_groups` {self.num_groups} "
  441. "for concatenation."
  442. )
  443. # storage for codebook variables (codewords)
  444. self.codevectors = nn.Parameter(
  445. torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
  446. )
  447. self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
  448. # can be decayed for training
  449. self.temperature = 2
  450. @staticmethod
  451. def _compute_perplexity(probs):
  452. marginal_probs = probs.mean(dim=0)
  453. perplexity = torch.exp(-torch.sum(torch.xlogy(marginal_probs, marginal_probs), dim=-1)).sum()
  454. return perplexity
  455. def forward(self, hidden_states):
  456. batch_size, sequence_length, hidden_size = hidden_states.shape
  457. # project to codevector dim
  458. hidden_states = self.weight_proj(hidden_states)
  459. hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
  460. if self.training:
  461. # sample code vector probs via gumbel in differentiateable way
  462. codevector_probs = nn.functional.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True)
  463. codevector_probs = codevector_probs.type_as(hidden_states)
  464. # compute perplexity
  465. codevector_soft_dist = torch.softmax(
  466. hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
  467. )
  468. perplexity = self._compute_perplexity(codevector_soft_dist)
  469. else:
  470. # take argmax in non-differentiable way
  471. # comptute hard codevector distribution (one hot)
  472. codevector_idx = hidden_states.argmax(dim=-1)
  473. codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
  474. -1, codevector_idx.view(-1, 1), 1.0
  475. )
  476. codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
  477. perplexity = self._compute_perplexity(codevector_probs)
  478. codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
  479. # use probs to retrieve codevectors
  480. codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
  481. codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
  482. codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
  483. return codevectors, perplexity
  484. @auto_docstring
  485. class WavLMPreTrainedModel(PreTrainedModel):
  486. config: WavLMConfig
  487. base_model_prefix = "wavlm"
  488. main_input_name = "input_values"
  489. input_modalities = "audio"
  490. supports_gradient_checkpointing = True
  491. _supports_flash_attn = False
  492. _supports_sdpa = False
  493. _supports_flex_attn = False
  494. @torch.no_grad()
  495. def _init_weights(self, module):
  496. """Initialize the weights"""
  497. # gumbel softmax requires special init
  498. if isinstance(module, WavLMGumbelVectorQuantizer):
  499. init.normal_(module.weight_proj.weight, mean=0.0, std=1)
  500. init.zeros_(module.weight_proj.bias)
  501. init.uniform_(module.codevectors)
  502. elif isinstance(module, WavLMPositionalConvEmbedding):
  503. init.normal_(
  504. module.conv.weight,
  505. mean=0,
  506. std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
  507. )
  508. init.constant_(module.conv.bias, 0)
  509. elif isinstance(module, WavLMFeatureProjection):
  510. k = math.sqrt(1 / module.projection.in_features)
  511. init.uniform_(module.projection.weight, a=-k, b=k)
  512. init.uniform_(module.projection.bias, a=-k, b=k)
  513. elif isinstance(module, nn.Linear):
  514. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  515. if module.bias is not None:
  516. init.zeros_(module.bias)
  517. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  518. init.zeros_(module.bias)
  519. init.ones_(module.weight)
  520. elif isinstance(module, nn.Conv1d):
  521. init.kaiming_normal_(module.weight)
  522. if module.bias is not None:
  523. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  524. init.uniform_(module.bias, a=-k, b=k)
  525. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int, add_adapter: bool | None = None):
  526. """
  527. Computes the output length of the convolutional layers
  528. """
  529. add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
  530. def _conv_out_length(input_length, kernel_size, stride):
  531. # 1D convolutional layer output length formula taken
  532. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  533. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  534. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  535. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  536. if add_adapter:
  537. for _ in range(self.config.num_adapter_layers):
  538. input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
  539. return input_lengths
  540. def _get_feature_vector_attention_mask(
  541. self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
  542. ):
  543. # Effectively attention_mask.sum(-1), but not inplace to be able to run
  544. # on inference mode.
  545. non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
  546. output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
  547. output_lengths = output_lengths.to(torch.long)
  548. batch_size = attention_mask.shape[0]
  549. attention_mask = torch.zeros(
  550. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  551. )
  552. # these two operations makes sure that all values before the output lengths idxs are attended to
  553. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  554. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  555. return attention_mask
  556. class WavLMNoLayerNormConvLayer(GradientCheckpointingLayer):
  557. def __init__(self, config, layer_id=0):
  558. super().__init__()
  559. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  560. self.out_conv_dim = config.conv_dim[layer_id]
  561. self.conv = nn.Conv1d(
  562. self.in_conv_dim,
  563. self.out_conv_dim,
  564. kernel_size=config.conv_kernel[layer_id],
  565. stride=config.conv_stride[layer_id],
  566. bias=config.conv_bias,
  567. )
  568. self.activation = ACT2FN[config.feat_extract_activation]
  569. def forward(self, hidden_states):
  570. hidden_states = self.conv(hidden_states)
  571. hidden_states = self.activation(hidden_states)
  572. return hidden_states
  573. class WavLMLayerNormConvLayer(GradientCheckpointingLayer):
  574. def __init__(self, config, layer_id=0):
  575. super().__init__()
  576. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  577. self.out_conv_dim = config.conv_dim[layer_id]
  578. self.conv = nn.Conv1d(
  579. self.in_conv_dim,
  580. self.out_conv_dim,
  581. kernel_size=config.conv_kernel[layer_id],
  582. stride=config.conv_stride[layer_id],
  583. bias=config.conv_bias,
  584. )
  585. self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
  586. self.activation = ACT2FN[config.feat_extract_activation]
  587. def forward(self, hidden_states):
  588. hidden_states = self.conv(hidden_states)
  589. hidden_states = hidden_states.transpose(-2, -1)
  590. hidden_states = self.layer_norm(hidden_states)
  591. hidden_states = hidden_states.transpose(-2, -1)
  592. hidden_states = self.activation(hidden_states)
  593. return hidden_states
  594. class WavLMGroupNormConvLayer(GradientCheckpointingLayer):
  595. def __init__(self, config, layer_id=0):
  596. super().__init__()
  597. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  598. self.out_conv_dim = config.conv_dim[layer_id]
  599. self.conv = nn.Conv1d(
  600. self.in_conv_dim,
  601. self.out_conv_dim,
  602. kernel_size=config.conv_kernel[layer_id],
  603. stride=config.conv_stride[layer_id],
  604. bias=config.conv_bias,
  605. )
  606. self.activation = ACT2FN[config.feat_extract_activation]
  607. self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
  608. def forward(self, hidden_states):
  609. hidden_states = self.conv(hidden_states)
  610. hidden_states = self.layer_norm(hidden_states)
  611. hidden_states = self.activation(hidden_states)
  612. return hidden_states
  613. class WavLMFeatureEncoder(nn.Module):
  614. """Construct the features from raw audio waveform"""
  615. def __init__(self, config):
  616. super().__init__()
  617. if config.feat_extract_norm == "group":
  618. conv_layers = [WavLMGroupNormConvLayer(config, layer_id=0)] + [
  619. WavLMNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
  620. ]
  621. elif config.feat_extract_norm == "layer":
  622. conv_layers = [WavLMLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
  623. else:
  624. raise ValueError(
  625. f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
  626. )
  627. self.conv_layers = nn.ModuleList(conv_layers)
  628. self.gradient_checkpointing = False
  629. self._requires_grad = True
  630. def _freeze_parameters(self):
  631. for param in self.parameters():
  632. param.requires_grad = False
  633. self._requires_grad = False
  634. def forward(self, input_values):
  635. hidden_states = input_values[:, None]
  636. # make sure hidden_states require grad for gradient_checkpointing
  637. if self._requires_grad and self.training:
  638. hidden_states.requires_grad = True
  639. for conv_layer in self.conv_layers:
  640. hidden_states = conv_layer(hidden_states)
  641. return hidden_states
  642. class WavLMAdapterLayer(nn.Module):
  643. def __init__(self, config):
  644. super().__init__()
  645. self.conv = nn.Conv1d(
  646. config.output_hidden_size,
  647. 2 * config.output_hidden_size,
  648. config.adapter_kernel_size,
  649. stride=config.adapter_stride,
  650. padding=1,
  651. )
  652. def forward(self, hidden_states):
  653. hidden_states = self.conv(hidden_states)
  654. hidden_states = nn.functional.glu(hidden_states, dim=1)
  655. return hidden_states
  656. class WavLMAdapter(nn.Module):
  657. def __init__(self, config):
  658. super().__init__()
  659. # feature dim might need to be down-projected
  660. if config.output_hidden_size != config.hidden_size:
  661. self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
  662. self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
  663. else:
  664. self.proj = self.proj_layer_norm = None
  665. self.layers = nn.ModuleList(WavLMAdapterLayer(config) for _ in range(config.num_adapter_layers))
  666. self.layerdrop = config.layerdrop
  667. def forward(self, hidden_states):
  668. # down project hidden_states if necessary
  669. if self.proj is not None and self.proj_layer_norm is not None:
  670. hidden_states = self.proj(hidden_states)
  671. hidden_states = self.proj_layer_norm(hidden_states)
  672. hidden_states = hidden_states.transpose(1, 2)
  673. for layer in self.layers:
  674. layerdrop_prob = np.random.random()
  675. if not self.training or (layerdrop_prob > self.layerdrop):
  676. hidden_states = layer(hidden_states)
  677. hidden_states = hidden_states.transpose(1, 2)
  678. return hidden_states
  679. def _compute_mask_indices(
  680. shape: tuple[int, int],
  681. mask_prob: float,
  682. mask_length: int,
  683. attention_mask: torch.LongTensor | None = None,
  684. min_masks: int = 0,
  685. ) -> np.ndarray:
  686. """
  687. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  688. ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  689. CPU as part of the preprocessing during training.
  690. Args:
  691. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  692. the first element is the batch size and the second element is the length of the axis to span.
  693. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  694. independently generated mask spans of length `mask_length` is computed by
  695. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  696. actual percentage will be smaller.
  697. mask_length: size of the mask
  698. min_masks: minimum number of masked spans
  699. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  700. each batch dimension.
  701. """
  702. batch_size, sequence_length = shape
  703. if mask_length < 1:
  704. raise ValueError("`mask_length` has to be bigger than 0.")
  705. if mask_length > sequence_length:
  706. raise ValueError(
  707. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  708. f" and `sequence_length`: {sequence_length}`"
  709. )
  710. # epsilon is used for probabilistic rounding
  711. epsilon = np.random.rand(1).item()
  712. def compute_num_masked_span(input_length):
  713. """Given input length, compute how many spans should be masked"""
  714. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  715. num_masked_span = max(num_masked_span, min_masks)
  716. # make sure num masked span <= sequence_length
  717. if num_masked_span * mask_length > sequence_length:
  718. num_masked_span = sequence_length // mask_length
  719. # make sure num_masked span is also <= input_length - (mask_length - 1)
  720. if input_length - (mask_length - 1) < num_masked_span:
  721. num_masked_span = max(input_length - (mask_length - 1), 0)
  722. return num_masked_span
  723. # compute number of masked spans in batch
  724. input_lengths = (
  725. attention_mask.detach().sum(-1).tolist()
  726. if attention_mask is not None
  727. else [sequence_length for _ in range(batch_size)]
  728. )
  729. # SpecAugment mask to fill
  730. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  731. spec_aug_mask_idxs = []
  732. max_num_masked_span = compute_num_masked_span(sequence_length)
  733. if max_num_masked_span == 0:
  734. return spec_aug_mask
  735. for input_length in input_lengths:
  736. # compute num of masked spans for this input
  737. num_masked_span = compute_num_masked_span(input_length)
  738. # get random indices to mask
  739. spec_aug_mask_idx = np.random.choice(
  740. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  741. )
  742. # pick first sampled index that will serve as a dummy index to pad vector
  743. # to ensure same dimension for all batches due to probabilistic rounding
  744. # Picking first sample just pads those vectors twice.
  745. if len(spec_aug_mask_idx) == 0:
  746. # this case can only happen if `input_length` is strictly smaller then
  747. # `sequence_length` in which case the last token has to be a padding
  748. # token which we can use as a dummy mask id
  749. dummy_mask_idx = sequence_length - 1
  750. else:
  751. dummy_mask_idx = spec_aug_mask_idx[0]
  752. spec_aug_mask_idx = np.concatenate(
  753. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  754. )
  755. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  756. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  757. # expand masked indices to masked spans
  758. spec_aug_mask_idxs = np.broadcast_to(
  759. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  760. )
  761. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  762. # add offset to the starting indexes so that indexes now create a span
  763. offsets = np.arange(mask_length)[None, None, :]
  764. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  765. batch_size, max_num_masked_span * mask_length
  766. )
  767. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  768. # ensure that we cannot have indices larger than sequence_length
  769. if spec_aug_mask_idxs.max() > sequence_length - 1:
  770. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  771. # scatter indices to mask
  772. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  773. return spec_aug_mask
  774. WavLMBaseModelOutput = Wav2Vec2BaseModelOutput
  775. @auto_docstring
  776. class WavLMModel(WavLMPreTrainedModel):
  777. def __init__(self, config: WavLMConfig):
  778. super().__init__(config)
  779. self.config = config
  780. self.feature_extractor = WavLMFeatureEncoder(config)
  781. self.feature_projection = WavLMFeatureProjection(config)
  782. # model only needs masking vector if mask prob is > 0.0
  783. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  784. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  785. if config.do_stable_layer_norm:
  786. self.encoder = WavLMEncoderStableLayerNorm(config)
  787. else:
  788. self.encoder = WavLMEncoder(config)
  789. self.adapter = WavLMAdapter(config) if config.add_adapter else None
  790. # Initialize weights and apply final processing
  791. self.post_init()
  792. def freeze_feature_encoder(self):
  793. """
  794. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  795. not be updated during training.
  796. """
  797. self.feature_extractor._freeze_parameters()
  798. def _mask_hidden_states(
  799. self,
  800. hidden_states: torch.FloatTensor,
  801. mask_time_indices: torch.FloatTensor | None = None,
  802. attention_mask: torch.LongTensor | None = None,
  803. ):
  804. """
  805. Masks extracted features along time axis and/or along feature axis according to
  806. [SpecAugment](https://huggingface.co/papers/1904.08779).
  807. """
  808. # `config.apply_spec_augment` can set masking to False
  809. if not getattr(self.config, "apply_spec_augment", True):
  810. return hidden_states
  811. # generate indices & apply SpecAugment along time axis
  812. batch_size, sequence_length, hidden_size = hidden_states.size()
  813. if mask_time_indices is not None:
  814. # apply SpecAugment along time axis with given mask_time_indices
  815. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  816. elif self.config.mask_time_prob > 0 and self.training:
  817. mask_time_indices = _compute_mask_indices(
  818. (batch_size, sequence_length),
  819. mask_prob=self.config.mask_time_prob,
  820. mask_length=self.config.mask_time_length,
  821. attention_mask=attention_mask,
  822. min_masks=self.config.mask_time_min_masks,
  823. )
  824. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  825. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  826. if self.config.mask_feature_prob > 0 and self.training:
  827. # generate indices & apply SpecAugment along feature axis
  828. mask_feature_indices = _compute_mask_indices(
  829. (batch_size, hidden_size),
  830. mask_prob=self.config.mask_feature_prob,
  831. mask_length=self.config.mask_feature_length,
  832. min_masks=self.config.mask_feature_min_masks,
  833. )
  834. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  835. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  836. hidden_states[mask_feature_indices] = 0
  837. return hidden_states
  838. @auto_docstring
  839. def forward(
  840. self,
  841. input_values: torch.Tensor | None,
  842. attention_mask: torch.Tensor | None = None,
  843. mask_time_indices: torch.FloatTensor | None = None,
  844. output_attentions: bool | None = None,
  845. output_hidden_states: bool | None = None,
  846. return_dict: bool | None = None,
  847. **kwargs,
  848. ) -> tuple | WavLMBaseModelOutput:
  849. r"""
  850. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  851. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  852. masked extracted features in *config.proj_codevector_dim* space.
  853. """
  854. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  855. output_hidden_states = (
  856. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  857. )
  858. return_dict = return_dict if return_dict is not None else self.config.return_dict
  859. extract_features = self.feature_extractor(input_values)
  860. extract_features = extract_features.transpose(1, 2)
  861. if attention_mask is not None:
  862. # compute reduced attention_mask corresponding to feature vectors
  863. attention_mask = self._get_feature_vector_attention_mask(
  864. extract_features.shape[1], attention_mask, add_adapter=False
  865. )
  866. hidden_states, extract_features = self.feature_projection(extract_features)
  867. hidden_states = self._mask_hidden_states(
  868. hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
  869. )
  870. encoder_outputs = self.encoder(
  871. hidden_states,
  872. attention_mask=attention_mask,
  873. output_attentions=output_attentions,
  874. output_hidden_states=output_hidden_states,
  875. return_dict=return_dict,
  876. )
  877. hidden_states = encoder_outputs[0]
  878. if self.adapter is not None:
  879. hidden_states = self.adapter(hidden_states)
  880. if not return_dict:
  881. return (hidden_states, extract_features) + encoder_outputs[1:]
  882. return WavLMBaseModelOutput(
  883. last_hidden_state=hidden_states,
  884. extract_features=extract_features,
  885. hidden_states=encoder_outputs.hidden_states,
  886. attentions=encoder_outputs.attentions,
  887. )
  888. _HIDDEN_STATES_START_POSITION = 2
  889. @auto_docstring(
  890. custom_intro="""
  891. WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
  892. """
  893. )
  894. class WavLMForCTC(WavLMPreTrainedModel):
  895. def __init__(self, config, target_lang: str | None = None):
  896. r"""
  897. target_lang (`str`, *optional*):
  898. Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
  899. adapter.<lang>.bin. Only relevant when using an instance of [`WavLMForCTC`] with adapters. Uses 'eng' by
  900. default.
  901. """
  902. super().__init__(config)
  903. self.wavlm = WavLMModel(config)
  904. self.dropout = nn.Dropout(config.final_dropout)
  905. self.target_lang = target_lang
  906. if config.vocab_size is None:
  907. raise ValueError(
  908. f"You are trying to instantiate {self.__class__} with a configuration that "
  909. "does not define the vocabulary size of the language model head. Please "
  910. "instantiate the model as follows: `WavLMForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
  911. "or define `vocab_size` of your model's configuration."
  912. )
  913. output_hidden_size = (
  914. config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
  915. )
  916. self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
  917. # Initialize weights and apply final processing
  918. self.post_init()
  919. def tie_weights(self, **kwargs):
  920. """
  921. This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
  922. passing `target_lang=...` to `from_pretrained(...)`.
  923. This method is **not** supposed to be called by the user and is prone to be changed in the future.
  924. """
  925. if get_torch_context_manager_or_global_device() == torch.device("meta"):
  926. return
  927. # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
  928. # correctly load adapter layers for WavLM so that we do not have to introduce a new API to
  929. # [`PreTrainedModel`]. While slightly hacky, WavLM never has to tie input and output embeddings, so that it is
  930. # ok to repurpose this function here.
  931. target_lang = self.target_lang
  932. if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
  933. raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
  934. elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
  935. logger.info("By default `target_lang` is set to 'eng'.")
  936. elif target_lang is not None:
  937. self.load_adapter(target_lang, force_load=True)
  938. def freeze_feature_encoder(self):
  939. """
  940. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  941. not be updated during training.
  942. """
  943. self.wavlm.feature_extractor._freeze_parameters()
  944. def freeze_base_model(self):
  945. """
  946. Calling this function will disable the gradient computation for the base model so that its parameters will not
  947. be updated during training. Only the classification head will be updated.
  948. """
  949. for param in self.wavlm.parameters():
  950. param.requires_grad = False
  951. @auto_docstring
  952. def forward(
  953. self,
  954. input_values: torch.Tensor | None,
  955. attention_mask: torch.Tensor | None = None,
  956. output_attentions: bool | None = None,
  957. output_hidden_states: bool | None = None,
  958. return_dict: bool | None = None,
  959. labels: torch.Tensor | None = None,
  960. **kwargs,
  961. ) -> tuple | CausalLMOutput:
  962. r"""
  963. labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
  964. Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
  965. the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
  966. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  967. config.vocab_size - 1]`.
  968. """
  969. return_dict = return_dict if return_dict is not None else self.config.return_dict
  970. if labels is not None and labels.max() >= self.config.vocab_size:
  971. raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
  972. outputs = self.wavlm(
  973. input_values,
  974. attention_mask=attention_mask,
  975. output_attentions=output_attentions,
  976. output_hidden_states=output_hidden_states,
  977. return_dict=return_dict,
  978. )
  979. hidden_states = outputs[0]
  980. hidden_states = self.dropout(hidden_states)
  981. logits = self.lm_head(hidden_states)
  982. loss = None
  983. if labels is not None:
  984. # retrieve loss input_lengths from attention_mask
  985. attention_mask = (
  986. attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
  987. )
  988. input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  989. # assuming that padded tokens are filled with -100
  990. # when not being attended to
  991. labels_mask = labels >= 0
  992. target_lengths = labels_mask.sum(-1)
  993. flattened_targets = labels.masked_select(labels_mask)
  994. # ctc_loss doesn't support fp16
  995. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  996. with torch.backends.cudnn.flags(enabled=False):
  997. loss = nn.functional.ctc_loss(
  998. log_probs,
  999. flattened_targets,
  1000. input_lengths,
  1001. target_lengths,
  1002. blank=self.config.pad_token_id,
  1003. reduction=self.config.ctc_loss_reduction,
  1004. zero_infinity=self.config.ctc_zero_infinity,
  1005. )
  1006. if not return_dict:
  1007. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1008. return ((loss,) + output) if loss is not None else output
  1009. return CausalLMOutput(
  1010. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  1011. )
  1012. @auto_docstring(
  1013. custom_intro="""
  1014. WavLM Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
  1015. SUPERB Keyword Spotting.
  1016. """
  1017. )
  1018. class WavLMForSequenceClassification(WavLMPreTrainedModel):
  1019. def __init__(self, config):
  1020. super().__init__(config)
  1021. if hasattr(config, "add_adapter") and config.add_adapter:
  1022. raise ValueError(
  1023. "Sequence classification does not support the use of WavLM adapters (config.add_adapter=True)"
  1024. )
  1025. self.wavlm = WavLMModel(config)
  1026. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1027. if config.use_weighted_layer_sum:
  1028. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1029. self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
  1030. self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
  1031. # Initialize weights and apply final processing
  1032. self.post_init()
  1033. def freeze_feature_encoder(self):
  1034. """
  1035. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1036. not be updated during training.
  1037. """
  1038. self.wavlm.feature_extractor._freeze_parameters()
  1039. def freeze_base_model(self):
  1040. """
  1041. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1042. be updated during training. Only the classification head will be updated.
  1043. """
  1044. for param in self.wavlm.parameters():
  1045. param.requires_grad = False
  1046. @auto_docstring
  1047. def forward(
  1048. self,
  1049. input_values: torch.Tensor | None,
  1050. attention_mask: torch.Tensor | None = None,
  1051. output_attentions: bool | None = None,
  1052. output_hidden_states: bool | None = None,
  1053. return_dict: bool | None = None,
  1054. labels: torch.Tensor | None = None,
  1055. **kwargs,
  1056. ) -> tuple | SequenceClassifierOutput:
  1057. r"""
  1058. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1059. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1060. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1061. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1062. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1063. into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
  1064. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1065. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1066. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1067. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1068. """
  1069. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1070. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1071. outputs = self.wavlm(
  1072. input_values,
  1073. attention_mask=attention_mask,
  1074. output_attentions=output_attentions,
  1075. output_hidden_states=output_hidden_states,
  1076. return_dict=return_dict,
  1077. )
  1078. if self.config.use_weighted_layer_sum:
  1079. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1080. hidden_states = torch.stack(hidden_states, dim=1)
  1081. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1082. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1083. else:
  1084. hidden_states = outputs[0]
  1085. hidden_states = self.projector(hidden_states)
  1086. if attention_mask is None:
  1087. pooled_output = hidden_states.mean(dim=1)
  1088. else:
  1089. padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
  1090. expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  1091. hidden_states[~expand_padding_mask] = 0.0
  1092. pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
  1093. logits = self.classifier(pooled_output)
  1094. loss = None
  1095. if labels is not None:
  1096. loss_fct = CrossEntropyLoss()
  1097. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1098. if not return_dict:
  1099. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1100. return ((loss,) + output) if loss is not None else output
  1101. return SequenceClassifierOutput(
  1102. loss=loss,
  1103. logits=logits,
  1104. hidden_states=outputs.hidden_states,
  1105. attentions=outputs.attentions,
  1106. )
  1107. @auto_docstring
  1108. class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
  1109. def __init__(self, config):
  1110. super().__init__(config)
  1111. if hasattr(config, "add_adapter") and config.add_adapter:
  1112. raise ValueError(
  1113. "Audio frame classification does not support the use of WavLM adapters (config.add_adapter=True)"
  1114. )
  1115. self.wavlm = WavLMModel(config)
  1116. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1117. if config.use_weighted_layer_sum:
  1118. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1119. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1120. self.num_labels = config.num_labels
  1121. self.post_init()
  1122. def freeze_feature_encoder(self):
  1123. """
  1124. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1125. not be updated during training.
  1126. """
  1127. self.wavlm.feature_extractor._freeze_parameters()
  1128. def freeze_base_model(self):
  1129. """
  1130. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1131. be updated during training. Only the classification head will be updated.
  1132. """
  1133. for param in self.wavlm.parameters():
  1134. param.requires_grad = False
  1135. @auto_docstring
  1136. def forward(
  1137. self,
  1138. input_values: torch.Tensor | None,
  1139. attention_mask: torch.Tensor | None = None,
  1140. labels: torch.Tensor | None = None,
  1141. output_attentions: bool | None = None,
  1142. output_hidden_states: bool | None = None,
  1143. return_dict: bool | None = None,
  1144. **kwargs,
  1145. ) -> tuple | TokenClassifierOutput:
  1146. r"""
  1147. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1148. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1149. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1150. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1151. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1152. into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
  1153. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1154. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1155. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1156. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1157. """
  1158. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1159. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1160. outputs = self.wavlm(
  1161. input_values,
  1162. attention_mask=attention_mask,
  1163. output_attentions=output_attentions,
  1164. output_hidden_states=output_hidden_states,
  1165. return_dict=return_dict,
  1166. )
  1167. if self.config.use_weighted_layer_sum:
  1168. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1169. hidden_states = torch.stack(hidden_states, dim=1)
  1170. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1171. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1172. else:
  1173. hidden_states = outputs[0]
  1174. logits = self.classifier(hidden_states)
  1175. loss = None
  1176. if labels is not None:
  1177. loss_fct = CrossEntropyLoss()
  1178. loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
  1179. if not return_dict:
  1180. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1181. return output
  1182. return TokenClassifierOutput(
  1183. loss=loss,
  1184. logits=logits,
  1185. hidden_states=outputs.hidden_states,
  1186. attentions=outputs.attentions,
  1187. )
  1188. class AMSoftmaxLoss(nn.Module):
  1189. def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
  1190. super().__init__()
  1191. self.scale = scale
  1192. self.margin = margin
  1193. self.num_labels = num_labels
  1194. self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
  1195. self.loss = nn.CrossEntropyLoss()
  1196. def forward(self, hidden_states, labels):
  1197. labels = labels.flatten()
  1198. weight = nn.functional.normalize(self.weight, dim=0)
  1199. hidden_states = nn.functional.normalize(hidden_states, dim=1)
  1200. cos_theta = torch.mm(hidden_states, weight)
  1201. psi = cos_theta - self.margin
  1202. onehot = nn.functional.one_hot(labels, self.num_labels)
  1203. logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
  1204. loss = self.loss(logits, labels)
  1205. return loss
  1206. class TDNNLayer(nn.Module):
  1207. def __init__(self, config, layer_id=0):
  1208. super().__init__()
  1209. self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
  1210. self.out_conv_dim = config.tdnn_dim[layer_id]
  1211. self.kernel_size = config.tdnn_kernel[layer_id]
  1212. self.dilation = config.tdnn_dilation[layer_id]
  1213. self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
  1214. self.activation = nn.ReLU()
  1215. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  1216. if is_peft_available():
  1217. from peft.tuners.lora import LoraLayer
  1218. if is_peft_available():
  1219. if isinstance(self.kernel, LoraLayer):
  1220. warnings.warn(
  1221. "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
  1222. "You should exclude TDNNLayer from LoRA's target modules.",
  1223. )
  1224. # for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
  1225. hidden_states = hidden_states.transpose(1, 2)
  1226. weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
  1227. hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
  1228. hidden_states = hidden_states.transpose(1, 2)
  1229. hidden_states = self.activation(hidden_states)
  1230. return hidden_states
  1231. @auto_docstring(
  1232. custom_intro="""
  1233. WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
  1234. """
  1235. )
  1236. class WavLMForXVector(WavLMPreTrainedModel):
  1237. def __init__(self, config):
  1238. super().__init__(config)
  1239. self.wavlm = WavLMModel(config)
  1240. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1241. if config.use_weighted_layer_sum:
  1242. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1243. self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
  1244. tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
  1245. self.tdnn = nn.ModuleList(tdnn_layers)
  1246. self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
  1247. self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
  1248. self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
  1249. self.post_init()
  1250. def freeze_feature_encoder(self):
  1251. """
  1252. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1253. not be updated during training.
  1254. """
  1255. self.wavlm.feature_extractor._freeze_parameters()
  1256. def freeze_base_model(self):
  1257. """
  1258. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1259. be updated during training. Only the classification head will be updated.
  1260. """
  1261. for param in self.wavlm.parameters():
  1262. param.requires_grad = False
  1263. def _get_tdnn_output_lengths(self, input_lengths: torch.LongTensor | int):
  1264. """
  1265. Computes the output length of the TDNN layers
  1266. """
  1267. def _conv_out_length(input_length, kernel_size, stride):
  1268. # 1D convolutional layer output length formula taken
  1269. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  1270. return (input_length - kernel_size) // stride + 1
  1271. for kernel_size in self.config.tdnn_kernel:
  1272. input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
  1273. return input_lengths
  1274. @auto_docstring
  1275. def forward(
  1276. self,
  1277. input_values: torch.Tensor | None,
  1278. attention_mask: torch.Tensor | None = None,
  1279. output_attentions: bool | None = None,
  1280. output_hidden_states: bool | None = None,
  1281. return_dict: bool | None = None,
  1282. labels: torch.Tensor | None = None,
  1283. **kwargs,
  1284. ) -> tuple | XVectorOutput:
  1285. r"""
  1286. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1287. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1288. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1289. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1290. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1291. into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
  1292. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1293. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1294. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1295. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1296. """
  1297. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1298. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1299. outputs = self.wavlm(
  1300. input_values,
  1301. attention_mask=attention_mask,
  1302. output_attentions=output_attentions,
  1303. output_hidden_states=output_hidden_states,
  1304. return_dict=return_dict,
  1305. )
  1306. if self.config.use_weighted_layer_sum:
  1307. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1308. hidden_states = torch.stack(hidden_states, dim=1)
  1309. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1310. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1311. else:
  1312. hidden_states = outputs[0]
  1313. hidden_states = self.projector(hidden_states)
  1314. for tdnn_layer in self.tdnn:
  1315. hidden_states = tdnn_layer(hidden_states)
  1316. # Statistic Pooling
  1317. if attention_mask is None:
  1318. mean_features = hidden_states.mean(dim=1)
  1319. std_features = hidden_states.std(dim=1)
  1320. else:
  1321. feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
  1322. tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
  1323. mean_features = []
  1324. std_features = []
  1325. for i, length in enumerate(tdnn_output_lengths):
  1326. mean_features.append(hidden_states[i, :length].mean(dim=0))
  1327. std_features.append(hidden_states[i, :length].std(dim=0))
  1328. mean_features = torch.stack(mean_features)
  1329. std_features = torch.stack(std_features)
  1330. statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
  1331. output_embeddings = self.feature_extractor(statistic_pooling)
  1332. logits = self.classifier(output_embeddings)
  1333. loss = None
  1334. if labels is not None:
  1335. loss = self.objective(logits, labels)
  1336. if not return_dict:
  1337. output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
  1338. return ((loss,) + output) if loss is not None else output
  1339. return XVectorOutput(
  1340. loss=loss,
  1341. logits=logits,
  1342. embeddings=output_embeddings,
  1343. hidden_states=outputs.hidden_states,
  1344. attentions=outputs.attentions,
  1345. )
  1346. __all__ = [
  1347. "WavLMForAudioFrameClassification",
  1348. "WavLMForCTC",
  1349. "WavLMForSequenceClassification",
  1350. "WavLMForXVector",
  1351. "WavLMModel",
  1352. "WavLMPreTrainedModel",
  1353. ]