modeling_unispeech.py 59 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/unispeech/modular_unispeech.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_unispeech.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. from dataclasses import dataclass
  23. import numpy as np
  24. import torch
  25. import torch.nn as nn
  26. from torch.nn import CrossEntropyLoss
  27. from ... import initialization as init
  28. from ...activations import ACT2FN
  29. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  30. from ...integrations.fsdp import is_fsdp_managed_module
  31. from ...masking_utils import create_bidirectional_mask
  32. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  33. from ...modeling_layers import GradientCheckpointingLayer
  34. from ...modeling_outputs import (
  35. BaseModelOutput,
  36. CausalLMOutput,
  37. ModelOutput,
  38. SequenceClassifierOutput,
  39. Wav2Vec2BaseModelOutput,
  40. )
  41. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, get_torch_context_manager_or_global_device
  42. from ...processing_utils import Unpack
  43. from ...utils import TransformersKwargs, auto_docstring, logging
  44. from .configuration_unispeech import UniSpeechConfig
  45. logger = logging.get_logger(__name__)
  46. @dataclass
  47. @auto_docstring(
  48. custom_intro="""
  49. Output type of [`UniSpeechForPreTrainingOutput`], with potential hidden states and attentions.
  50. """
  51. )
  52. class UniSpeechForPreTrainingOutput(ModelOutput):
  53. r"""
  54. loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`):
  55. Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
  56. paper](https://huggingface.co/papers/2006.11477).
  57. projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
  58. Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
  59. projected quantized states.
  60. projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
  61. Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
  62. target vectors for contrastive loss.
  63. codevector_perplexity (`torch.FloatTensor` of shape `(1,)`):
  64. The perplexity of the codevector distribution, used to measure the diversity of the codebook.
  65. """
  66. loss: torch.FloatTensor | None = None
  67. projected_states: torch.FloatTensor | None = None
  68. projected_quantized_states: torch.FloatTensor | None = None
  69. codevector_perplexity: torch.FloatTensor | None = None
  70. hidden_states: tuple[torch.FloatTensor] | None = None
  71. attentions: tuple[torch.FloatTensor] | None = None
  72. class UniSpeechSamePadLayer(nn.Module):
  73. def __init__(self, num_conv_pos_embeddings):
  74. super().__init__()
  75. self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
  76. def forward(self, hidden_states):
  77. if self.num_pad_remove > 0:
  78. hidden_states = hidden_states[:, :, : -self.num_pad_remove]
  79. return hidden_states
  80. class UniSpeechPositionalConvEmbedding(nn.Module):
  81. def __init__(self, config):
  82. super().__init__()
  83. self.conv = nn.Conv1d(
  84. config.hidden_size,
  85. config.hidden_size,
  86. kernel_size=config.num_conv_pos_embeddings,
  87. padding=config.num_conv_pos_embeddings // 2,
  88. groups=config.num_conv_pos_embedding_groups,
  89. )
  90. weight_norm = nn.utils.weight_norm
  91. if hasattr(nn.utils.parametrizations, "weight_norm"):
  92. weight_norm = nn.utils.parametrizations.weight_norm
  93. if is_deepspeed_zero3_enabled():
  94. import deepspeed
  95. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  96. self.conv = weight_norm(self.conv, name="weight", dim=2)
  97. if hasattr(self.conv, "parametrizations"):
  98. weight_g = self.conv.parametrizations.weight.original0
  99. weight_v = self.conv.parametrizations.weight.original1
  100. else:
  101. weight_g = self.conv.weight_g
  102. weight_v = self.conv.weight_v
  103. deepspeed.zero.register_external_parameter(self, weight_v)
  104. deepspeed.zero.register_external_parameter(self, weight_g)
  105. else:
  106. self.conv = weight_norm(self.conv, name="weight", dim=2)
  107. self.padding = UniSpeechSamePadLayer(config.num_conv_pos_embeddings)
  108. self.activation = ACT2FN[config.feat_extract_activation]
  109. def forward(self, hidden_states):
  110. hidden_states = hidden_states.transpose(1, 2)
  111. hidden_states = self.conv(hidden_states)
  112. hidden_states = self.padding(hidden_states)
  113. hidden_states = self.activation(hidden_states)
  114. hidden_states = hidden_states.transpose(1, 2)
  115. return hidden_states
  116. class UniSpeechNoLayerNormConvLayer(GradientCheckpointingLayer):
  117. def __init__(self, config, layer_id=0):
  118. super().__init__()
  119. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  120. self.out_conv_dim = config.conv_dim[layer_id]
  121. self.conv = nn.Conv1d(
  122. self.in_conv_dim,
  123. self.out_conv_dim,
  124. kernel_size=config.conv_kernel[layer_id],
  125. stride=config.conv_stride[layer_id],
  126. bias=config.conv_bias,
  127. )
  128. self.activation = ACT2FN[config.feat_extract_activation]
  129. def forward(self, hidden_states):
  130. hidden_states = self.conv(hidden_states)
  131. hidden_states = self.activation(hidden_states)
  132. return hidden_states
  133. class UniSpeechLayerNormConvLayer(GradientCheckpointingLayer):
  134. def __init__(self, config, layer_id=0):
  135. super().__init__()
  136. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  137. self.out_conv_dim = config.conv_dim[layer_id]
  138. self.conv = nn.Conv1d(
  139. self.in_conv_dim,
  140. self.out_conv_dim,
  141. kernel_size=config.conv_kernel[layer_id],
  142. stride=config.conv_stride[layer_id],
  143. bias=config.conv_bias,
  144. )
  145. self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
  146. self.activation = ACT2FN[config.feat_extract_activation]
  147. def forward(self, hidden_states):
  148. hidden_states = self.conv(hidden_states)
  149. hidden_states = hidden_states.transpose(-2, -1)
  150. hidden_states = self.layer_norm(hidden_states)
  151. hidden_states = hidden_states.transpose(-2, -1)
  152. hidden_states = self.activation(hidden_states)
  153. return hidden_states
  154. class UniSpeechGroupNormConvLayer(GradientCheckpointingLayer):
  155. def __init__(self, config, layer_id=0):
  156. super().__init__()
  157. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  158. self.out_conv_dim = config.conv_dim[layer_id]
  159. self.conv = nn.Conv1d(
  160. self.in_conv_dim,
  161. self.out_conv_dim,
  162. kernel_size=config.conv_kernel[layer_id],
  163. stride=config.conv_stride[layer_id],
  164. bias=config.conv_bias,
  165. )
  166. self.activation = ACT2FN[config.feat_extract_activation]
  167. self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
  168. def forward(self, hidden_states):
  169. hidden_states = self.conv(hidden_states)
  170. hidden_states = self.layer_norm(hidden_states)
  171. hidden_states = self.activation(hidden_states)
  172. return hidden_states
  173. class UniSpeechFeatureEncoder(nn.Module):
  174. """Construct the features from raw audio waveform"""
  175. def __init__(self, config):
  176. super().__init__()
  177. if config.feat_extract_norm == "group":
  178. conv_layers = [UniSpeechGroupNormConvLayer(config, layer_id=0)] + [
  179. UniSpeechNoLayerNormConvLayer(config, layer_id=i + 1)
  180. for i in range(config.num_feat_extract_layers - 1)
  181. ]
  182. elif config.feat_extract_norm == "layer":
  183. conv_layers = [
  184. UniSpeechLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
  185. ]
  186. else:
  187. raise ValueError(
  188. f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
  189. )
  190. self.conv_layers = nn.ModuleList(conv_layers)
  191. self.gradient_checkpointing = False
  192. self._requires_grad = True
  193. def _freeze_parameters(self):
  194. for param in self.parameters():
  195. param.requires_grad = False
  196. self._requires_grad = False
  197. def forward(self, input_values):
  198. hidden_states = input_values[:, None]
  199. # make sure hidden_states require grad for gradient_checkpointing
  200. if self._requires_grad and self.training:
  201. hidden_states.requires_grad = True
  202. for conv_layer in self.conv_layers:
  203. hidden_states = conv_layer(hidden_states)
  204. return hidden_states
  205. class UniSpeechFeatureProjection(nn.Module):
  206. def __init__(self, config):
  207. super().__init__()
  208. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  209. self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  210. self.dropout = nn.Dropout(config.feat_proj_dropout)
  211. def forward(self, hidden_states):
  212. # non-projected hidden states are needed for quantization
  213. norm_hidden_states = self.layer_norm(hidden_states)
  214. hidden_states = self.projection(norm_hidden_states)
  215. hidden_states = self.dropout(hidden_states)
  216. return hidden_states, norm_hidden_states
  217. def eager_attention_forward(
  218. module: nn.Module,
  219. query: torch.Tensor,
  220. key: torch.Tensor,
  221. value: torch.Tensor,
  222. attention_mask: torch.Tensor | None,
  223. scaling: float | None = None,
  224. dropout: float = 0.0,
  225. **kwargs: Unpack[TransformersKwargs],
  226. ):
  227. if scaling is None:
  228. scaling = query.size(-1) ** -0.5
  229. # Take the dot product between "query" and "key" to get the raw attention scores.
  230. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  231. if attention_mask is not None:
  232. attn_weights = attn_weights + attention_mask
  233. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  234. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  235. attn_output = torch.matmul(attn_weights, value)
  236. attn_output = attn_output.transpose(1, 2).contiguous()
  237. return attn_output, attn_weights
  238. class UniSpeechAttention(nn.Module):
  239. """Multi-headed attention from 'Attention Is All You Need' paper"""
  240. def __init__(
  241. self,
  242. embed_dim: int,
  243. num_heads: int,
  244. dropout: float = 0.0,
  245. is_decoder: bool = False,
  246. bias: bool = True,
  247. is_causal: bool = False,
  248. config: UniSpeechConfig | None = None,
  249. ):
  250. super().__init__()
  251. self.embed_dim = embed_dim
  252. self.num_heads = num_heads
  253. self.dropout = dropout
  254. self.head_dim = embed_dim // num_heads
  255. self.config = config
  256. if (self.head_dim * num_heads) != self.embed_dim:
  257. raise ValueError(
  258. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  259. f" and `num_heads`: {num_heads})."
  260. )
  261. self.scaling = self.head_dim**-0.5
  262. self.is_decoder = is_decoder
  263. self.is_causal = is_causal
  264. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  265. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  266. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  267. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  268. def forward(
  269. self,
  270. hidden_states: torch.Tensor,
  271. key_value_states: torch.Tensor | None = None,
  272. attention_mask: torch.Tensor | None = None,
  273. output_attentions: bool | None = False,
  274. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  275. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  276. **kwargs: Unpack[FlashAttentionKwargs],
  277. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  278. """Input shape: Batch x Time x Channel"""
  279. # if key_value_states are provided this layer is used as a cross-attention layer
  280. # for the decoder
  281. is_cross_attention = key_value_states is not None
  282. # determine input shapes
  283. input_shape = hidden_states.shape[:-1]
  284. hidden_shape = (*input_shape, -1, self.head_dim)
  285. # get query proj
  286. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  287. current_states = key_value_states if is_cross_attention else hidden_states
  288. kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
  289. key_states = self.k_proj(current_states).view(kv_shape).transpose(1, 2)
  290. value_states = self.v_proj(current_states).view(kv_shape).transpose(1, 2)
  291. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  292. self.config._attn_implementation, eager_attention_forward
  293. )
  294. attn_output, attn_weights = attention_interface(
  295. self,
  296. query_states,
  297. key_states,
  298. value_states,
  299. attention_mask,
  300. dropout=0.0 if not self.training else self.dropout,
  301. scaling=self.scaling,
  302. output_attentions=output_attentions,
  303. **kwargs,
  304. )
  305. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  306. attn_output = self.out_proj(attn_output)
  307. return attn_output, attn_weights, None
  308. class UniSpeechFeedForward(nn.Module):
  309. def __init__(self, config):
  310. super().__init__()
  311. self.intermediate_dropout = nn.Dropout(config.activation_dropout)
  312. self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
  313. if isinstance(config.hidden_act, str):
  314. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  315. else:
  316. self.intermediate_act_fn = config.hidden_act
  317. self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
  318. self.output_dropout = nn.Dropout(config.hidden_dropout)
  319. def forward(self, hidden_states):
  320. hidden_states = self.intermediate_dense(hidden_states)
  321. hidden_states = self.intermediate_act_fn(hidden_states)
  322. hidden_states = self.intermediate_dropout(hidden_states)
  323. hidden_states = self.output_dense(hidden_states)
  324. hidden_states = self.output_dropout(hidden_states)
  325. return hidden_states
  326. class UniSpeechEncoderLayer(GradientCheckpointingLayer):
  327. def __init__(self, config):
  328. super().__init__()
  329. self.attention = UniSpeechAttention(
  330. embed_dim=config.hidden_size,
  331. num_heads=config.num_attention_heads,
  332. dropout=config.attention_dropout,
  333. is_decoder=False,
  334. config=config,
  335. )
  336. self.dropout = nn.Dropout(config.hidden_dropout)
  337. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  338. self.feed_forward = UniSpeechFeedForward(config)
  339. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  340. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  341. attn_residual = hidden_states
  342. hidden_states, attn_weights, _ = self.attention(
  343. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  344. )
  345. hidden_states = self.dropout(hidden_states)
  346. hidden_states = attn_residual + hidden_states
  347. hidden_states = self.layer_norm(hidden_states)
  348. hidden_states = hidden_states + self.feed_forward(hidden_states)
  349. hidden_states = self.final_layer_norm(hidden_states)
  350. outputs = (hidden_states,)
  351. if output_attentions:
  352. outputs += (attn_weights,)
  353. return outputs
  354. class UniSpeechEncoder(nn.Module):
  355. def __init__(self, config):
  356. super().__init__()
  357. self.config = config
  358. self.pos_conv_embed = UniSpeechPositionalConvEmbedding(config)
  359. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  360. self.dropout = nn.Dropout(config.hidden_dropout)
  361. self.layers = nn.ModuleList([UniSpeechEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  362. self.gradient_checkpointing = False
  363. def forward(
  364. self,
  365. hidden_states: torch.tensor,
  366. attention_mask: torch.Tensor | None = None,
  367. output_attentions: bool = False,
  368. output_hidden_states: bool = False,
  369. return_dict: bool = True,
  370. ):
  371. all_hidden_states = () if output_hidden_states else None
  372. all_self_attentions = () if output_attentions else None
  373. if attention_mask is not None:
  374. # make sure padded tokens output 0
  375. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  376. hidden_states[~expand_attention_mask] = 0
  377. attention_mask = create_bidirectional_mask(
  378. config=self.config,
  379. inputs_embeds=hidden_states,
  380. attention_mask=attention_mask,
  381. )
  382. position_embeddings = self.pos_conv_embed(hidden_states)
  383. hidden_states = hidden_states + position_embeddings.to(hidden_states.device)
  384. hidden_states = self.layer_norm(hidden_states)
  385. hidden_states = self.dropout(hidden_states)
  386. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  387. for layer in self.layers:
  388. if output_hidden_states:
  389. all_hidden_states = all_hidden_states + (hidden_states,)
  390. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  391. dropout_probability = torch.rand([])
  392. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  393. if not skip_the_layer or synced_gpus:
  394. # under fsdp or deepspeed zero3 all gpus must run in sync
  395. layer_outputs = layer(
  396. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  397. )
  398. hidden_states = layer_outputs[0]
  399. if skip_the_layer:
  400. layer_outputs = (None, None)
  401. if output_attentions:
  402. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  403. if output_hidden_states:
  404. all_hidden_states = all_hidden_states + (hidden_states,)
  405. if not return_dict:
  406. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  407. return BaseModelOutput(
  408. last_hidden_state=hidden_states,
  409. hidden_states=all_hidden_states,
  410. attentions=all_self_attentions,
  411. )
  412. class UniSpeechAttnAdapterLayer(nn.Module):
  413. def __init__(self, config):
  414. """
  415. Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
  416. up training throughput.
  417. """
  418. super().__init__()
  419. self.input_dim = config.adapter_attn_dim
  420. self.hidden_dim = config.hidden_size
  421. self.norm = nn.LayerNorm(self.hidden_dim)
  422. self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
  423. self.act_fn = nn.ReLU()
  424. self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
  425. def forward(self, hidden_states: torch.FloatTensor):
  426. hidden_states = self.norm(hidden_states)
  427. hidden_states = self.linear_1(hidden_states)
  428. hidden_states = self.act_fn(hidden_states)
  429. hidden_states = self.linear_2(hidden_states)
  430. return hidden_states
  431. class UniSpeechEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
  432. def __init__(self, config):
  433. super().__init__()
  434. self.attention = UniSpeechAttention(
  435. embed_dim=config.hidden_size,
  436. num_heads=config.num_attention_heads,
  437. dropout=config.attention_dropout,
  438. is_decoder=False,
  439. config=config,
  440. )
  441. self.dropout = nn.Dropout(config.hidden_dropout)
  442. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  443. self.feed_forward = UniSpeechFeedForward(config)
  444. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  445. if getattr(config, "adapter_attn_dim", None) is not None:
  446. self.adapter_layer = UniSpeechAttnAdapterLayer(config)
  447. else:
  448. self.adapter_layer = None
  449. def forward(
  450. self,
  451. hidden_states: torch.Tensor,
  452. attention_mask: torch.Tensor | None = None,
  453. output_attentions: bool = False,
  454. ):
  455. attn_residual = hidden_states
  456. hidden_states = self.layer_norm(hidden_states)
  457. hidden_states, attn_weights, _ = self.attention(
  458. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  459. )
  460. hidden_states = self.dropout(hidden_states)
  461. hidden_states = attn_residual + hidden_states
  462. hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
  463. if self.adapter_layer is not None:
  464. hidden_states = hidden_states + self.adapter_layer(hidden_states)
  465. outputs = (hidden_states,)
  466. if output_attentions:
  467. outputs += (attn_weights,)
  468. return outputs
  469. class UniSpeechEncoderStableLayerNorm(nn.Module):
  470. def __init__(self, config):
  471. super().__init__()
  472. self.config = config
  473. self.pos_conv_embed = UniSpeechPositionalConvEmbedding(config)
  474. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  475. self.dropout = nn.Dropout(config.hidden_dropout)
  476. self.layers = nn.ModuleList(
  477. [UniSpeechEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
  478. )
  479. self.gradient_checkpointing = False
  480. def forward(
  481. self,
  482. hidden_states,
  483. attention_mask=None,
  484. output_attentions=False,
  485. output_hidden_states=False,
  486. return_dict=True,
  487. ):
  488. all_hidden_states = () if output_hidden_states else None
  489. all_self_attentions = () if output_attentions else None
  490. if attention_mask is not None:
  491. # make sure padded tokens output 0
  492. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  493. hidden_states[~expand_attention_mask] = 0
  494. attention_mask = create_bidirectional_mask(
  495. config=self.config,
  496. inputs_embeds=hidden_states,
  497. attention_mask=attention_mask,
  498. )
  499. position_embeddings = self.pos_conv_embed(hidden_states)
  500. hidden_states = hidden_states + position_embeddings
  501. hidden_states = self.dropout(hidden_states)
  502. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  503. for layer in self.layers:
  504. if output_hidden_states:
  505. all_hidden_states = all_hidden_states + (hidden_states,)
  506. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  507. dropout_probability = torch.rand([])
  508. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  509. if not skip_the_layer or synced_gpus:
  510. # under fsdp or deepspeed zero3 all gpus must run in sync
  511. # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
  512. layer_outputs = layer(
  513. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  514. )
  515. hidden_states = layer_outputs[0]
  516. if skip_the_layer:
  517. layer_outputs = (None, None)
  518. if output_attentions:
  519. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  520. hidden_states = self.layer_norm(hidden_states)
  521. if output_hidden_states:
  522. all_hidden_states = all_hidden_states + (hidden_states,)
  523. if not return_dict:
  524. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  525. return BaseModelOutput(
  526. last_hidden_state=hidden_states,
  527. hidden_states=all_hidden_states,
  528. attentions=all_self_attentions,
  529. )
  530. class UniSpeechGumbelVectorQuantizer(nn.Module):
  531. """
  532. Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
  533. GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
  534. """
  535. def __init__(self, config):
  536. super().__init__()
  537. self.num_groups = config.num_codevector_groups
  538. self.num_vars = config.num_codevectors_per_group
  539. if config.codevector_dim % self.num_groups != 0:
  540. raise ValueError(
  541. f"`config.codevector_dim {config.codevector_dim} must be divisible "
  542. f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
  543. )
  544. # storage for codebook variables (codewords)
  545. self.codevectors = nn.Parameter(
  546. torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
  547. )
  548. self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
  549. # can be decayed for training
  550. self.temperature = 2
  551. @staticmethod
  552. def _compute_perplexity(probs):
  553. marginal_probs = probs.mean(dim=0)
  554. perplexity = torch.exp(-torch.sum(torch.xlogy(marginal_probs, marginal_probs), dim=-1)).sum()
  555. return perplexity
  556. def forward(self, hidden_states):
  557. batch_size, sequence_length, hidden_size = hidden_states.shape
  558. # project to codevector dim
  559. hidden_states = self.weight_proj(hidden_states)
  560. hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
  561. if self.training:
  562. # sample code vector probs via gumbel in differentiateable way
  563. codevector_probs = nn.functional.gumbel_softmax(
  564. hidden_states.float(), tau=self.temperature, hard=True
  565. ).type_as(hidden_states)
  566. # compute perplexity
  567. codevector_soft_dist = torch.softmax(
  568. hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
  569. )
  570. perplexity = self._compute_perplexity(codevector_soft_dist)
  571. else:
  572. # take argmax in non-differentiable way
  573. # comptute hard codevector distribution (one hot)
  574. codevector_idx = hidden_states.argmax(dim=-1)
  575. codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
  576. -1, codevector_idx.view(-1, 1), 1.0
  577. )
  578. codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
  579. perplexity = self._compute_perplexity(codevector_probs)
  580. codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
  581. # use probs to retrieve codevectors
  582. codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
  583. codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
  584. codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
  585. return codevectors, perplexity
  586. @auto_docstring
  587. class UniSpeechPreTrainedModel(PreTrainedModel):
  588. config: UniSpeechConfig
  589. base_model_prefix = "unispeech"
  590. main_input_name = "input_values"
  591. input_modalities = "audio"
  592. supports_gradient_checkpointing = True
  593. _supports_flash_attn = True
  594. _supports_sdpa = True
  595. _supports_flex_attn = True
  596. @torch.no_grad()
  597. def _init_weights(self, module):
  598. """Initialize the weights"""
  599. # gumbel softmax requires special init
  600. if isinstance(module, UniSpeechGumbelVectorQuantizer):
  601. init.normal_(module.weight_proj.weight, mean=0.0, std=1)
  602. init.zeros_(module.weight_proj.bias)
  603. init.uniform_(module.codevectors)
  604. elif isinstance(module, UniSpeechPositionalConvEmbedding):
  605. init.normal_(
  606. module.conv.weight,
  607. mean=0,
  608. std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
  609. )
  610. init.constant_(module.conv.bias, 0)
  611. elif isinstance(module, UniSpeechFeatureProjection):
  612. k = math.sqrt(1 / module.projection.in_features)
  613. init.uniform_(module.projection.weight, a=-k, b=k)
  614. init.uniform_(module.projection.bias, a=-k, b=k)
  615. elif isinstance(module, nn.Linear):
  616. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  617. if module.bias is not None:
  618. init.zeros_(module.bias)
  619. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  620. init.zeros_(module.bias)
  621. init.ones_(module.weight)
  622. elif isinstance(module, nn.Conv1d):
  623. init.kaiming_normal_(module.weight)
  624. if module.bias is not None:
  625. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  626. init.uniform_(module.bias, a=-k, b=k)
  627. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int):
  628. """
  629. Computes the output length of the convolutional layers
  630. """
  631. def _conv_out_length(input_length, kernel_size, stride):
  632. # 1D convolutional layer output length formula taken
  633. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  634. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  635. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  636. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  637. return input_lengths
  638. def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
  639. # Effectively attention_mask.sum(-1), but not inplace to be able to run
  640. # on inference mode.
  641. non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
  642. output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
  643. batch_size = attention_mask.shape[0]
  644. attention_mask = torch.zeros(
  645. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  646. )
  647. # these two operations makes sure that all values before the output lengths idxs are attended to
  648. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  649. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  650. return attention_mask
  651. def _compute_mask_indices(
  652. shape: tuple[int, int],
  653. mask_prob: float,
  654. mask_length: int,
  655. attention_mask: torch.LongTensor | None = None,
  656. min_masks: int = 0,
  657. ) -> np.ndarray:
  658. """
  659. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  660. ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  661. CPU as part of the preprocessing during training.
  662. Args:
  663. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  664. the first element is the batch size and the second element is the length of the axis to span.
  665. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  666. independently generated mask spans of length `mask_length` is computed by
  667. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  668. actual percentage will be smaller.
  669. mask_length: size of the mask
  670. min_masks: minimum number of masked spans
  671. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  672. each batch dimension.
  673. """
  674. batch_size, sequence_length = shape
  675. if mask_length < 1:
  676. raise ValueError("`mask_length` has to be bigger than 0.")
  677. if mask_length > sequence_length:
  678. raise ValueError(
  679. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  680. f" and `sequence_length`: {sequence_length}`"
  681. )
  682. # epsilon is used for probabilistic rounding
  683. epsilon = np.random.rand(1).item()
  684. def compute_num_masked_span(input_length):
  685. """Given input length, compute how many spans should be masked"""
  686. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  687. num_masked_span = max(num_masked_span, min_masks)
  688. # make sure num masked span <= sequence_length
  689. if num_masked_span * mask_length > sequence_length:
  690. num_masked_span = sequence_length // mask_length
  691. # make sure num_masked span is also <= input_length - (mask_length - 1)
  692. if input_length - (mask_length - 1) < num_masked_span:
  693. num_masked_span = max(input_length - (mask_length - 1), 0)
  694. return num_masked_span
  695. # compute number of masked spans in batch
  696. input_lengths = (
  697. attention_mask.detach().sum(-1).tolist()
  698. if attention_mask is not None
  699. else [sequence_length for _ in range(batch_size)]
  700. )
  701. # SpecAugment mask to fill
  702. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  703. spec_aug_mask_idxs = []
  704. max_num_masked_span = compute_num_masked_span(sequence_length)
  705. if max_num_masked_span == 0:
  706. return spec_aug_mask
  707. for input_length in input_lengths:
  708. # compute num of masked spans for this input
  709. num_masked_span = compute_num_masked_span(input_length)
  710. # get random indices to mask
  711. spec_aug_mask_idx = np.random.choice(
  712. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  713. )
  714. # pick first sampled index that will serve as a dummy index to pad vector
  715. # to ensure same dimension for all batches due to probabilistic rounding
  716. # Picking first sample just pads those vectors twice.
  717. if len(spec_aug_mask_idx) == 0:
  718. # this case can only happen if `input_length` is strictly smaller then
  719. # `sequence_length` in which case the last token has to be a padding
  720. # token which we can use as a dummy mask id
  721. dummy_mask_idx = sequence_length - 1
  722. else:
  723. dummy_mask_idx = spec_aug_mask_idx[0]
  724. spec_aug_mask_idx = np.concatenate(
  725. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  726. )
  727. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  728. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  729. # expand masked indices to masked spans
  730. spec_aug_mask_idxs = np.broadcast_to(
  731. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  732. )
  733. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  734. # add offset to the starting indexes so that indexes now create a span
  735. offsets = np.arange(mask_length)[None, None, :]
  736. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  737. batch_size, max_num_masked_span * mask_length
  738. )
  739. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  740. # ensure that we cannot have indices larger than sequence_length
  741. if spec_aug_mask_idxs.max() > sequence_length - 1:
  742. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  743. # scatter indices to mask
  744. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  745. return spec_aug_mask
  746. UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput
  747. @auto_docstring
  748. class UniSpeechModel(UniSpeechPreTrainedModel):
  749. def __init__(self, config: UniSpeechConfig):
  750. super().__init__(config)
  751. self.config = config
  752. self.feature_extractor = UniSpeechFeatureEncoder(config)
  753. self.feature_projection = UniSpeechFeatureProjection(config)
  754. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  755. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  756. if config.do_stable_layer_norm:
  757. self.encoder = UniSpeechEncoderStableLayerNorm(config)
  758. else:
  759. self.encoder = UniSpeechEncoder(config)
  760. # Initialize weights and apply final processing
  761. self.post_init()
  762. def _mask_hidden_states(
  763. self,
  764. hidden_states: torch.FloatTensor,
  765. mask_time_indices: torch.FloatTensor | None = None,
  766. attention_mask: torch.LongTensor | None = None,
  767. ):
  768. """
  769. Masks extracted features along time axis and/or along feature axis according to
  770. [SpecAugment](https://huggingface.co/papers/1904.08779).
  771. """
  772. # `config.apply_spec_augment` can set masking to False
  773. if not getattr(self.config, "apply_spec_augment", True):
  774. return hidden_states
  775. # generate indices & apply SpecAugment along time axis
  776. batch_size, sequence_length, hidden_size = hidden_states.size()
  777. if mask_time_indices is not None:
  778. # apply SpecAugment along time axis with given mask_time_indices
  779. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  780. elif self.config.mask_time_prob > 0 and self.training:
  781. mask_time_indices = _compute_mask_indices(
  782. (batch_size, sequence_length),
  783. mask_prob=self.config.mask_time_prob,
  784. mask_length=self.config.mask_time_length,
  785. attention_mask=attention_mask,
  786. min_masks=self.config.mask_time_min_masks,
  787. )
  788. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  789. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  790. if self.config.mask_feature_prob > 0 and self.training:
  791. # generate indices & apply SpecAugment along feature axis
  792. mask_feature_indices = _compute_mask_indices(
  793. (batch_size, hidden_size),
  794. mask_prob=self.config.mask_feature_prob,
  795. mask_length=self.config.mask_feature_length,
  796. min_masks=self.config.mask_feature_min_masks,
  797. )
  798. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  799. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  800. hidden_states[mask_feature_indices] = 0
  801. return hidden_states
  802. @auto_docstring
  803. def forward(
  804. self,
  805. input_values: torch.Tensor | None,
  806. attention_mask: torch.Tensor | None = None,
  807. mask_time_indices: torch.FloatTensor | None = None,
  808. output_attentions: bool | None = None,
  809. output_hidden_states: bool | None = None,
  810. return_dict: bool | None = None,
  811. **kwargs,
  812. ) -> tuple | UniSpeechBaseModelOutput:
  813. r"""
  814. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  815. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  816. masked extracted features in *config.proj_codevector_dim* space.
  817. """
  818. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  819. output_hidden_states = (
  820. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  821. )
  822. return_dict = return_dict if return_dict is not None else self.config.return_dict
  823. extract_features = self.feature_extractor(input_values)
  824. extract_features = extract_features.transpose(1, 2)
  825. if attention_mask is not None:
  826. # compute reduced attention_mask corresponding to feature vectors
  827. attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
  828. hidden_states, extract_features = self.feature_projection(extract_features)
  829. hidden_states = self._mask_hidden_states(
  830. hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
  831. )
  832. encoder_outputs = self.encoder(
  833. hidden_states,
  834. attention_mask=attention_mask,
  835. output_attentions=output_attentions,
  836. output_hidden_states=output_hidden_states,
  837. return_dict=return_dict,
  838. )
  839. hidden_states = encoder_outputs[0]
  840. if not return_dict:
  841. return (hidden_states, extract_features) + encoder_outputs[1:]
  842. return UniSpeechBaseModelOutput(
  843. last_hidden_state=hidden_states,
  844. extract_features=extract_features,
  845. hidden_states=encoder_outputs.hidden_states,
  846. attentions=encoder_outputs.attentions,
  847. )
  848. @auto_docstring(
  849. custom_intro="""
  850. UniSpeech Model with a vector-quantization module and ctc loss for pre-training.
  851. """
  852. )
  853. class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
  854. def __init__(self, config: UniSpeechConfig):
  855. super().__init__(config)
  856. self.unispeech = UniSpeechModel(config)
  857. self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
  858. self.quantizer = UniSpeechGumbelVectorQuantizer(config)
  859. self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
  860. self.project_hid = nn.Linear(config.proj_codevector_dim, config.hidden_size)
  861. self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes)
  862. self.dropout = nn.Dropout(config.final_dropout)
  863. # Initialize weights and apply final processing
  864. self.post_init()
  865. def set_gumbel_temperature(self, temperature: int):
  866. """
  867. Set the Gumbel softmax temperature to a given value. Only necessary for training
  868. """
  869. self.quantizer.temperature = temperature
  870. def freeze_feature_encoder(self):
  871. """
  872. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  873. not be updated during training.
  874. """
  875. self.unispeech.feature_extractor._freeze_parameters()
  876. @staticmethod
  877. def compute_contrastive_logits(
  878. target_features: torch.FloatTensor,
  879. negative_features: torch.FloatTensor,
  880. predicted_features: torch.FloatTensor,
  881. temperature: int = 1,
  882. ):
  883. """
  884. Compute logits for contrastive loss based using cosine similarity as the distance measure between
  885. `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
  886. """
  887. target_features = torch.cat([target_features, negative_features], dim=0)
  888. logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1)
  889. logits = logits.type_as(target_features)
  890. # apply temperature
  891. logits = logits / temperature
  892. return logits
  893. @auto_docstring
  894. def forward(
  895. self,
  896. input_values: torch.Tensor | None,
  897. attention_mask: torch.Tensor | None = None,
  898. output_attentions: bool | None = None,
  899. output_hidden_states: bool | None = None,
  900. return_dict: bool | None = None,
  901. **kwargs,
  902. ) -> tuple | UniSpeechForPreTrainingOutput:
  903. r"""
  904. Example:
  905. ```python
  906. >>> import torch
  907. >>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining
  908. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv")
  909. >>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
  910. >>> # TODO: Add full pretraining example
  911. ```"""
  912. return_dict = return_dict if return_dict is not None else self.config.return_dict
  913. outputs = self.unispeech(
  914. input_values,
  915. attention_mask=attention_mask,
  916. output_attentions=output_attentions,
  917. output_hidden_states=output_hidden_states,
  918. return_dict=return_dict,
  919. )
  920. transformer_features = outputs[0]
  921. # quantize all (unmasked) extracted features and project to final vq dim
  922. extract_features = self.dropout_features(outputs[1])
  923. quantized_features, codevector_perplexity = self.quantizer(extract_features)
  924. # project quantized features twice
  925. quantized_features = self.project_q(quantized_features.to(self.project_q.weight.dtype))
  926. quantized_features = self.project_hid(quantized_features)
  927. prob_replace_matrix = torch.empty(transformer_features.size(0), transformer_features.size(1)).fill_(
  928. self.config.replace_prob
  929. )
  930. prob_replace_matrix = prob_replace_matrix.transpose(0, 1)
  931. sampled_replace_matrix = torch.bernoulli(prob_replace_matrix).bool().to(transformer_features.device)
  932. sampled_replace_matrix = sampled_replace_matrix.transpose(0, 1)
  933. sampled_replace_matrix = sampled_replace_matrix.unsqueeze(-1)
  934. logits = transformer_features.masked_fill(sampled_replace_matrix, 0.0) + (
  935. quantized_features.masked_fill(~sampled_replace_matrix, 0.0)
  936. )
  937. # project to ctc units
  938. logits = self.dropout(logits)
  939. logits = self.ctc_proj(logits)
  940. # TODO(PVP) - add negative sampling & loss computation
  941. loss = None
  942. if not return_dict:
  943. if loss is not None:
  944. return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
  945. return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
  946. return UniSpeechForPreTrainingOutput(
  947. loss=loss,
  948. projected_states=transformer_features,
  949. projected_quantized_states=quantized_features,
  950. codevector_perplexity=codevector_perplexity,
  951. hidden_states=outputs.hidden_states,
  952. attentions=outputs.attentions,
  953. )
  954. _HIDDEN_STATES_START_POSITION = 2
  955. @auto_docstring(
  956. custom_intro="""
  957. UniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
  958. """
  959. )
  960. class UniSpeechForCTC(UniSpeechPreTrainedModel):
  961. def __init__(self, config, target_lang: str | None = None):
  962. r"""
  963. target_lang (`str`, *optional*):
  964. Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
  965. adapter.<lang>.bin. Only relevant when using an instance of [`UniSpeechForCTC`] with adapters. Uses 'eng' by
  966. default.
  967. """
  968. super().__init__(config)
  969. self.unispeech = UniSpeechModel(config)
  970. self.dropout = nn.Dropout(config.final_dropout)
  971. self.target_lang = target_lang
  972. if config.vocab_size is None:
  973. raise ValueError(
  974. f"You are trying to instantiate {self.__class__} with a configuration that "
  975. "does not define the vocabulary size of the language model head. Please "
  976. "instantiate the model as follows: `UniSpeechForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
  977. "or define `vocab_size` of your model's configuration."
  978. )
  979. output_hidden_size = (
  980. config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
  981. )
  982. self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
  983. # Initialize weights and apply final processing
  984. self.post_init()
  985. def tie_weights(self, **kwargs):
  986. """
  987. This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
  988. passing `target_lang=...` to `from_pretrained(...)`.
  989. This method is **not** supposed to be called by the user and is prone to be changed in the future.
  990. """
  991. if get_torch_context_manager_or_global_device() == torch.device("meta"):
  992. return
  993. # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
  994. # correctly load adapter layers for UniSpeech so that we do not have to introduce a new API to
  995. # [`PreTrainedModel`]. While slightly hacky, UniSpeech never has to tie input and output embeddings, so that it is
  996. # ok to repurpose this function here.
  997. target_lang = self.target_lang
  998. if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
  999. raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
  1000. elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
  1001. logger.info("By default `target_lang` is set to 'eng'.")
  1002. elif target_lang is not None:
  1003. self.load_adapter(target_lang, force_load=True)
  1004. def freeze_feature_encoder(self):
  1005. """
  1006. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1007. not be updated during training.
  1008. """
  1009. self.unispeech.feature_extractor._freeze_parameters()
  1010. def freeze_base_model(self):
  1011. """
  1012. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1013. be updated during training. Only the classification head will be updated.
  1014. """
  1015. for param in self.unispeech.parameters():
  1016. param.requires_grad = False
  1017. @auto_docstring
  1018. def forward(
  1019. self,
  1020. input_values: torch.Tensor | None,
  1021. attention_mask: torch.Tensor | None = None,
  1022. output_attentions: bool | None = None,
  1023. output_hidden_states: bool | None = None,
  1024. return_dict: bool | None = None,
  1025. labels: torch.Tensor | None = None,
  1026. **kwargs,
  1027. ) -> tuple | CausalLMOutput:
  1028. r"""
  1029. labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
  1030. Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
  1031. the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
  1032. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  1033. config.vocab_size - 1]`.
  1034. """
  1035. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1036. if labels is not None and labels.max() >= self.config.vocab_size:
  1037. raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
  1038. outputs = self.unispeech(
  1039. input_values,
  1040. attention_mask=attention_mask,
  1041. output_attentions=output_attentions,
  1042. output_hidden_states=output_hidden_states,
  1043. return_dict=return_dict,
  1044. )
  1045. hidden_states = outputs[0]
  1046. hidden_states = self.dropout(hidden_states)
  1047. logits = self.lm_head(hidden_states)
  1048. loss = None
  1049. if labels is not None:
  1050. # retrieve loss input_lengths from attention_mask
  1051. attention_mask = (
  1052. attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
  1053. )
  1054. input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  1055. # assuming that padded tokens are filled with -100
  1056. # when not being attended to
  1057. labels_mask = labels >= 0
  1058. target_lengths = labels_mask.sum(-1)
  1059. flattened_targets = labels.masked_select(labels_mask)
  1060. # ctc_loss doesn't support fp16
  1061. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  1062. with torch.backends.cudnn.flags(enabled=False):
  1063. loss = nn.functional.ctc_loss(
  1064. log_probs,
  1065. flattened_targets,
  1066. input_lengths,
  1067. target_lengths,
  1068. blank=self.config.pad_token_id,
  1069. reduction=self.config.ctc_loss_reduction,
  1070. zero_infinity=self.config.ctc_zero_infinity,
  1071. )
  1072. if not return_dict:
  1073. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1074. return ((loss,) + output) if loss is not None else output
  1075. return CausalLMOutput(
  1076. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  1077. )
  1078. @auto_docstring(
  1079. custom_intro="""
  1080. UniSpeech Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
  1081. SUPERB Keyword Spotting.
  1082. """
  1083. )
  1084. class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
  1085. def __init__(self, config):
  1086. super().__init__(config)
  1087. if hasattr(config, "add_adapter") and config.add_adapter:
  1088. raise ValueError(
  1089. "Sequence classification does not support the use of UniSpeech adapters (config.add_adapter=True)"
  1090. )
  1091. self.unispeech = UniSpeechModel(config)
  1092. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1093. if config.use_weighted_layer_sum:
  1094. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1095. self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
  1096. self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
  1097. # Initialize weights and apply final processing
  1098. self.post_init()
  1099. def freeze_feature_encoder(self):
  1100. """
  1101. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1102. not be updated during training.
  1103. """
  1104. self.unispeech.feature_extractor._freeze_parameters()
  1105. def freeze_base_model(self):
  1106. """
  1107. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1108. be updated during training. Only the classification head will be updated.
  1109. """
  1110. for param in self.unispeech.parameters():
  1111. param.requires_grad = False
  1112. @auto_docstring
  1113. def forward(
  1114. self,
  1115. input_values: torch.Tensor | None,
  1116. attention_mask: torch.Tensor | None = None,
  1117. output_attentions: bool | None = None,
  1118. output_hidden_states: bool | None = None,
  1119. return_dict: bool | None = None,
  1120. labels: torch.Tensor | None = None,
  1121. **kwargs,
  1122. ) -> tuple | SequenceClassifierOutput:
  1123. r"""
  1124. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1125. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1126. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1127. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1128. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1129. into a tensor of type `torch.FloatTensor`. See [`UniSpeechProcessor.__call__`] for details.
  1130. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1131. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1132. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1133. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1134. """
  1135. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1136. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1137. outputs = self.unispeech(
  1138. input_values,
  1139. attention_mask=attention_mask,
  1140. output_attentions=output_attentions,
  1141. output_hidden_states=output_hidden_states,
  1142. return_dict=return_dict,
  1143. )
  1144. if self.config.use_weighted_layer_sum:
  1145. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1146. hidden_states = torch.stack(hidden_states, dim=1)
  1147. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1148. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1149. else:
  1150. hidden_states = outputs[0]
  1151. hidden_states = self.projector(hidden_states)
  1152. if attention_mask is None:
  1153. pooled_output = hidden_states.mean(dim=1)
  1154. else:
  1155. padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
  1156. expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  1157. hidden_states[~expand_padding_mask] = 0.0
  1158. pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
  1159. logits = self.classifier(pooled_output)
  1160. loss = None
  1161. if labels is not None:
  1162. loss_fct = CrossEntropyLoss()
  1163. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1164. if not return_dict:
  1165. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1166. return ((loss,) + output) if loss is not None else output
  1167. return SequenceClassifierOutput(
  1168. loss=loss,
  1169. logits=logits,
  1170. hidden_states=outputs.hidden_states,
  1171. attentions=outputs.attentions,
  1172. )
  1173. __all__ = [
  1174. "UniSpeechForCTC",
  1175. "UniSpeechForPreTraining",
  1176. "UniSpeechForSequenceClassification",
  1177. "UniSpeechModel",
  1178. "UniSpeechPreTrainedModel",
  1179. ]