modeling_wav2vec2.py 90 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153
  1. # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Wav2Vec2 model."""
  15. import math
  16. import warnings
  17. from collections.abc import Callable
  18. from dataclasses import dataclass
  19. import numpy as np
  20. import torch
  21. from safetensors.torch import load_file as safe_load_file
  22. from torch import nn
  23. from torch.nn import CrossEntropyLoss
  24. from ... import initialization as init
  25. from ...activations import ACT2FN
  26. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  27. from ...integrations.fsdp import is_fsdp_managed_module
  28. from ...masking_utils import create_bidirectional_mask
  29. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import (
  32. BaseModelOutput,
  33. CausalLMOutput,
  34. SequenceClassifierOutput,
  35. TokenClassifierOutput,
  36. Wav2Vec2BaseModelOutput,
  37. XVectorOutput,
  38. )
  39. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, get_torch_context_manager_or_global_device
  40. from ...processing_utils import Unpack
  41. from ...utils import (
  42. ModelOutput,
  43. TransformersKwargs,
  44. auto_docstring,
  45. cached_file,
  46. check_torch_load_is_safe,
  47. is_peft_available,
  48. logging,
  49. )
  50. from .configuration_wav2vec2 import Wav2Vec2Config
  51. WAV2VEC2_ADAPTER_PT_FILE = "adapter.{}.bin"
  52. WAV2VEC2_ADAPTER_SAFE_FILE = "adapter.{}.safetensors"
  53. logger = logging.get_logger(__name__)
  54. _HIDDEN_STATES_START_POSITION = 2
  55. @dataclass
  56. @auto_docstring(
  57. custom_intro="""
  58. Output type of [`Wav2Vec2ForPreTraining`], with potential hidden states and attentions.
  59. """
  60. )
  61. class Wav2Vec2ForPreTrainingOutput(ModelOutput):
  62. r"""
  63. loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
  64. Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
  65. paper](https://huggingface.co/papers/2006.11477).
  66. projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
  67. Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
  68. projected quantized states.
  69. projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
  70. Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
  71. target vectors for contrastive loss.
  72. codevector_perplexity (`torch.FloatTensor` of shape `(1,)`):
  73. The perplexity of the codevector distribution, used to measure the diversity of the codebook.
  74. contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
  75. The contrastive loss (L_m) as stated in the [official paper](https://huggingface.co/papers/2006.11477).
  76. diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
  77. The diversity loss (L_d) as stated in the [official paper](https://huggingface.co/papers/2006.11477).
  78. """
  79. loss: torch.FloatTensor | None = None
  80. projected_states: torch.FloatTensor | None = None
  81. projected_quantized_states: torch.FloatTensor | None = None
  82. codevector_perplexity: torch.FloatTensor | None = None
  83. hidden_states: tuple[torch.FloatTensor] | None = None
  84. attentions: tuple[torch.FloatTensor] | None = None
  85. contrastive_loss: torch.FloatTensor | None = None
  86. diversity_loss: torch.FloatTensor | None = None
  87. def _compute_mask_indices(
  88. shape: tuple[int, int],
  89. mask_prob: float,
  90. mask_length: int,
  91. attention_mask: torch.LongTensor | None = None,
  92. min_masks: int = 0,
  93. ) -> np.ndarray:
  94. """
  95. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  96. ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  97. CPU as part of the preprocessing during training.
  98. Args:
  99. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  100. the first element is the batch size and the second element is the length of the axis to span.
  101. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  102. independently generated mask spans of length `mask_length` is computed by
  103. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  104. actual percentage will be smaller.
  105. mask_length: size of the mask
  106. min_masks: minimum number of masked spans
  107. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  108. each batch dimension.
  109. """
  110. batch_size, sequence_length = shape
  111. if mask_length < 1:
  112. raise ValueError("`mask_length` has to be bigger than 0.")
  113. if mask_length > sequence_length:
  114. raise ValueError(
  115. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  116. f" and `sequence_length`: {sequence_length}`"
  117. )
  118. # epsilon is used for probabilistic rounding
  119. epsilon = np.random.rand(1).item()
  120. def compute_num_masked_span(input_length):
  121. """Given input length, compute how many spans should be masked"""
  122. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  123. num_masked_span = max(num_masked_span, min_masks)
  124. # make sure num masked span <= sequence_length
  125. if num_masked_span * mask_length > sequence_length:
  126. num_masked_span = sequence_length // mask_length
  127. # make sure num_masked span is also <= input_length - (mask_length - 1)
  128. if input_length - (mask_length - 1) < num_masked_span:
  129. num_masked_span = max(input_length - (mask_length - 1), 0)
  130. return num_masked_span
  131. # compute number of masked spans in batch
  132. input_lengths = (
  133. attention_mask.detach().sum(-1).tolist()
  134. if attention_mask is not None
  135. else [sequence_length for _ in range(batch_size)]
  136. )
  137. # SpecAugment mask to fill
  138. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  139. spec_aug_mask_idxs = []
  140. max_num_masked_span = compute_num_masked_span(sequence_length)
  141. if max_num_masked_span == 0:
  142. return spec_aug_mask
  143. for input_length in input_lengths:
  144. # compute num of masked spans for this input
  145. num_masked_span = compute_num_masked_span(input_length)
  146. # get random indices to mask
  147. spec_aug_mask_idx = np.random.choice(
  148. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  149. )
  150. # pick first sampled index that will serve as a dummy index to pad vector
  151. # to ensure same dimension for all batches due to probabilistic rounding
  152. # Picking first sample just pads those vectors twice.
  153. if len(spec_aug_mask_idx) == 0:
  154. # this case can only happen if `input_length` is strictly smaller then
  155. # `sequence_length` in which case the last token has to be a padding
  156. # token which we can use as a dummy mask id
  157. dummy_mask_idx = sequence_length - 1
  158. else:
  159. dummy_mask_idx = spec_aug_mask_idx[0]
  160. spec_aug_mask_idx = np.concatenate(
  161. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  162. )
  163. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  164. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  165. # expand masked indices to masked spans
  166. spec_aug_mask_idxs = np.broadcast_to(
  167. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  168. )
  169. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  170. # add offset to the starting indexes so that indexes now create a span
  171. offsets = np.arange(mask_length)[None, None, :]
  172. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  173. batch_size, max_num_masked_span * mask_length
  174. )
  175. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  176. # ensure that we cannot have indices larger than sequence_length
  177. if spec_aug_mask_idxs.max() > sequence_length - 1:
  178. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  179. # scatter indices to mask
  180. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  181. return spec_aug_mask
  182. def _sample_negative_indices(features_shape: tuple, num_negatives: int, mask_time_indices: np.ndarray | None = None):
  183. """
  184. Sample `num_negatives` vectors from feature vectors.
  185. """
  186. batch_size, sequence_length = features_shape
  187. # generate indices of the positive vectors themselves, repeat them `num_negatives` times
  188. sequence_length_range = np.arange(sequence_length)
  189. # get `num_negatives` random vector indices from the same utterance
  190. sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
  191. mask_time_indices = (
  192. mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
  193. )
  194. for batch_idx in range(batch_size):
  195. high = mask_time_indices[batch_idx].sum() - 1
  196. mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
  197. feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
  198. sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
  199. # avoid sampling the same positive vector, but keep the distribution uniform
  200. sampled_indices[sampled_indices >= feature_indices] += 1
  201. # remap to actual indices
  202. sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
  203. # correct for batch size
  204. sampled_negative_indices[batch_idx] += batch_idx * sequence_length
  205. return sampled_negative_indices
  206. class Wav2Vec2NoLayerNormConvLayer(GradientCheckpointingLayer):
  207. def __init__(self, config, layer_id=0):
  208. super().__init__()
  209. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  210. self.out_conv_dim = config.conv_dim[layer_id]
  211. self.conv = nn.Conv1d(
  212. self.in_conv_dim,
  213. self.out_conv_dim,
  214. kernel_size=config.conv_kernel[layer_id],
  215. stride=config.conv_stride[layer_id],
  216. bias=config.conv_bias,
  217. )
  218. self.activation = ACT2FN[config.feat_extract_activation]
  219. def forward(self, hidden_states):
  220. hidden_states = self.conv(hidden_states)
  221. hidden_states = self.activation(hidden_states)
  222. return hidden_states
  223. class Wav2Vec2LayerNormConvLayer(GradientCheckpointingLayer):
  224. def __init__(self, config, layer_id=0):
  225. super().__init__()
  226. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  227. self.out_conv_dim = config.conv_dim[layer_id]
  228. self.conv = nn.Conv1d(
  229. self.in_conv_dim,
  230. self.out_conv_dim,
  231. kernel_size=config.conv_kernel[layer_id],
  232. stride=config.conv_stride[layer_id],
  233. bias=config.conv_bias,
  234. )
  235. self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
  236. self.activation = ACT2FN[config.feat_extract_activation]
  237. def forward(self, hidden_states):
  238. hidden_states = self.conv(hidden_states)
  239. hidden_states = hidden_states.transpose(-2, -1)
  240. hidden_states = self.layer_norm(hidden_states)
  241. hidden_states = hidden_states.transpose(-2, -1)
  242. hidden_states = self.activation(hidden_states)
  243. return hidden_states
  244. class Wav2Vec2GroupNormConvLayer(GradientCheckpointingLayer):
  245. def __init__(self, config, layer_id=0):
  246. super().__init__()
  247. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  248. self.out_conv_dim = config.conv_dim[layer_id]
  249. self.conv = nn.Conv1d(
  250. self.in_conv_dim,
  251. self.out_conv_dim,
  252. kernel_size=config.conv_kernel[layer_id],
  253. stride=config.conv_stride[layer_id],
  254. bias=config.conv_bias,
  255. )
  256. self.activation = ACT2FN[config.feat_extract_activation]
  257. self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
  258. def forward(self, hidden_states):
  259. hidden_states = self.conv(hidden_states)
  260. hidden_states = self.layer_norm(hidden_states)
  261. hidden_states = self.activation(hidden_states)
  262. return hidden_states
  263. class Wav2Vec2PositionalConvEmbedding(nn.Module):
  264. def __init__(self, config):
  265. super().__init__()
  266. self.conv = nn.Conv1d(
  267. config.hidden_size,
  268. config.hidden_size,
  269. kernel_size=config.num_conv_pos_embeddings,
  270. padding=config.num_conv_pos_embeddings // 2,
  271. groups=config.num_conv_pos_embedding_groups,
  272. )
  273. weight_norm = nn.utils.weight_norm
  274. if hasattr(nn.utils.parametrizations, "weight_norm"):
  275. weight_norm = nn.utils.parametrizations.weight_norm
  276. if is_deepspeed_zero3_enabled():
  277. import deepspeed
  278. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  279. self.conv = weight_norm(self.conv, name="weight", dim=2)
  280. if hasattr(self.conv, "parametrizations"):
  281. weight_g = self.conv.parametrizations.weight.original0
  282. weight_v = self.conv.parametrizations.weight.original1
  283. else:
  284. weight_g = self.conv.weight_g
  285. weight_v = self.conv.weight_v
  286. deepspeed.zero.register_external_parameter(self, weight_v)
  287. deepspeed.zero.register_external_parameter(self, weight_g)
  288. else:
  289. self.conv = weight_norm(self.conv, name="weight", dim=2)
  290. self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
  291. self.activation = ACT2FN[config.feat_extract_activation]
  292. def forward(self, hidden_states):
  293. hidden_states = hidden_states.transpose(1, 2)
  294. hidden_states = self.conv(hidden_states)
  295. hidden_states = self.padding(hidden_states)
  296. hidden_states = self.activation(hidden_states)
  297. hidden_states = hidden_states.transpose(1, 2)
  298. return hidden_states
  299. class Wav2Vec2SamePadLayer(nn.Module):
  300. def __init__(self, num_conv_pos_embeddings):
  301. super().__init__()
  302. self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
  303. def forward(self, hidden_states):
  304. if self.num_pad_remove > 0:
  305. hidden_states = hidden_states[:, :, : -self.num_pad_remove]
  306. return hidden_states
  307. class Wav2Vec2FeatureEncoder(nn.Module):
  308. """Construct the features from raw audio waveform"""
  309. def __init__(self, config):
  310. super().__init__()
  311. if config.feat_extract_norm == "group":
  312. conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [
  313. Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
  314. ]
  315. elif config.feat_extract_norm == "layer":
  316. conv_layers = [
  317. Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
  318. ]
  319. else:
  320. raise ValueError(
  321. f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
  322. )
  323. self.conv_layers = nn.ModuleList(conv_layers)
  324. self.gradient_checkpointing = False
  325. self._requires_grad = True
  326. def _freeze_parameters(self):
  327. for param in self.parameters():
  328. param.requires_grad = False
  329. self._requires_grad = False
  330. def forward(self, input_values):
  331. hidden_states = input_values[:, None]
  332. # make sure hidden_states require grad for gradient_checkpointing
  333. if self._requires_grad and self.training:
  334. hidden_states.requires_grad = True
  335. for conv_layer in self.conv_layers:
  336. hidden_states = conv_layer(hidden_states)
  337. return hidden_states
  338. class Wav2Vec2FeatureProjection(nn.Module):
  339. def __init__(self, config):
  340. super().__init__()
  341. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  342. self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  343. self.dropout = nn.Dropout(config.feat_proj_dropout)
  344. def forward(self, hidden_states):
  345. # non-projected hidden states are needed for quantization
  346. norm_hidden_states = self.layer_norm(hidden_states)
  347. hidden_states = self.projection(norm_hidden_states)
  348. hidden_states = self.dropout(hidden_states)
  349. return hidden_states, norm_hidden_states
  350. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  351. def eager_attention_forward(
  352. module: nn.Module,
  353. query: torch.Tensor,
  354. key: torch.Tensor,
  355. value: torch.Tensor,
  356. attention_mask: torch.Tensor | None,
  357. scaling: float | None = None,
  358. dropout: float = 0.0,
  359. **kwargs: Unpack[TransformersKwargs],
  360. ):
  361. if scaling is None:
  362. scaling = query.size(-1) ** -0.5
  363. # Take the dot product between "query" and "key" to get the raw attention scores.
  364. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  365. if attention_mask is not None:
  366. attn_weights = attn_weights + attention_mask
  367. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  368. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  369. attn_output = torch.matmul(attn_weights, value)
  370. attn_output = attn_output.transpose(1, 2).contiguous()
  371. return attn_output, attn_weights
  372. class Wav2Vec2Attention(nn.Module):
  373. """Multi-headed attention from 'Attention Is All You Need' paper"""
  374. def __init__(
  375. self,
  376. embed_dim: int,
  377. num_heads: int,
  378. dropout: float = 0.0,
  379. is_decoder: bool = False,
  380. bias: bool = True,
  381. is_causal: bool = False,
  382. config: Wav2Vec2Config | None = None,
  383. ):
  384. super().__init__()
  385. self.embed_dim = embed_dim
  386. self.num_heads = num_heads
  387. self.dropout = dropout
  388. self.head_dim = embed_dim // num_heads
  389. self.config = config
  390. if (self.head_dim * num_heads) != self.embed_dim:
  391. raise ValueError(
  392. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  393. f" and `num_heads`: {num_heads})."
  394. )
  395. self.scaling = self.head_dim**-0.5
  396. self.is_decoder = is_decoder
  397. self.is_causal = is_causal
  398. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  399. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  400. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  401. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  402. def forward(
  403. self,
  404. hidden_states: torch.Tensor,
  405. key_value_states: torch.Tensor | None = None,
  406. attention_mask: torch.Tensor | None = None,
  407. output_attentions: bool | None = False,
  408. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  409. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  410. **kwargs: Unpack[FlashAttentionKwargs],
  411. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  412. """Input shape: Batch x Time x Channel"""
  413. # if key_value_states are provided this layer is used as a cross-attention layer
  414. # for the decoder
  415. is_cross_attention = key_value_states is not None
  416. # determine input shapes
  417. input_shape = hidden_states.shape[:-1]
  418. hidden_shape = (*input_shape, -1, self.head_dim)
  419. # get query proj
  420. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  421. current_states = key_value_states if is_cross_attention else hidden_states
  422. kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
  423. key_states = self.k_proj(current_states).view(kv_shape).transpose(1, 2)
  424. value_states = self.v_proj(current_states).view(kv_shape).transpose(1, 2)
  425. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  426. self.config._attn_implementation, eager_attention_forward
  427. )
  428. attn_output, attn_weights = attention_interface(
  429. self,
  430. query_states,
  431. key_states,
  432. value_states,
  433. attention_mask,
  434. dropout=0.0 if not self.training else self.dropout,
  435. scaling=self.scaling,
  436. output_attentions=output_attentions,
  437. **kwargs,
  438. )
  439. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  440. attn_output = self.out_proj(attn_output)
  441. return attn_output, attn_weights, None
  442. class Wav2Vec2FeedForward(nn.Module):
  443. def __init__(self, config):
  444. super().__init__()
  445. self.intermediate_dropout = nn.Dropout(config.activation_dropout)
  446. self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
  447. if isinstance(config.hidden_act, str):
  448. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  449. else:
  450. self.intermediate_act_fn = config.hidden_act
  451. self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
  452. self.output_dropout = nn.Dropout(config.hidden_dropout)
  453. def forward(self, hidden_states):
  454. hidden_states = self.intermediate_dense(hidden_states)
  455. hidden_states = self.intermediate_act_fn(hidden_states)
  456. hidden_states = self.intermediate_dropout(hidden_states)
  457. hidden_states = self.output_dense(hidden_states)
  458. hidden_states = self.output_dropout(hidden_states)
  459. return hidden_states
  460. class Wav2Vec2EncoderLayer(GradientCheckpointingLayer):
  461. def __init__(self, config):
  462. super().__init__()
  463. self.attention = Wav2Vec2Attention(
  464. embed_dim=config.hidden_size,
  465. num_heads=config.num_attention_heads,
  466. dropout=config.attention_dropout,
  467. is_decoder=False,
  468. config=config,
  469. )
  470. self.dropout = nn.Dropout(config.hidden_dropout)
  471. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  472. self.feed_forward = Wav2Vec2FeedForward(config)
  473. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  474. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  475. attn_residual = hidden_states
  476. hidden_states, attn_weights, _ = self.attention(
  477. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  478. )
  479. hidden_states = self.dropout(hidden_states)
  480. hidden_states = attn_residual + hidden_states
  481. hidden_states = self.layer_norm(hidden_states)
  482. hidden_states = hidden_states + self.feed_forward(hidden_states)
  483. hidden_states = self.final_layer_norm(hidden_states)
  484. outputs = (hidden_states,)
  485. if output_attentions:
  486. outputs += (attn_weights,)
  487. return outputs
  488. class Wav2Vec2EncoderLayerStableLayerNorm(GradientCheckpointingLayer):
  489. def __init__(self, config):
  490. super().__init__()
  491. self.attention = Wav2Vec2Attention(
  492. embed_dim=config.hidden_size,
  493. num_heads=config.num_attention_heads,
  494. dropout=config.attention_dropout,
  495. is_decoder=False,
  496. config=config,
  497. )
  498. self.dropout = nn.Dropout(config.hidden_dropout)
  499. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  500. self.feed_forward = Wav2Vec2FeedForward(config)
  501. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  502. if getattr(config, "adapter_attn_dim", None) is not None:
  503. self.adapter_layer = Wav2Vec2AttnAdapterLayer(config)
  504. else:
  505. self.adapter_layer = None
  506. def forward(
  507. self,
  508. hidden_states: torch.Tensor,
  509. attention_mask: torch.Tensor | None = None,
  510. output_attentions: bool = False,
  511. ):
  512. attn_residual = hidden_states
  513. hidden_states = self.layer_norm(hidden_states)
  514. hidden_states, attn_weights, _ = self.attention(
  515. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  516. )
  517. hidden_states = self.dropout(hidden_states)
  518. hidden_states = attn_residual + hidden_states
  519. hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
  520. if self.adapter_layer is not None:
  521. hidden_states = hidden_states + self.adapter_layer(hidden_states)
  522. outputs = (hidden_states,)
  523. if output_attentions:
  524. outputs += (attn_weights,)
  525. return outputs
  526. class Wav2Vec2Encoder(nn.Module):
  527. def __init__(self, config):
  528. super().__init__()
  529. self.config = config
  530. self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
  531. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  532. self.dropout = nn.Dropout(config.hidden_dropout)
  533. self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
  534. self.gradient_checkpointing = False
  535. def forward(
  536. self,
  537. hidden_states: torch.tensor,
  538. attention_mask: torch.Tensor | None = None,
  539. output_attentions: bool = False,
  540. output_hidden_states: bool = False,
  541. return_dict: bool = True,
  542. ):
  543. all_hidden_states = () if output_hidden_states else None
  544. all_self_attentions = () if output_attentions else None
  545. if attention_mask is not None:
  546. # make sure padded tokens output 0
  547. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  548. hidden_states[~expand_attention_mask] = 0
  549. attention_mask = create_bidirectional_mask(
  550. config=self.config,
  551. inputs_embeds=hidden_states,
  552. attention_mask=attention_mask,
  553. )
  554. position_embeddings = self.pos_conv_embed(hidden_states)
  555. hidden_states = hidden_states + position_embeddings.to(hidden_states.device)
  556. hidden_states = self.layer_norm(hidden_states)
  557. hidden_states = self.dropout(hidden_states)
  558. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  559. for layer in self.layers:
  560. if output_hidden_states:
  561. all_hidden_states = all_hidden_states + (hidden_states,)
  562. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  563. dropout_probability = torch.rand([])
  564. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  565. if not skip_the_layer or synced_gpus:
  566. # under fsdp or deepspeed zero3 all gpus must run in sync
  567. layer_outputs = layer(
  568. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  569. )
  570. hidden_states = layer_outputs[0]
  571. if skip_the_layer:
  572. layer_outputs = (None, None)
  573. if output_attentions:
  574. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  575. if output_hidden_states:
  576. all_hidden_states = all_hidden_states + (hidden_states,)
  577. if not return_dict:
  578. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  579. return BaseModelOutput(
  580. last_hidden_state=hidden_states,
  581. hidden_states=all_hidden_states,
  582. attentions=all_self_attentions,
  583. )
  584. class Wav2Vec2EncoderStableLayerNorm(nn.Module):
  585. def __init__(self, config):
  586. super().__init__()
  587. self.config = config
  588. self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
  589. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  590. self.dropout = nn.Dropout(config.hidden_dropout)
  591. self.layers = nn.ModuleList(
  592. [Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
  593. )
  594. self.gradient_checkpointing = False
  595. def forward(
  596. self,
  597. hidden_states,
  598. attention_mask=None,
  599. output_attentions=False,
  600. output_hidden_states=False,
  601. return_dict=True,
  602. ):
  603. all_hidden_states = () if output_hidden_states else None
  604. all_self_attentions = () if output_attentions else None
  605. if attention_mask is not None:
  606. # make sure padded tokens output 0
  607. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  608. hidden_states[~expand_attention_mask] = 0
  609. attention_mask = create_bidirectional_mask(
  610. config=self.config,
  611. inputs_embeds=hidden_states,
  612. attention_mask=attention_mask,
  613. )
  614. position_embeddings = self.pos_conv_embed(hidden_states)
  615. hidden_states = hidden_states + position_embeddings
  616. hidden_states = self.dropout(hidden_states)
  617. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  618. for layer in self.layers:
  619. if output_hidden_states:
  620. all_hidden_states = all_hidden_states + (hidden_states,)
  621. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  622. dropout_probability = torch.rand([])
  623. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  624. if not skip_the_layer or synced_gpus:
  625. # under fsdp or deepspeed zero3 all gpus must run in sync
  626. # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
  627. layer_outputs = layer(
  628. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  629. )
  630. hidden_states = layer_outputs[0]
  631. if skip_the_layer:
  632. layer_outputs = (None, None)
  633. if output_attentions:
  634. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  635. hidden_states = self.layer_norm(hidden_states)
  636. if output_hidden_states:
  637. all_hidden_states = all_hidden_states + (hidden_states,)
  638. if not return_dict:
  639. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  640. return BaseModelOutput(
  641. last_hidden_state=hidden_states,
  642. hidden_states=all_hidden_states,
  643. attentions=all_self_attentions,
  644. )
  645. class Wav2Vec2GumbelVectorQuantizer(nn.Module):
  646. """
  647. Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
  648. GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
  649. """
  650. def __init__(self, config):
  651. super().__init__()
  652. self.num_groups = config.num_codevector_groups
  653. self.num_vars = config.num_codevectors_per_group
  654. if config.codevector_dim % self.num_groups != 0:
  655. raise ValueError(
  656. f"`config.codevector_dim {config.codevector_dim} must be divisible "
  657. f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
  658. )
  659. # storage for codebook variables (codewords)
  660. self.codevectors = nn.Parameter(
  661. torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
  662. )
  663. self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
  664. # can be decayed for training
  665. self.temperature = 2
  666. @staticmethod
  667. def _compute_perplexity(probs, mask=None):
  668. if mask is not None:
  669. mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
  670. probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
  671. marginal_probs = probs.sum(dim=0) / mask.sum()
  672. else:
  673. marginal_probs = probs.mean(dim=0)
  674. perplexity = torch.exp(-torch.sum(torch.xlogy(marginal_probs, marginal_probs), dim=-1)).sum()
  675. return perplexity
  676. def forward(self, hidden_states, mask_time_indices=None):
  677. batch_size, sequence_length, hidden_size = hidden_states.shape
  678. # project to codevector dim
  679. hidden_states = self.weight_proj(hidden_states)
  680. hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
  681. if self.training:
  682. # sample code vector probs via gumbel in differentiateable way
  683. codevector_probs = nn.functional.gumbel_softmax(
  684. hidden_states.float(), tau=self.temperature, hard=True
  685. ).type_as(hidden_states)
  686. # compute perplexity
  687. codevector_soft_dist = torch.softmax(
  688. hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
  689. )
  690. perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
  691. else:
  692. # take argmax in non-differentiable way
  693. # comptute hard codevector distribution (one hot)
  694. codevector_idx = hidden_states.argmax(dim=-1)
  695. codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
  696. -1, codevector_idx.view(-1, 1), 1.0
  697. )
  698. codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
  699. perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
  700. codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
  701. # use probs to retrieve codevectors
  702. codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
  703. codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
  704. codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
  705. return codevectors, perplexity
  706. class Wav2Vec2Adapter(nn.Module):
  707. def __init__(self, config):
  708. super().__init__()
  709. # feature dim might need to be down-projected
  710. if config.output_hidden_size != config.hidden_size:
  711. self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
  712. self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
  713. else:
  714. self.proj = self.proj_layer_norm = None
  715. self.layers = nn.ModuleList(Wav2Vec2AdapterLayer(config) for _ in range(config.num_adapter_layers))
  716. self.layerdrop = config.layerdrop
  717. def forward(self, hidden_states):
  718. # down project hidden_states if necessary
  719. if self.proj is not None and self.proj_layer_norm is not None:
  720. hidden_states = self.proj(hidden_states)
  721. hidden_states = self.proj_layer_norm(hidden_states)
  722. hidden_states = hidden_states.transpose(1, 2)
  723. for layer in self.layers:
  724. layerdrop_prob = np.random.random()
  725. if not self.training or (layerdrop_prob > self.layerdrop):
  726. hidden_states = layer(hidden_states)
  727. hidden_states = hidden_states.transpose(1, 2)
  728. return hidden_states
  729. class Wav2Vec2AdapterLayer(nn.Module):
  730. def __init__(self, config):
  731. super().__init__()
  732. self.conv = nn.Conv1d(
  733. config.output_hidden_size,
  734. 2 * config.output_hidden_size,
  735. config.adapter_kernel_size,
  736. stride=config.adapter_stride,
  737. padding=1,
  738. )
  739. def forward(self, hidden_states):
  740. hidden_states = self.conv(hidden_states)
  741. hidden_states = nn.functional.glu(hidden_states, dim=1)
  742. return hidden_states
  743. class Wav2Vec2AttnAdapterLayer(nn.Module):
  744. def __init__(self, config):
  745. """
  746. Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
  747. up training throughput.
  748. """
  749. super().__init__()
  750. self.input_dim = config.adapter_attn_dim
  751. self.hidden_dim = config.hidden_size
  752. self.norm = nn.LayerNorm(self.hidden_dim)
  753. self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
  754. self.act_fn = nn.ReLU()
  755. self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
  756. def forward(self, hidden_states: torch.FloatTensor):
  757. hidden_states = self.norm(hidden_states)
  758. hidden_states = self.linear_1(hidden_states)
  759. hidden_states = self.act_fn(hidden_states)
  760. hidden_states = self.linear_2(hidden_states)
  761. return hidden_states
  762. @auto_docstring
  763. class Wav2Vec2PreTrainedModel(PreTrainedModel):
  764. config: Wav2Vec2Config
  765. base_model_prefix = "wav2vec2"
  766. main_input_name = "input_values"
  767. input_modalities = "audio"
  768. supports_gradient_checkpointing = True
  769. _supports_flash_attn = True
  770. _supports_sdpa = True
  771. _supports_flex_attn = True
  772. @torch.no_grad()
  773. def _init_weights(self, module):
  774. """Initialize the weights"""
  775. # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
  776. if isinstance(module, Wav2Vec2ForPreTraining):
  777. module.project_hid.reset_parameters()
  778. module.project_q.reset_parameters()
  779. # gumbel softmax requires special init
  780. elif isinstance(module, Wav2Vec2GumbelVectorQuantizer):
  781. init.normal_(module.weight_proj.weight, mean=0.0, std=1)
  782. init.zeros_(module.weight_proj.bias)
  783. init.uniform_(module.codevectors)
  784. elif isinstance(module, Wav2Vec2PositionalConvEmbedding):
  785. init.normal_(
  786. module.conv.weight,
  787. mean=0,
  788. std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
  789. )
  790. init.constant_(module.conv.bias, 0)
  791. elif isinstance(module, Wav2Vec2FeatureProjection):
  792. k = math.sqrt(1 / module.projection.in_features)
  793. init.uniform_(module.projection.weight, a=-k, b=k)
  794. init.uniform_(module.projection.bias, a=-k, b=k)
  795. elif isinstance(module, nn.Linear):
  796. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  797. if module.bias is not None:
  798. init.zeros_(module.bias)
  799. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  800. init.zeros_(module.bias)
  801. init.ones_(module.weight)
  802. elif isinstance(module, nn.Conv1d):
  803. init.kaiming_normal_(module.weight)
  804. if module.bias is not None:
  805. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  806. init.uniform_(module.bias, a=-k, b=k)
  807. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int, add_adapter: bool | None = None):
  808. """
  809. Computes the output length of the convolutional layers
  810. """
  811. add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
  812. def _conv_out_length(input_length, kernel_size, stride):
  813. # 1D convolutional layer output length formula taken
  814. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  815. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  816. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  817. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  818. if add_adapter:
  819. for _ in range(self.config.num_adapter_layers):
  820. input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
  821. return input_lengths
  822. def _get_feature_vector_attention_mask(
  823. self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
  824. ):
  825. # Effectively attention_mask.sum(-1), but not inplace to be able to run
  826. # on inference mode.
  827. non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
  828. output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
  829. output_lengths = output_lengths.to(torch.long)
  830. batch_size = attention_mask.shape[0]
  831. attention_mask = torch.zeros(
  832. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  833. )
  834. # these two operations makes sure that all values before the output lengths idxs are attended to
  835. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  836. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  837. return attention_mask
  838. def _get_adapters(self):
  839. if self.config.adapter_attn_dim is None:
  840. raise ValueError(f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`.")
  841. adapter_weights = {}
  842. for name, module in self.named_modules():
  843. if isinstance(module, Wav2Vec2AttnAdapterLayer):
  844. for param_name, param in module.named_parameters():
  845. adapter_weights[".".join([name, param_name])] = param
  846. if isinstance(self, Wav2Vec2ForCTC):
  847. for name, param in self.lm_head.named_parameters():
  848. adapter_weights[".".join(["lm_head", name])] = param
  849. return adapter_weights
  850. def init_adapter_layers(self):
  851. """
  852. (Re-)initialize attention adapter layers and lm head for adapter-only fine-tuning
  853. """
  854. # init attention adapters
  855. for module in self.modules():
  856. if isinstance(module, Wav2Vec2AttnAdapterLayer):
  857. self._init_weights(module)
  858. # init lm head
  859. if isinstance(self, Wav2Vec2ForCTC):
  860. self._init_weights(self.lm_head)
  861. def load_adapter(self, target_lang: str, force_load=True, **kwargs):
  862. r"""
  863. Load a language adapter model from a pre-trained adapter model.
  864. Parameters:
  865. target_lang (`str`):
  866. Has to be a language id of an existing adapter weight. Adapter weights are stored in the format
  867. adapter.<lang>.safetensors or adapter.<lang>.bin
  868. force_load (`bool`, defaults to `True`):
  869. Whether the weights shall be loaded even if `target_lang` matches `self.target_lang`.
  870. cache_dir (`Union[str, os.PathLike]`, *optional*):
  871. Path to a directory in which a downloaded pretrained model configuration should be cached if the
  872. standard cache should not be used.
  873. force_download (`bool`, *optional*, defaults to `False`):
  874. Whether or not to force the (re-)download of the model weights and configuration files, overriding the
  875. cached versions if they exist.
  876. proxies (`dict[str, str]`, *optional*):
  877. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  878. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
  879. local_files_only(`bool`, *optional*, defaults to `False`):
  880. Whether or not to only look at local files (i.e., do not try to download the model).
  881. token (`str` or `bool`, *optional*):
  882. The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
  883. the token generated when running `hf auth login` (stored in `~/.huggingface`).
  884. revision (`str`, *optional*, defaults to `"main"`):
  885. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  886. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  887. identifier allowed by git.
  888. <Tip>
  889. To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
  890. </Tip>
  891. mirror (`str`, *optional*):
  892. Mirror source to accelerate downloads in China. If you are from China and have an accessibility
  893. problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
  894. Please refer to the mirror site for more information.
  895. <Tip>
  896. Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
  897. use this method in a firewalled environment.
  898. </Tip>
  899. Examples:
  900. ```python
  901. >>> from transformers import Wav2Vec2ForCTC, AutoProcessor
  902. >>> ckpt = "facebook/mms-1b-all"
  903. >>> processor = AutoProcessor.from_pretrained(ckpt)
  904. >>> model = Wav2Vec2ForCTC.from_pretrained(ckpt, target_lang="eng")
  905. >>> # set specific language
  906. >>> processor.tokenizer.set_target_lang("spa")
  907. >>> model.load_adapter("spa")
  908. ```
  909. """
  910. if self.config.adapter_attn_dim is None:
  911. raise ValueError(f"Cannot load_adapter for {target_lang} if `config.adapter_attn_dim` is not defined.")
  912. if target_lang == self.target_lang and not force_load:
  913. logger.warning(f"Adapter weights are already set to {target_lang}.")
  914. return
  915. cache_dir = kwargs.pop("cache_dir", None)
  916. force_download = kwargs.pop("force_download", False)
  917. proxies = kwargs.pop("proxies", None)
  918. local_files_only = kwargs.pop("local_files_only", False)
  919. token = kwargs.pop("token", None)
  920. revision = kwargs.pop("revision", None)
  921. use_safetensors = kwargs.pop("use_safetensors", None)
  922. model_path_or_id = self.config._name_or_path
  923. state_dict = None
  924. # 1. Let's first try loading a safetensors adapter weight
  925. if use_safetensors is not False:
  926. filepath = WAV2VEC2_ADAPTER_SAFE_FILE.format(target_lang)
  927. try:
  928. weight_path = cached_file(
  929. model_path_or_id,
  930. filename=filepath,
  931. force_download=force_download,
  932. proxies=proxies,
  933. local_files_only=local_files_only,
  934. token=token,
  935. revision=revision,
  936. cache_dir=cache_dir,
  937. )
  938. state_dict = safe_load_file(weight_path)
  939. except OSError:
  940. if use_safetensors:
  941. # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
  942. # to the original exception.
  943. raise
  944. except Exception:
  945. # For any other exception, we throw a generic error.
  946. if use_safetensors:
  947. raise OSError(
  948. f"Can't load the model for '{model_path_or_id}'. If you were trying to load it"
  949. " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
  950. f" same name. Otherwise, make sure '{model_path_or_id}' is the correct path to a"
  951. f" directory containing a file named {filepath}."
  952. )
  953. # 2. If this didn't work let's try loading a PyTorch adapter weight
  954. if state_dict is None:
  955. filepath = WAV2VEC2_ADAPTER_PT_FILE.format(target_lang)
  956. try:
  957. weight_path = cached_file(
  958. model_path_or_id,
  959. filename=filepath,
  960. force_download=force_download,
  961. proxies=proxies,
  962. local_files_only=local_files_only,
  963. token=token,
  964. revision=revision,
  965. cache_dir=cache_dir,
  966. )
  967. check_torch_load_is_safe()
  968. state_dict = torch.load(
  969. weight_path,
  970. map_location="cpu",
  971. weights_only=True,
  972. )
  973. except OSError:
  974. # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
  975. # to the original exception.
  976. raise
  977. except ValueError:
  978. raise
  979. except Exception:
  980. # For any other exception, we throw a generic error.
  981. raise OSError(
  982. f"Can't load the model for '{model_path_or_id}'. If you were trying to load it"
  983. " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
  984. f" same name. Otherwise, make sure '{model_path_or_id}' is the correct path to a"
  985. f" directory containing a file named {filepath}."
  986. )
  987. adapter_weights = self._get_adapters()
  988. unexpected_keys = set(state_dict.keys()) - set(adapter_weights.keys())
  989. missing_keys = set(adapter_weights.keys()) - set(state_dict.keys())
  990. if len(unexpected_keys) > 0:
  991. raise ValueError(f"The adapter weights {weight_path} has unexpected keys: {', '.join(unexpected_keys)}.")
  992. elif len(missing_keys) > 0:
  993. raise ValueError(f"The adapter weights {weight_path} has missing keys: {', '.join(missing_keys)}.")
  994. # make sure now vocab size is correct
  995. target_vocab_size = state_dict["lm_head.weight"].shape[0]
  996. if target_vocab_size != self.config.vocab_size:
  997. self.lm_head = nn.Linear(
  998. self.config.output_hidden_size, target_vocab_size, device=self.device, dtype=self.dtype
  999. )
  1000. self.config.vocab_size = target_vocab_size
  1001. # make sure that adapter weights are put in exactly the same precision and device placement and overwritten adapter weights
  1002. state_dict = {k: v.to(adapter_weights[k]) for k, v in state_dict.items()}
  1003. self.load_state_dict(state_dict, strict=False)
  1004. # set target language correctly
  1005. self.target_lang = target_lang
  1006. @auto_docstring
  1007. class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
  1008. def __init__(self, config: Wav2Vec2Config):
  1009. super().__init__(config)
  1010. self.config = config
  1011. self.feature_extractor = Wav2Vec2FeatureEncoder(config)
  1012. self.feature_projection = Wav2Vec2FeatureProjection(config)
  1013. # model only needs masking vector if mask prob is > 0.0
  1014. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  1015. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  1016. if config.do_stable_layer_norm:
  1017. self.encoder = Wav2Vec2EncoderStableLayerNorm(config)
  1018. else:
  1019. self.encoder = Wav2Vec2Encoder(config)
  1020. self.adapter = Wav2Vec2Adapter(config) if config.add_adapter else None
  1021. # Initialize weights and apply final processing
  1022. self.post_init()
  1023. def freeze_feature_encoder(self):
  1024. """
  1025. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1026. not be updated during training.
  1027. """
  1028. self.feature_extractor._freeze_parameters()
  1029. def _mask_hidden_states(
  1030. self,
  1031. hidden_states: torch.FloatTensor,
  1032. mask_time_indices: torch.FloatTensor | None = None,
  1033. attention_mask: torch.LongTensor | None = None,
  1034. ):
  1035. """
  1036. Masks extracted features along time axis and/or along feature axis according to
  1037. [SpecAugment](https://huggingface.co/papers/1904.08779).
  1038. """
  1039. # `config.apply_spec_augment` can set masking to False
  1040. if not getattr(self.config, "apply_spec_augment", True):
  1041. return hidden_states
  1042. # generate indices & apply SpecAugment along time axis
  1043. batch_size, sequence_length, hidden_size = hidden_states.size()
  1044. if mask_time_indices is not None:
  1045. # apply SpecAugment along time axis with given mask_time_indices
  1046. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  1047. elif self.config.mask_time_prob > 0 and self.training:
  1048. mask_time_indices = _compute_mask_indices(
  1049. (batch_size, sequence_length),
  1050. mask_prob=self.config.mask_time_prob,
  1051. mask_length=self.config.mask_time_length,
  1052. attention_mask=attention_mask,
  1053. min_masks=self.config.mask_time_min_masks,
  1054. )
  1055. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  1056. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  1057. if self.config.mask_feature_prob > 0 and self.training:
  1058. # generate indices & apply SpecAugment along feature axis
  1059. mask_feature_indices = _compute_mask_indices(
  1060. (batch_size, hidden_size),
  1061. mask_prob=self.config.mask_feature_prob,
  1062. mask_length=self.config.mask_feature_length,
  1063. min_masks=self.config.mask_feature_min_masks,
  1064. )
  1065. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  1066. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  1067. hidden_states[mask_feature_indices] = 0
  1068. return hidden_states
  1069. @auto_docstring
  1070. def forward(
  1071. self,
  1072. input_values: torch.Tensor | None,
  1073. attention_mask: torch.Tensor | None = None,
  1074. mask_time_indices: torch.FloatTensor | None = None,
  1075. output_attentions: bool | None = None,
  1076. output_hidden_states: bool | None = None,
  1077. return_dict: bool | None = None,
  1078. **kwargs,
  1079. ) -> tuple | Wav2Vec2BaseModelOutput:
  1080. r"""
  1081. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1082. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  1083. masked extracted features in *config.proj_codevector_dim* space.
  1084. """
  1085. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1086. output_hidden_states = (
  1087. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1088. )
  1089. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1090. extract_features = self.feature_extractor(input_values)
  1091. extract_features = extract_features.transpose(1, 2)
  1092. if attention_mask is not None:
  1093. # compute reduced attention_mask corresponding to feature vectors
  1094. attention_mask = self._get_feature_vector_attention_mask(
  1095. extract_features.shape[1], attention_mask, add_adapter=False
  1096. )
  1097. hidden_states, extract_features = self.feature_projection(extract_features)
  1098. hidden_states = self._mask_hidden_states(
  1099. hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
  1100. )
  1101. encoder_outputs = self.encoder(
  1102. hidden_states,
  1103. attention_mask=attention_mask,
  1104. output_attentions=output_attentions,
  1105. output_hidden_states=output_hidden_states,
  1106. return_dict=return_dict,
  1107. )
  1108. hidden_states = encoder_outputs[0]
  1109. if self.adapter is not None:
  1110. hidden_states = self.adapter(hidden_states)
  1111. if not return_dict:
  1112. return (hidden_states, extract_features) + encoder_outputs[1:]
  1113. return Wav2Vec2BaseModelOutput(
  1114. last_hidden_state=hidden_states,
  1115. extract_features=extract_features,
  1116. hidden_states=encoder_outputs.hidden_states,
  1117. attentions=encoder_outputs.attentions,
  1118. )
  1119. @auto_docstring(
  1120. custom_intro="""
  1121. Wav2Vec2 Model with a quantizer and `VQ` head on top.
  1122. """
  1123. )
  1124. class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
  1125. def __init__(self, config: Wav2Vec2Config):
  1126. super().__init__(config)
  1127. self.wav2vec2 = Wav2Vec2Model(config)
  1128. self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
  1129. self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
  1130. self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
  1131. self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
  1132. # Initialize weights and apply final processing
  1133. self.post_init()
  1134. def set_gumbel_temperature(self, temperature: int):
  1135. """
  1136. Set the Gumbel softmax temperature to a given value. Only necessary for training
  1137. """
  1138. self.quantizer.temperature = temperature
  1139. def freeze_feature_encoder(self):
  1140. """
  1141. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1142. not be updated during training.
  1143. """
  1144. self.wav2vec2.feature_extractor._freeze_parameters()
  1145. @staticmethod
  1146. def compute_contrastive_logits(
  1147. target_features: torch.FloatTensor,
  1148. negative_features: torch.FloatTensor,
  1149. predicted_features: torch.FloatTensor,
  1150. temperature: float = 0.1,
  1151. ):
  1152. """
  1153. Compute logits for contrastive loss based using cosine similarity as the distance measure between
  1154. `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
  1155. """
  1156. target_features = torch.cat([target_features, negative_features], dim=0)
  1157. logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
  1158. target_features
  1159. )
  1160. # apply temperature
  1161. logits = logits / temperature
  1162. return logits
  1163. @auto_docstring
  1164. def forward(
  1165. self,
  1166. input_values: torch.Tensor | None,
  1167. attention_mask: torch.Tensor | None = None,
  1168. mask_time_indices: torch.BoolTensor | None = None,
  1169. sampled_negative_indices: torch.BoolTensor | None = None,
  1170. output_attentions: bool | None = None,
  1171. output_hidden_states: bool | None = None,
  1172. return_dict: bool | None = None,
  1173. **kwargs,
  1174. ) -> tuple | Wav2Vec2ForPreTrainingOutput:
  1175. r"""
  1176. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1177. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  1178. masked extracted features in *config.proj_codevector_dim* space.
  1179. sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
  1180. Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
  1181. Required input for pre-training.
  1182. Example:
  1183. ```python
  1184. >>> import torch
  1185. >>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
  1186. >>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
  1187. >>> from datasets import load_dataset
  1188. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
  1189. >>> model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
  1190. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  1191. >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
  1192. >>> # compute masked indices
  1193. >>> batch_size, raw_sequence_length = input_values.shape
  1194. >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
  1195. >>> mask_time_indices = _compute_mask_indices(
  1196. ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
  1197. ... )
  1198. >>> sampled_negative_indices = _sample_negative_indices(
  1199. ... features_shape=(batch_size, sequence_length),
  1200. ... num_negatives=model.config.num_negatives,
  1201. ... mask_time_indices=mask_time_indices,
  1202. ... )
  1203. >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
  1204. >>> sampled_negative_indices = torch.tensor(
  1205. ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
  1206. ... )
  1207. >>> with torch.no_grad():
  1208. ... outputs = model(input_values, mask_time_indices=mask_time_indices)
  1209. >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
  1210. >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
  1211. >>> # show that cosine similarity is much higher than random
  1212. >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
  1213. tensor(True)
  1214. >>> # for contrastive loss training model should be put into train mode
  1215. >>> model = model.train()
  1216. >>> loss = model(
  1217. ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
  1218. ... ).loss
  1219. ```"""
  1220. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1221. if mask_time_indices is not None:
  1222. mask_time_indices = mask_time_indices.to(torch.bool)
  1223. outputs = self.wav2vec2(
  1224. input_values,
  1225. attention_mask=attention_mask,
  1226. output_attentions=output_attentions,
  1227. output_hidden_states=output_hidden_states,
  1228. mask_time_indices=mask_time_indices,
  1229. return_dict=return_dict,
  1230. )
  1231. # 1. project all transformed features (including masked) to final vq dim
  1232. transformer_features = self.project_hid(outputs[0])
  1233. # 2. quantize all (unmasked) extracted features and project to final vq dim
  1234. extract_features = self.dropout_features(outputs[1])
  1235. if attention_mask is not None:
  1236. # compute reduced attention_mask corresponding to feature vectors
  1237. attention_mask = self._get_feature_vector_attention_mask(
  1238. extract_features.shape[1], attention_mask, add_adapter=False
  1239. )
  1240. quantized_features, codevector_perplexity = self.quantizer(
  1241. extract_features, mask_time_indices=mask_time_indices
  1242. )
  1243. quantized_features = quantized_features.to(self.project_q.weight.dtype)
  1244. quantized_features = self.project_q(quantized_features)
  1245. loss = contrastive_loss = diversity_loss = None
  1246. if sampled_negative_indices is not None:
  1247. batch_size, sequence_length, hidden_size = quantized_features.shape
  1248. # for training, we sample negatives
  1249. # 3. sample K negatives (distractors) quantized states for contrastive loss
  1250. # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
  1251. # sample negative quantized vectors BTC => (BxT)C
  1252. negative_quantized_features = quantized_features.view(-1, hidden_size)[
  1253. sampled_negative_indices.long().view(-1)
  1254. ]
  1255. negative_quantized_features = negative_quantized_features.view(
  1256. batch_size, sequence_length, -1, hidden_size
  1257. ).permute(2, 0, 1, 3)
  1258. # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
  1259. # of equation (3) in https://huggingface.co/papers/2006.11477
  1260. logits = self.compute_contrastive_logits(
  1261. quantized_features[None, :],
  1262. negative_quantized_features,
  1263. transformer_features,
  1264. self.config.contrastive_logits_temperature,
  1265. )
  1266. # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
  1267. # its cosine similarity will be masked
  1268. neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
  1269. if neg_is_pos.any():
  1270. logits[1:][neg_is_pos] = float("-inf")
  1271. # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
  1272. # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
  1273. logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
  1274. target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
  1275. contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
  1276. # 7. compute diversity loss: \mathbf{L}_d
  1277. num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
  1278. diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
  1279. # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
  1280. loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
  1281. if not return_dict:
  1282. if loss is not None:
  1283. return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
  1284. return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
  1285. return Wav2Vec2ForPreTrainingOutput(
  1286. loss=loss,
  1287. projected_states=transformer_features,
  1288. projected_quantized_states=quantized_features,
  1289. codevector_perplexity=codevector_perplexity,
  1290. hidden_states=outputs.hidden_states,
  1291. attentions=outputs.attentions,
  1292. contrastive_loss=contrastive_loss,
  1293. diversity_loss=diversity_loss,
  1294. )
  1295. @auto_docstring(
  1296. custom_intro="""
  1297. Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
  1298. """
  1299. )
  1300. class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
  1301. def __init__(self, config, target_lang: str | None = None):
  1302. r"""
  1303. target_lang (`str`, *optional*):
  1304. Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
  1305. adapter.<lang>.bin. Only relevant when using an instance of [`Wav2Vec2ForCTC`] with adapters. Uses 'eng' by
  1306. default.
  1307. """
  1308. super().__init__(config)
  1309. self.wav2vec2 = Wav2Vec2Model(config)
  1310. self.dropout = nn.Dropout(config.final_dropout)
  1311. self.target_lang = target_lang
  1312. if config.vocab_size is None:
  1313. raise ValueError(
  1314. f"You are trying to instantiate {self.__class__} with a configuration that "
  1315. "does not define the vocabulary size of the language model head. Please "
  1316. "instantiate the model as follows: `Wav2Vec2ForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
  1317. "or define `vocab_size` of your model's configuration."
  1318. )
  1319. output_hidden_size = (
  1320. config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
  1321. )
  1322. self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
  1323. # Initialize weights and apply final processing
  1324. self.post_init()
  1325. def tie_weights(self, **kwargs):
  1326. """
  1327. This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
  1328. passing `target_lang=...` to `from_pretrained(...)`.
  1329. This method is **not** supposed to be called by the user and is prone to be changed in the future.
  1330. """
  1331. if get_torch_context_manager_or_global_device() == torch.device("meta"):
  1332. return
  1333. # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
  1334. # correctly load adapter layers for Wav2Vec2 so that we do not have to introduce a new API to
  1335. # [`PreTrainedModel`]. While slightly hacky, Wav2Vec2 never has to tie input and output embeddings, so that it is
  1336. # ok to repurpose this function here.
  1337. target_lang = self.target_lang
  1338. if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
  1339. raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
  1340. elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
  1341. logger.info("By default `target_lang` is set to 'eng'.")
  1342. elif target_lang is not None:
  1343. self.load_adapter(target_lang, force_load=True)
  1344. def freeze_feature_encoder(self):
  1345. """
  1346. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1347. not be updated during training.
  1348. """
  1349. self.wav2vec2.feature_extractor._freeze_parameters()
  1350. def freeze_base_model(self):
  1351. """
  1352. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1353. be updated during training. Only the classification head will be updated.
  1354. """
  1355. for param in self.wav2vec2.parameters():
  1356. param.requires_grad = False
  1357. @auto_docstring
  1358. def forward(
  1359. self,
  1360. input_values: torch.Tensor | None,
  1361. attention_mask: torch.Tensor | None = None,
  1362. output_attentions: bool | None = None,
  1363. output_hidden_states: bool | None = None,
  1364. return_dict: bool | None = None,
  1365. labels: torch.Tensor | None = None,
  1366. **kwargs,
  1367. ) -> tuple | CausalLMOutput:
  1368. r"""
  1369. labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
  1370. Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
  1371. the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
  1372. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  1373. config.vocab_size - 1]`.
  1374. """
  1375. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1376. if labels is not None and labels.max() >= self.config.vocab_size:
  1377. raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
  1378. outputs = self.wav2vec2(
  1379. input_values,
  1380. attention_mask=attention_mask,
  1381. output_attentions=output_attentions,
  1382. output_hidden_states=output_hidden_states,
  1383. return_dict=return_dict,
  1384. )
  1385. hidden_states = outputs[0]
  1386. hidden_states = self.dropout(hidden_states)
  1387. logits = self.lm_head(hidden_states)
  1388. loss = None
  1389. if labels is not None:
  1390. # retrieve loss input_lengths from attention_mask
  1391. attention_mask = (
  1392. attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
  1393. )
  1394. input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  1395. # assuming that padded tokens are filled with -100
  1396. # when not being attended to
  1397. labels_mask = labels >= 0
  1398. target_lengths = labels_mask.sum(-1)
  1399. flattened_targets = labels.masked_select(labels_mask)
  1400. # ctc_loss doesn't support fp16
  1401. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  1402. with torch.backends.cudnn.flags(enabled=False):
  1403. loss = nn.functional.ctc_loss(
  1404. log_probs,
  1405. flattened_targets,
  1406. input_lengths,
  1407. target_lengths,
  1408. blank=self.config.pad_token_id,
  1409. reduction=self.config.ctc_loss_reduction,
  1410. zero_infinity=self.config.ctc_zero_infinity,
  1411. )
  1412. if not return_dict:
  1413. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1414. return ((loss,) + output) if loss is not None else output
  1415. return CausalLMOutput(
  1416. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  1417. )
  1418. @auto_docstring(
  1419. custom_intro="""
  1420. Wav2Vec2 Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
  1421. SUPERB Keyword Spotting.
  1422. """
  1423. )
  1424. class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
  1425. def __init__(self, config):
  1426. super().__init__(config)
  1427. if hasattr(config, "add_adapter") and config.add_adapter:
  1428. raise ValueError(
  1429. "Sequence classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)"
  1430. )
  1431. self.wav2vec2 = Wav2Vec2Model(config)
  1432. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1433. if config.use_weighted_layer_sum:
  1434. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1435. self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
  1436. self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
  1437. # Initialize weights and apply final processing
  1438. self.post_init()
  1439. def freeze_feature_encoder(self):
  1440. """
  1441. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1442. not be updated during training.
  1443. """
  1444. self.wav2vec2.feature_extractor._freeze_parameters()
  1445. def freeze_base_model(self):
  1446. """
  1447. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1448. be updated during training. Only the classification head will be updated.
  1449. """
  1450. for param in self.wav2vec2.parameters():
  1451. param.requires_grad = False
  1452. @auto_docstring
  1453. def forward(
  1454. self,
  1455. input_values: torch.Tensor | None,
  1456. attention_mask: torch.Tensor | None = None,
  1457. output_attentions: bool | None = None,
  1458. output_hidden_states: bool | None = None,
  1459. return_dict: bool | None = None,
  1460. labels: torch.Tensor | None = None,
  1461. **kwargs,
  1462. ) -> tuple | SequenceClassifierOutput:
  1463. r"""
  1464. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1465. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1466. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1467. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1468. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1469. into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
  1470. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1471. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1472. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1473. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1474. """
  1475. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1476. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1477. outputs = self.wav2vec2(
  1478. input_values,
  1479. attention_mask=attention_mask,
  1480. output_attentions=output_attentions,
  1481. output_hidden_states=output_hidden_states,
  1482. return_dict=return_dict,
  1483. )
  1484. if self.config.use_weighted_layer_sum:
  1485. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1486. hidden_states = torch.stack(hidden_states, dim=1)
  1487. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1488. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1489. else:
  1490. hidden_states = outputs[0]
  1491. hidden_states = self.projector(hidden_states)
  1492. if attention_mask is None:
  1493. pooled_output = hidden_states.mean(dim=1)
  1494. else:
  1495. padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
  1496. expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  1497. hidden_states[~expand_padding_mask] = 0.0
  1498. pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
  1499. logits = self.classifier(pooled_output)
  1500. loss = None
  1501. if labels is not None:
  1502. loss_fct = CrossEntropyLoss()
  1503. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1504. if not return_dict:
  1505. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1506. return ((loss,) + output) if loss is not None else output
  1507. return SequenceClassifierOutput(
  1508. loss=loss,
  1509. logits=logits,
  1510. hidden_states=outputs.hidden_states,
  1511. attentions=outputs.attentions,
  1512. )
  1513. @auto_docstring
  1514. class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
  1515. def __init__(self, config):
  1516. super().__init__(config)
  1517. if hasattr(config, "add_adapter") and config.add_adapter:
  1518. raise ValueError(
  1519. "Audio frame classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)"
  1520. )
  1521. self.wav2vec2 = Wav2Vec2Model(config)
  1522. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1523. if config.use_weighted_layer_sum:
  1524. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1525. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1526. self.num_labels = config.num_labels
  1527. self.post_init()
  1528. def freeze_feature_encoder(self):
  1529. """
  1530. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1531. not be updated during training.
  1532. """
  1533. self.wav2vec2.feature_extractor._freeze_parameters()
  1534. def freeze_base_model(self):
  1535. """
  1536. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1537. be updated during training. Only the classification head will be updated.
  1538. """
  1539. for param in self.wav2vec2.parameters():
  1540. param.requires_grad = False
  1541. @auto_docstring
  1542. def forward(
  1543. self,
  1544. input_values: torch.Tensor | None,
  1545. attention_mask: torch.Tensor | None = None,
  1546. labels: torch.Tensor | None = None,
  1547. output_attentions: bool | None = None,
  1548. output_hidden_states: bool | None = None,
  1549. return_dict: bool | None = None,
  1550. **kwargs,
  1551. ) -> tuple | TokenClassifierOutput:
  1552. r"""
  1553. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1554. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1555. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1556. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1557. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1558. into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
  1559. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1560. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1561. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1562. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1563. """
  1564. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1565. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1566. outputs = self.wav2vec2(
  1567. input_values,
  1568. attention_mask=attention_mask,
  1569. output_attentions=output_attentions,
  1570. output_hidden_states=output_hidden_states,
  1571. return_dict=return_dict,
  1572. )
  1573. if self.config.use_weighted_layer_sum:
  1574. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1575. hidden_states = torch.stack(hidden_states, dim=1)
  1576. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1577. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1578. else:
  1579. hidden_states = outputs[0]
  1580. logits = self.classifier(hidden_states)
  1581. loss = None
  1582. if labels is not None:
  1583. loss_fct = CrossEntropyLoss()
  1584. loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
  1585. if not return_dict:
  1586. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1587. return output
  1588. return TokenClassifierOutput(
  1589. loss=loss,
  1590. logits=logits,
  1591. hidden_states=outputs.hidden_states,
  1592. attentions=outputs.attentions,
  1593. )
  1594. class AMSoftmaxLoss(nn.Module):
  1595. def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
  1596. super().__init__()
  1597. self.scale = scale
  1598. self.margin = margin
  1599. self.num_labels = num_labels
  1600. self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
  1601. self.loss = nn.CrossEntropyLoss()
  1602. def forward(self, hidden_states, labels):
  1603. labels = labels.flatten()
  1604. weight = nn.functional.normalize(self.weight, dim=0)
  1605. hidden_states = nn.functional.normalize(hidden_states, dim=1)
  1606. cos_theta = torch.mm(hidden_states, weight)
  1607. psi = cos_theta - self.margin
  1608. onehot = nn.functional.one_hot(labels, self.num_labels)
  1609. logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
  1610. loss = self.loss(logits, labels)
  1611. return loss
  1612. class TDNNLayer(nn.Module):
  1613. def __init__(self, config, layer_id=0):
  1614. super().__init__()
  1615. self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
  1616. self.out_conv_dim = config.tdnn_dim[layer_id]
  1617. self.kernel_size = config.tdnn_kernel[layer_id]
  1618. self.dilation = config.tdnn_dilation[layer_id]
  1619. self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
  1620. self.activation = nn.ReLU()
  1621. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  1622. if is_peft_available():
  1623. from peft.tuners.lora import LoraLayer
  1624. if is_peft_available():
  1625. if isinstance(self.kernel, LoraLayer):
  1626. warnings.warn(
  1627. "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
  1628. "You should exclude TDNNLayer from LoRA's target modules.",
  1629. )
  1630. # for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
  1631. hidden_states = hidden_states.transpose(1, 2)
  1632. weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
  1633. hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
  1634. hidden_states = hidden_states.transpose(1, 2)
  1635. hidden_states = self.activation(hidden_states)
  1636. return hidden_states
  1637. @auto_docstring(
  1638. custom_intro="""
  1639. Wav2Vec2 Model with an XVector feature extraction head on top for tasks like Speaker Verification.
  1640. """
  1641. )
  1642. class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
  1643. def __init__(self, config):
  1644. super().__init__(config)
  1645. self.wav2vec2 = Wav2Vec2Model(config)
  1646. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1647. if config.use_weighted_layer_sum:
  1648. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1649. self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
  1650. tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
  1651. self.tdnn = nn.ModuleList(tdnn_layers)
  1652. self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
  1653. self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
  1654. self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
  1655. self.post_init()
  1656. def freeze_feature_encoder(self):
  1657. """
  1658. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1659. not be updated during training.
  1660. """
  1661. self.wav2vec2.feature_extractor._freeze_parameters()
  1662. def freeze_base_model(self):
  1663. """
  1664. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1665. be updated during training. Only the classification head will be updated.
  1666. """
  1667. for param in self.wav2vec2.parameters():
  1668. param.requires_grad = False
  1669. def _get_tdnn_output_lengths(self, input_lengths: torch.LongTensor | int):
  1670. """
  1671. Computes the output length of the TDNN layers
  1672. """
  1673. def _conv_out_length(input_length, kernel_size, stride):
  1674. # 1D convolutional layer output length formula taken
  1675. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  1676. return (input_length - kernel_size) // stride + 1
  1677. for kernel_size in self.config.tdnn_kernel:
  1678. input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
  1679. return input_lengths
  1680. @auto_docstring
  1681. def forward(
  1682. self,
  1683. input_values: torch.Tensor | None,
  1684. attention_mask: torch.Tensor | None = None,
  1685. output_attentions: bool | None = None,
  1686. output_hidden_states: bool | None = None,
  1687. return_dict: bool | None = None,
  1688. labels: torch.Tensor | None = None,
  1689. **kwargs,
  1690. ) -> tuple | XVectorOutput:
  1691. r"""
  1692. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1693. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1694. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1695. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1696. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1697. into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
  1698. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1699. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1700. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1701. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1702. """
  1703. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1704. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1705. outputs = self.wav2vec2(
  1706. input_values,
  1707. attention_mask=attention_mask,
  1708. output_attentions=output_attentions,
  1709. output_hidden_states=output_hidden_states,
  1710. return_dict=return_dict,
  1711. )
  1712. if self.config.use_weighted_layer_sum:
  1713. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1714. hidden_states = torch.stack(hidden_states, dim=1)
  1715. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1716. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1717. else:
  1718. hidden_states = outputs[0]
  1719. hidden_states = self.projector(hidden_states)
  1720. for tdnn_layer in self.tdnn:
  1721. hidden_states = tdnn_layer(hidden_states)
  1722. # Statistic Pooling
  1723. if attention_mask is None:
  1724. mean_features = hidden_states.mean(dim=1)
  1725. std_features = hidden_states.std(dim=1)
  1726. else:
  1727. feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
  1728. tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
  1729. mean_features = []
  1730. std_features = []
  1731. for i, length in enumerate(tdnn_output_lengths):
  1732. mean_features.append(hidden_states[i, :length].mean(dim=0))
  1733. std_features.append(hidden_states[i, :length].std(dim=0))
  1734. mean_features = torch.stack(mean_features)
  1735. std_features = torch.stack(std_features)
  1736. statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
  1737. output_embeddings = self.feature_extractor(statistic_pooling)
  1738. logits = self.classifier(output_embeddings)
  1739. loss = None
  1740. if labels is not None:
  1741. loss = self.objective(logits, labels)
  1742. if not return_dict:
  1743. output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
  1744. return ((loss,) + output) if loss is not None else output
  1745. return XVectorOutput(
  1746. loss=loss,
  1747. logits=logits,
  1748. embeddings=output_embeddings,
  1749. hidden_states=outputs.hidden_states,
  1750. attentions=outputs.attentions,
  1751. )
  1752. __all__ = [
  1753. "Wav2Vec2ForAudioFrameClassification",
  1754. "Wav2Vec2ForCTC",
  1755. "Wav2Vec2ForPreTraining",
  1756. "Wav2Vec2ForSequenceClassification",
  1757. "Wav2Vec2ForXVector",
  1758. "Wav2Vec2Model",
  1759. "Wav2Vec2PreTrainedModel",
  1760. ]