modeling_hubert.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/hubert/modular_hubert.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_hubert.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. import numpy as np
  22. import torch
  23. import torch.nn as nn
  24. from torch.nn import CrossEntropyLoss
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  28. from ...integrations.fsdp import is_fsdp_managed_module
  29. from ...masking_utils import create_bidirectional_mask
  30. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, get_torch_context_manager_or_global_device
  34. from ...processing_utils import Unpack
  35. from ...utils import TransformersKwargs, auto_docstring, logging
  36. from .configuration_hubert import HubertConfig
  37. logger = logging.get_logger(__name__)
  38. class HubertPositionalConvEmbedding(nn.Module):
  39. def __init__(self, config):
  40. super().__init__()
  41. self.conv = nn.Conv1d(
  42. config.hidden_size,
  43. config.hidden_size,
  44. kernel_size=config.num_conv_pos_embeddings,
  45. padding=config.num_conv_pos_embeddings // 2,
  46. groups=config.num_conv_pos_embedding_groups,
  47. )
  48. self.batch_norm = None
  49. if config.conv_pos_batch_norm:
  50. self.batch_norm = nn.BatchNorm1d(config.hidden_size)
  51. else:
  52. weight_norm = nn.utils.weight_norm
  53. if hasattr(nn.utils.parametrizations, "weight_norm"):
  54. weight_norm = nn.utils.parametrizations.weight_norm
  55. if is_deepspeed_zero3_enabled():
  56. import deepspeed
  57. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  58. self.conv = weight_norm(self.conv, name="weight", dim=2)
  59. if hasattr(self.conv, "parametrizations"):
  60. weight_g = self.conv.parametrizations.weight.original0
  61. weight_v = self.conv.parametrizations.weight.original1
  62. else:
  63. weight_g = self.conv.weight_g
  64. weight_v = self.conv.weight_v
  65. deepspeed.zero.register_external_parameter(self, weight_v)
  66. deepspeed.zero.register_external_parameter(self, weight_g)
  67. else:
  68. self.conv = weight_norm(self.conv, name="weight", dim=2)
  69. self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
  70. self.activation = ACT2FN[config.feat_extract_activation]
  71. def forward(self, hidden_states):
  72. hidden_states = hidden_states.transpose(1, 2)
  73. if self.batch_norm is not None:
  74. hidden_states = self.batch_norm(hidden_states)
  75. hidden_states = self.conv(hidden_states)
  76. hidden_states = self.padding(hidden_states)
  77. hidden_states = self.activation(hidden_states)
  78. hidden_states = hidden_states.transpose(1, 2)
  79. return hidden_states
  80. class HubertSamePadLayer(nn.Module):
  81. def __init__(self, num_conv_pos_embeddings):
  82. super().__init__()
  83. self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
  84. def forward(self, hidden_states):
  85. if self.num_pad_remove > 0:
  86. hidden_states = hidden_states[:, :, : -self.num_pad_remove]
  87. return hidden_states
  88. class HubertNoLayerNormConvLayer(GradientCheckpointingLayer):
  89. def __init__(self, config, layer_id=0):
  90. super().__init__()
  91. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  92. self.out_conv_dim = config.conv_dim[layer_id]
  93. self.conv = nn.Conv1d(
  94. self.in_conv_dim,
  95. self.out_conv_dim,
  96. kernel_size=config.conv_kernel[layer_id],
  97. stride=config.conv_stride[layer_id],
  98. bias=config.conv_bias,
  99. )
  100. self.activation = ACT2FN[config.feat_extract_activation]
  101. def forward(self, hidden_states):
  102. hidden_states = self.conv(hidden_states)
  103. hidden_states = self.activation(hidden_states)
  104. return hidden_states
  105. class HubertLayerNormConvLayer(GradientCheckpointingLayer):
  106. def __init__(self, config, layer_id=0):
  107. super().__init__()
  108. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  109. self.out_conv_dim = config.conv_dim[layer_id]
  110. self.conv = nn.Conv1d(
  111. self.in_conv_dim,
  112. self.out_conv_dim,
  113. kernel_size=config.conv_kernel[layer_id],
  114. stride=config.conv_stride[layer_id],
  115. bias=config.conv_bias,
  116. )
  117. self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
  118. self.activation = ACT2FN[config.feat_extract_activation]
  119. def forward(self, hidden_states):
  120. hidden_states = self.conv(hidden_states)
  121. hidden_states = hidden_states.transpose(-2, -1)
  122. hidden_states = self.layer_norm(hidden_states)
  123. hidden_states = hidden_states.transpose(-2, -1)
  124. hidden_states = self.activation(hidden_states)
  125. return hidden_states
  126. class HubertGroupNormConvLayer(GradientCheckpointingLayer):
  127. def __init__(self, config, layer_id=0):
  128. super().__init__()
  129. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  130. self.out_conv_dim = config.conv_dim[layer_id]
  131. self.conv = nn.Conv1d(
  132. self.in_conv_dim,
  133. self.out_conv_dim,
  134. kernel_size=config.conv_kernel[layer_id],
  135. stride=config.conv_stride[layer_id],
  136. bias=config.conv_bias,
  137. )
  138. self.activation = ACT2FN[config.feat_extract_activation]
  139. self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
  140. def forward(self, hidden_states):
  141. hidden_states = self.conv(hidden_states)
  142. hidden_states = self.layer_norm(hidden_states)
  143. hidden_states = self.activation(hidden_states)
  144. return hidden_states
  145. class HubertFeatureEncoder(nn.Module):
  146. """Construct the features from raw audio waveform"""
  147. def __init__(self, config):
  148. super().__init__()
  149. if config.feat_extract_norm == "group":
  150. conv_layers = [HubertGroupNormConvLayer(config, layer_id=0)] + [
  151. HubertNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
  152. ]
  153. elif config.feat_extract_norm == "layer":
  154. conv_layers = [HubertLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
  155. else:
  156. raise ValueError(
  157. f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
  158. )
  159. self.conv_layers = nn.ModuleList(conv_layers)
  160. self.gradient_checkpointing = False
  161. self._requires_grad = True
  162. def _freeze_parameters(self):
  163. for param in self.parameters():
  164. param.requires_grad = False
  165. self._requires_grad = False
  166. def forward(self, input_values):
  167. hidden_states = input_values[:, None]
  168. # make sure hidden_states require grad for gradient_checkpointing
  169. if self._requires_grad and self.training:
  170. hidden_states.requires_grad = True
  171. for conv_layer in self.conv_layers:
  172. hidden_states = conv_layer(hidden_states)
  173. return hidden_states
  174. class HubertFeatureProjection(nn.Module):
  175. def __init__(self, config):
  176. super().__init__()
  177. self.feat_proj_layer_norm = config.feat_proj_layer_norm
  178. if self.feat_proj_layer_norm:
  179. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  180. self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  181. self.dropout = nn.Dropout(config.feat_proj_dropout)
  182. def forward(self, hidden_states):
  183. # non-projected hidden states are needed for quantization
  184. if self.feat_proj_layer_norm:
  185. hidden_states = self.layer_norm(hidden_states)
  186. hidden_states = self.projection(hidden_states)
  187. hidden_states = self.dropout(hidden_states)
  188. return hidden_states
  189. def eager_attention_forward(
  190. module: nn.Module,
  191. query: torch.Tensor,
  192. key: torch.Tensor,
  193. value: torch.Tensor,
  194. attention_mask: torch.Tensor | None,
  195. scaling: float | None = None,
  196. dropout: float = 0.0,
  197. **kwargs: Unpack[TransformersKwargs],
  198. ):
  199. if scaling is None:
  200. scaling = query.size(-1) ** -0.5
  201. # Take the dot product between "query" and "key" to get the raw attention scores.
  202. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  203. if attention_mask is not None:
  204. attn_weights = attn_weights + attention_mask
  205. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  206. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  207. attn_output = torch.matmul(attn_weights, value)
  208. attn_output = attn_output.transpose(1, 2).contiguous()
  209. return attn_output, attn_weights
  210. class HubertAttention(nn.Module):
  211. """Multi-headed attention from 'Attention Is All You Need' paper"""
  212. def __init__(
  213. self,
  214. embed_dim: int,
  215. num_heads: int,
  216. dropout: float = 0.0,
  217. is_decoder: bool = False,
  218. bias: bool = True,
  219. is_causal: bool = False,
  220. config: HubertConfig | None = None,
  221. ):
  222. super().__init__()
  223. self.embed_dim = embed_dim
  224. self.num_heads = num_heads
  225. self.dropout = dropout
  226. self.head_dim = embed_dim // num_heads
  227. self.config = config
  228. if (self.head_dim * num_heads) != self.embed_dim:
  229. raise ValueError(
  230. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  231. f" and `num_heads`: {num_heads})."
  232. )
  233. self.scaling = self.head_dim**-0.5
  234. self.is_decoder = is_decoder
  235. self.is_causal = is_causal
  236. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  237. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  238. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  239. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  240. def forward(
  241. self,
  242. hidden_states: torch.Tensor,
  243. key_value_states: torch.Tensor | None = None,
  244. attention_mask: torch.Tensor | None = None,
  245. output_attentions: bool | None = False,
  246. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  247. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  248. **kwargs: Unpack[FlashAttentionKwargs],
  249. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  250. """Input shape: Batch x Time x Channel"""
  251. # if key_value_states are provided this layer is used as a cross-attention layer
  252. # for the decoder
  253. is_cross_attention = key_value_states is not None
  254. # determine input shapes
  255. input_shape = hidden_states.shape[:-1]
  256. hidden_shape = (*input_shape, -1, self.head_dim)
  257. # get query proj
  258. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  259. current_states = key_value_states if is_cross_attention else hidden_states
  260. kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
  261. key_states = self.k_proj(current_states).view(kv_shape).transpose(1, 2)
  262. value_states = self.v_proj(current_states).view(kv_shape).transpose(1, 2)
  263. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  264. self.config._attn_implementation, eager_attention_forward
  265. )
  266. attn_output, attn_weights = attention_interface(
  267. self,
  268. query_states,
  269. key_states,
  270. value_states,
  271. attention_mask,
  272. dropout=0.0 if not self.training else self.dropout,
  273. scaling=self.scaling,
  274. output_attentions=output_attentions,
  275. **kwargs,
  276. )
  277. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  278. attn_output = self.out_proj(attn_output)
  279. return attn_output, attn_weights, None
  280. class HubertFeedForward(nn.Module):
  281. def __init__(self, config):
  282. super().__init__()
  283. self.intermediate_dropout = nn.Dropout(config.activation_dropout)
  284. self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
  285. if isinstance(config.hidden_act, str):
  286. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  287. else:
  288. self.intermediate_act_fn = config.hidden_act
  289. self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
  290. self.output_dropout = nn.Dropout(config.hidden_dropout)
  291. def forward(self, hidden_states):
  292. hidden_states = self.intermediate_dense(hidden_states)
  293. hidden_states = self.intermediate_act_fn(hidden_states)
  294. hidden_states = self.intermediate_dropout(hidden_states)
  295. hidden_states = self.output_dense(hidden_states)
  296. hidden_states = self.output_dropout(hidden_states)
  297. return hidden_states
  298. class HubertEncoderLayer(GradientCheckpointingLayer):
  299. def __init__(self, config):
  300. super().__init__()
  301. self.attention = HubertAttention(
  302. embed_dim=config.hidden_size,
  303. num_heads=config.num_attention_heads,
  304. dropout=config.attention_dropout,
  305. is_decoder=False,
  306. config=config,
  307. )
  308. self.dropout = nn.Dropout(config.hidden_dropout)
  309. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  310. self.feed_forward = HubertFeedForward(config)
  311. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  312. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  313. attn_residual = hidden_states
  314. hidden_states, attn_weights, _ = self.attention(
  315. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  316. )
  317. hidden_states = self.dropout(hidden_states)
  318. hidden_states = attn_residual + hidden_states
  319. hidden_states = self.layer_norm(hidden_states)
  320. hidden_states = hidden_states + self.feed_forward(hidden_states)
  321. hidden_states = self.final_layer_norm(hidden_states)
  322. outputs = (hidden_states,)
  323. if output_attentions:
  324. outputs += (attn_weights,)
  325. return outputs
  326. class HubertEncoder(nn.Module):
  327. def __init__(self, config):
  328. super().__init__()
  329. self.config = config
  330. self.pos_conv_embed = HubertPositionalConvEmbedding(config)
  331. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  332. self.dropout = nn.Dropout(config.hidden_dropout)
  333. self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  334. self.gradient_checkpointing = False
  335. def forward(
  336. self,
  337. hidden_states: torch.tensor,
  338. attention_mask: torch.Tensor | None = None,
  339. output_attentions: bool = False,
  340. output_hidden_states: bool = False,
  341. return_dict: bool = True,
  342. ):
  343. all_hidden_states = () if output_hidden_states else None
  344. all_self_attentions = () if output_attentions else None
  345. if attention_mask is not None:
  346. # make sure padded tokens output 0
  347. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  348. hidden_states[~expand_attention_mask] = 0
  349. attention_mask = create_bidirectional_mask(
  350. config=self.config,
  351. inputs_embeds=hidden_states,
  352. attention_mask=attention_mask,
  353. )
  354. position_embeddings = self.pos_conv_embed(hidden_states)
  355. hidden_states = hidden_states + position_embeddings.to(hidden_states.device)
  356. hidden_states = self.layer_norm(hidden_states)
  357. hidden_states = self.dropout(hidden_states)
  358. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  359. for layer in self.layers:
  360. if output_hidden_states:
  361. all_hidden_states = all_hidden_states + (hidden_states,)
  362. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  363. dropout_probability = torch.rand([])
  364. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  365. if not skip_the_layer or synced_gpus:
  366. # under fsdp or deepspeed zero3 all gpus must run in sync
  367. layer_outputs = layer(
  368. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  369. )
  370. hidden_states = layer_outputs[0]
  371. if skip_the_layer:
  372. layer_outputs = (None, None)
  373. if output_attentions:
  374. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  375. if output_hidden_states:
  376. all_hidden_states = all_hidden_states + (hidden_states,)
  377. if not return_dict:
  378. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  379. return BaseModelOutput(
  380. last_hidden_state=hidden_states,
  381. hidden_states=all_hidden_states,
  382. attentions=all_self_attentions,
  383. )
  384. class HubertAttnAdapterLayer(nn.Module):
  385. def __init__(self, config):
  386. """
  387. Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
  388. up training throughput.
  389. """
  390. super().__init__()
  391. self.input_dim = config.adapter_attn_dim
  392. self.hidden_dim = config.hidden_size
  393. self.norm = nn.LayerNorm(self.hidden_dim)
  394. self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
  395. self.act_fn = nn.ReLU()
  396. self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
  397. def forward(self, hidden_states: torch.FloatTensor):
  398. hidden_states = self.norm(hidden_states)
  399. hidden_states = self.linear_1(hidden_states)
  400. hidden_states = self.act_fn(hidden_states)
  401. hidden_states = self.linear_2(hidden_states)
  402. return hidden_states
  403. class HubertEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
  404. def __init__(self, config):
  405. super().__init__()
  406. self.attention = HubertAttention(
  407. embed_dim=config.hidden_size,
  408. num_heads=config.num_attention_heads,
  409. dropout=config.attention_dropout,
  410. is_decoder=False,
  411. config=config,
  412. )
  413. self.dropout = nn.Dropout(config.hidden_dropout)
  414. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  415. self.feed_forward = HubertFeedForward(config)
  416. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  417. if getattr(config, "adapter_attn_dim", None) is not None:
  418. self.adapter_layer = HubertAttnAdapterLayer(config)
  419. else:
  420. self.adapter_layer = None
  421. def forward(
  422. self,
  423. hidden_states: torch.Tensor,
  424. attention_mask: torch.Tensor | None = None,
  425. output_attentions: bool = False,
  426. ):
  427. attn_residual = hidden_states
  428. hidden_states = self.layer_norm(hidden_states)
  429. hidden_states, attn_weights, _ = self.attention(
  430. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  431. )
  432. hidden_states = self.dropout(hidden_states)
  433. hidden_states = attn_residual + hidden_states
  434. hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
  435. if self.adapter_layer is not None:
  436. hidden_states = hidden_states + self.adapter_layer(hidden_states)
  437. outputs = (hidden_states,)
  438. if output_attentions:
  439. outputs += (attn_weights,)
  440. return outputs
  441. class HubertEncoderStableLayerNorm(nn.Module):
  442. def __init__(self, config):
  443. super().__init__()
  444. self.config = config
  445. self.pos_conv_embed = HubertPositionalConvEmbedding(config)
  446. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  447. self.dropout = nn.Dropout(config.hidden_dropout)
  448. self.layers = nn.ModuleList(
  449. [HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
  450. )
  451. self.gradient_checkpointing = False
  452. def forward(
  453. self,
  454. hidden_states,
  455. attention_mask=None,
  456. output_attentions=False,
  457. output_hidden_states=False,
  458. return_dict=True,
  459. ):
  460. all_hidden_states = () if output_hidden_states else None
  461. all_self_attentions = () if output_attentions else None
  462. if attention_mask is not None:
  463. # make sure padded tokens output 0
  464. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  465. hidden_states[~expand_attention_mask] = 0
  466. attention_mask = create_bidirectional_mask(
  467. config=self.config,
  468. inputs_embeds=hidden_states,
  469. attention_mask=attention_mask,
  470. )
  471. position_embeddings = self.pos_conv_embed(hidden_states)
  472. hidden_states = hidden_states + position_embeddings
  473. hidden_states = self.dropout(hidden_states)
  474. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  475. for layer in self.layers:
  476. if output_hidden_states:
  477. all_hidden_states = all_hidden_states + (hidden_states,)
  478. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  479. dropout_probability = torch.rand([])
  480. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  481. if not skip_the_layer or synced_gpus:
  482. # under fsdp or deepspeed zero3 all gpus must run in sync
  483. # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
  484. layer_outputs = layer(
  485. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  486. )
  487. hidden_states = layer_outputs[0]
  488. if skip_the_layer:
  489. layer_outputs = (None, None)
  490. if output_attentions:
  491. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  492. hidden_states = self.layer_norm(hidden_states)
  493. if output_hidden_states:
  494. all_hidden_states = all_hidden_states + (hidden_states,)
  495. if not return_dict:
  496. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  497. return BaseModelOutput(
  498. last_hidden_state=hidden_states,
  499. hidden_states=all_hidden_states,
  500. attentions=all_self_attentions,
  501. )
  502. @auto_docstring
  503. class HubertPreTrainedModel(PreTrainedModel):
  504. config: HubertConfig
  505. base_model_prefix = "hubert"
  506. main_input_name = "input_values"
  507. input_modalities = "audio"
  508. _no_split_modules = ["HubertEncoderLayer", "ParametrizedConv1d"]
  509. supports_gradient_checkpointing = True
  510. _supports_flash_attn = True
  511. _supports_sdpa = True
  512. _supports_flex_attn = True
  513. @torch.no_grad()
  514. def _init_weights(self, module):
  515. """Initialize the weights"""
  516. if isinstance(module, nn.Linear):
  517. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  518. if module.bias is not None:
  519. init.zeros_(module.bias)
  520. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
  521. init.zeros_(module.bias)
  522. init.ones_(module.weight)
  523. if getattr(module, "running_mean", None) is not None:
  524. init.zeros_(module.running_mean)
  525. init.ones_(module.running_var)
  526. init.zeros_(module.num_batches_tracked)
  527. elif isinstance(module, nn.Conv1d):
  528. if is_deepspeed_zero3_enabled():
  529. import deepspeed
  530. if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
  531. with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
  532. init.kaiming_normal_(module.weight)
  533. else:
  534. with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
  535. init.kaiming_normal_(module.weight)
  536. else:
  537. init.kaiming_normal_(module.weight)
  538. if module.bias is not None:
  539. init.zeros_(module.bias)
  540. elif isinstance(module, HubertModel):
  541. if hasattr(module, "masked_spec_embed"):
  542. init.uniform_(module.masked_spec_embed)
  543. elif isinstance(module, HubertForSequenceClassification):
  544. if hasattr(module, "layer_weights"):
  545. init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1))
  546. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int):
  547. """
  548. Computes the output length of the convolutional layers
  549. """
  550. def _conv_out_length(input_length, kernel_size, stride):
  551. # 1D convolutional layer output length formula taken
  552. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  553. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  554. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  555. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  556. return input_lengths
  557. def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
  558. output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  559. batch_size = attention_mask.shape[0]
  560. attention_mask = torch.zeros(
  561. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  562. )
  563. # these two operations makes sure that all values before the output lengths idxs are attended to
  564. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  565. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  566. return attention_mask
  567. def _compute_mask_indices(
  568. shape: tuple[int, int],
  569. mask_prob: float,
  570. mask_length: int,
  571. attention_mask: torch.LongTensor | None = None,
  572. min_masks: int = 0,
  573. ) -> np.ndarray:
  574. """
  575. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  576. ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  577. CPU as part of the preprocessing during training.
  578. Args:
  579. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  580. the first element is the batch size and the second element is the length of the axis to span.
  581. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  582. independently generated mask spans of length `mask_length` is computed by
  583. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  584. actual percentage will be smaller.
  585. mask_length: size of the mask
  586. min_masks: minimum number of masked spans
  587. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  588. each batch dimension.
  589. """
  590. batch_size, sequence_length = shape
  591. if mask_length < 1:
  592. raise ValueError("`mask_length` has to be bigger than 0.")
  593. if mask_length > sequence_length:
  594. raise ValueError(
  595. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  596. f" and `sequence_length`: {sequence_length}`"
  597. )
  598. # epsilon is used for probabilistic rounding
  599. epsilon = np.random.rand(1).item()
  600. def compute_num_masked_span(input_length):
  601. """Given input length, compute how many spans should be masked"""
  602. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  603. num_masked_span = max(num_masked_span, min_masks)
  604. # make sure num masked span <= sequence_length
  605. if num_masked_span * mask_length > sequence_length:
  606. num_masked_span = sequence_length // mask_length
  607. # make sure num_masked span is also <= input_length - (mask_length - 1)
  608. if input_length - (mask_length - 1) < num_masked_span:
  609. num_masked_span = max(input_length - (mask_length - 1), 0)
  610. return num_masked_span
  611. # compute number of masked spans in batch
  612. input_lengths = (
  613. attention_mask.detach().sum(-1).tolist()
  614. if attention_mask is not None
  615. else [sequence_length for _ in range(batch_size)]
  616. )
  617. # SpecAugment mask to fill
  618. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  619. spec_aug_mask_idxs = []
  620. max_num_masked_span = compute_num_masked_span(sequence_length)
  621. if max_num_masked_span == 0:
  622. return spec_aug_mask
  623. for input_length in input_lengths:
  624. # compute num of masked spans for this input
  625. num_masked_span = compute_num_masked_span(input_length)
  626. # get random indices to mask
  627. spec_aug_mask_idx = np.random.choice(
  628. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  629. )
  630. # pick first sampled index that will serve as a dummy index to pad vector
  631. # to ensure same dimension for all batches due to probabilistic rounding
  632. # Picking first sample just pads those vectors twice.
  633. if len(spec_aug_mask_idx) == 0:
  634. # this case can only happen if `input_length` is strictly smaller then
  635. # `sequence_length` in which case the last token has to be a padding
  636. # token which we can use as a dummy mask id
  637. dummy_mask_idx = sequence_length - 1
  638. else:
  639. dummy_mask_idx = spec_aug_mask_idx[0]
  640. spec_aug_mask_idx = np.concatenate(
  641. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  642. )
  643. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  644. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  645. # expand masked indices to masked spans
  646. spec_aug_mask_idxs = np.broadcast_to(
  647. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  648. )
  649. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  650. # add offset to the starting indexes so that indexes now create a span
  651. offsets = np.arange(mask_length)[None, None, :]
  652. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  653. batch_size, max_num_masked_span * mask_length
  654. )
  655. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  656. # ensure that we cannot have indices larger than sequence_length
  657. if spec_aug_mask_idxs.max() > sequence_length - 1:
  658. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  659. # scatter indices to mask
  660. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  661. return spec_aug_mask
  662. @auto_docstring
  663. class HubertModel(HubertPreTrainedModel):
  664. def __init__(self, config: HubertConfig):
  665. super().__init__(config)
  666. self.config = config
  667. self.feature_extractor = HubertFeatureEncoder(config)
  668. self.feature_projection = HubertFeatureProjection(config)
  669. # model only needs masking vector if mask prob is > 0.0
  670. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  671. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  672. if config.do_stable_layer_norm:
  673. self.encoder = HubertEncoderStableLayerNorm(config)
  674. else:
  675. self.encoder = HubertEncoder(config)
  676. # Initialize weights and apply final processing
  677. self.post_init()
  678. def _mask_hidden_states(
  679. self,
  680. hidden_states: torch.FloatTensor,
  681. mask_time_indices: torch.FloatTensor | None = None,
  682. attention_mask: torch.LongTensor | None = None,
  683. ):
  684. """
  685. Masks extracted features along time axis and/or along feature axis according to
  686. [SpecAugment](https://huggingface.co/papers/1904.08779).
  687. """
  688. # `config.apply_spec_augment` can set masking to False
  689. if not getattr(self.config, "apply_spec_augment", True):
  690. return hidden_states
  691. # generate indices & apply SpecAugment along time axis
  692. batch_size, sequence_length, hidden_size = hidden_states.size()
  693. if mask_time_indices is not None:
  694. # apply SpecAugment along time axis with given mask_time_indices
  695. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  696. elif self.config.mask_time_prob > 0 and self.training:
  697. mask_time_indices = _compute_mask_indices(
  698. (batch_size, sequence_length),
  699. mask_prob=self.config.mask_time_prob,
  700. mask_length=self.config.mask_time_length,
  701. attention_mask=attention_mask,
  702. min_masks=self.config.mask_time_min_masks,
  703. )
  704. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  705. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  706. if self.config.mask_feature_prob > 0 and self.training:
  707. # generate indices & apply SpecAugment along feature axis
  708. mask_feature_indices = _compute_mask_indices(
  709. (batch_size, hidden_size),
  710. mask_prob=self.config.mask_feature_prob,
  711. mask_length=self.config.mask_feature_length,
  712. min_masks=self.config.mask_feature_min_masks,
  713. )
  714. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  715. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  716. hidden_states[mask_feature_indices] = 0
  717. return hidden_states
  718. @auto_docstring
  719. def forward(
  720. self,
  721. input_values: torch.Tensor | None,
  722. attention_mask: torch.Tensor | None = None,
  723. mask_time_indices: torch.FloatTensor | None = None,
  724. output_attentions: bool | None = None,
  725. output_hidden_states: bool | None = None,
  726. return_dict: bool | None = None,
  727. **kwargs,
  728. ) -> tuple | BaseModelOutput:
  729. r"""
  730. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  731. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  732. masked extracted features in *config.proj_codevector_dim* space.
  733. Example:
  734. ```python
  735. >>> from transformers import AutoProcessor, HubertModel
  736. >>> from datasets import load_dataset
  737. >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
  738. >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
  739. >>> def map_to_array(example):
  740. ... example["speech"] = example["audio"]["array"]
  741. ... return example
  742. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  743. >>> ds = ds.map(map_to_array)
  744. >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
  745. >>> hidden_states = model(input_values).last_hidden_state
  746. ```"""
  747. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  748. output_hidden_states = (
  749. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  750. )
  751. return_dict = return_dict if return_dict is not None else self.config.return_dict
  752. extract_features = self.feature_extractor(input_values)
  753. extract_features = extract_features.transpose(1, 2)
  754. if attention_mask is not None:
  755. # compute reduced attention_mask corresponding to feature vectors
  756. attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
  757. hidden_states = self.feature_projection(extract_features)
  758. hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
  759. encoder_outputs = self.encoder(
  760. hidden_states,
  761. attention_mask=attention_mask,
  762. output_attentions=output_attentions,
  763. output_hidden_states=output_hidden_states,
  764. return_dict=return_dict,
  765. )
  766. hidden_states = encoder_outputs[0]
  767. if not return_dict:
  768. return (hidden_states,) + encoder_outputs[1:]
  769. return BaseModelOutput(
  770. last_hidden_state=hidden_states,
  771. hidden_states=encoder_outputs.hidden_states,
  772. attentions=encoder_outputs.attentions,
  773. )
  774. _HIDDEN_STATES_START_POSITION = 1
  775. @auto_docstring(
  776. custom_intro="""
  777. Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
  778. """
  779. )
  780. class HubertForCTC(HubertPreTrainedModel):
  781. def __init__(self, config, target_lang: str | None = None):
  782. r"""
  783. target_lang (`str`, *optional*):
  784. Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
  785. adapter.<lang>.bin. Only relevant when using an instance of [`HubertForCTC`] with adapters. Uses 'eng' by
  786. default.
  787. """
  788. super().__init__(config)
  789. self.hubert = HubertModel(config)
  790. self.dropout = nn.Dropout(config.final_dropout)
  791. self.target_lang = target_lang
  792. if config.vocab_size is None:
  793. raise ValueError(
  794. f"You are trying to instantiate {self.__class__} with a configuration that "
  795. "does not define the vocabulary size of the language model head. Please "
  796. "instantiate the model as follows: `HubertForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
  797. "or define `vocab_size` of your model's configuration."
  798. )
  799. output_hidden_size = (
  800. config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
  801. )
  802. self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
  803. # Initialize weights and apply final processing
  804. self.post_init()
  805. def tie_weights(self, **kwargs):
  806. """
  807. This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
  808. passing `target_lang=...` to `from_pretrained(...)`.
  809. This method is **not** supposed to be called by the user and is prone to be changed in the future.
  810. """
  811. if get_torch_context_manager_or_global_device() == torch.device("meta"):
  812. return
  813. # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
  814. # correctly load adapter layers for Hubert so that we do not have to introduce a new API to
  815. # [`PreTrainedModel`]. While slightly hacky, Hubert never has to tie input and output embeddings, so that it is
  816. # ok to repurpose this function here.
  817. target_lang = self.target_lang
  818. if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
  819. raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
  820. elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
  821. logger.info("By default `target_lang` is set to 'eng'.")
  822. elif target_lang is not None:
  823. self.load_adapter(target_lang, force_load=True)
  824. def freeze_feature_encoder(self):
  825. """
  826. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  827. not be updated during training.
  828. """
  829. self.hubert.feature_extractor._freeze_parameters()
  830. def freeze_base_model(self):
  831. """
  832. Calling this function will disable the gradient computation for the base model so that its parameters will not
  833. be updated during training. Only the classification head will be updated.
  834. """
  835. for param in self.hubert.parameters():
  836. param.requires_grad = False
  837. @auto_docstring
  838. def forward(
  839. self,
  840. input_values: torch.Tensor | None,
  841. attention_mask: torch.Tensor | None = None,
  842. output_attentions: bool | None = None,
  843. output_hidden_states: bool | None = None,
  844. return_dict: bool | None = None,
  845. labels: torch.Tensor | None = None,
  846. **kwargs,
  847. ) -> tuple | CausalLMOutput:
  848. r"""
  849. labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
  850. Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
  851. the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
  852. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  853. config.vocab_size - 1]`.
  854. """
  855. return_dict = return_dict if return_dict is not None else self.config.return_dict
  856. if labels is not None and labels.max() >= self.config.vocab_size:
  857. raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
  858. outputs = self.hubert(
  859. input_values,
  860. attention_mask=attention_mask,
  861. output_attentions=output_attentions,
  862. output_hidden_states=output_hidden_states,
  863. return_dict=return_dict,
  864. )
  865. hidden_states = outputs[0]
  866. hidden_states = self.dropout(hidden_states)
  867. logits = self.lm_head(hidden_states)
  868. loss = None
  869. if labels is not None:
  870. # retrieve loss input_lengths from attention_mask
  871. attention_mask = (
  872. attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
  873. )
  874. input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  875. # assuming that padded tokens are filled with -100
  876. # when not being attended to
  877. labels_mask = labels >= 0
  878. target_lengths = labels_mask.sum(-1)
  879. flattened_targets = labels.masked_select(labels_mask)
  880. # ctc_loss doesn't support fp16
  881. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  882. with torch.backends.cudnn.flags(enabled=False):
  883. loss = nn.functional.ctc_loss(
  884. log_probs,
  885. flattened_targets,
  886. input_lengths,
  887. target_lengths,
  888. blank=self.config.pad_token_id,
  889. reduction=self.config.ctc_loss_reduction,
  890. zero_infinity=self.config.ctc_zero_infinity,
  891. )
  892. if not return_dict:
  893. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  894. return ((loss,) + output) if loss is not None else output
  895. return CausalLMOutput(
  896. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  897. )
  898. @auto_docstring(
  899. custom_intro="""
  900. Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
  901. SUPERB Keyword Spotting.
  902. """
  903. )
  904. class HubertForSequenceClassification(HubertPreTrainedModel):
  905. def __init__(self, config):
  906. super().__init__(config)
  907. if hasattr(config, "add_adapter") and config.add_adapter:
  908. raise ValueError(
  909. "Sequence classification does not support the use of Hubert adapters (config.add_adapter=True)"
  910. )
  911. self.hubert = HubertModel(config)
  912. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  913. if config.use_weighted_layer_sum:
  914. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  915. self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
  916. self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
  917. # Initialize weights and apply final processing
  918. self.post_init()
  919. def freeze_feature_encoder(self):
  920. """
  921. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  922. not be updated during training.
  923. """
  924. self.hubert.feature_extractor._freeze_parameters()
  925. def freeze_base_model(self):
  926. """
  927. Calling this function will disable the gradient computation for the base model so that its parameters will not
  928. be updated during training. Only the classification head will be updated.
  929. """
  930. for param in self.hubert.parameters():
  931. param.requires_grad = False
  932. @auto_docstring
  933. def forward(
  934. self,
  935. input_values: torch.Tensor | None,
  936. attention_mask: torch.Tensor | None = None,
  937. output_attentions: bool | None = None,
  938. output_hidden_states: bool | None = None,
  939. return_dict: bool | None = None,
  940. labels: torch.Tensor | None = None,
  941. **kwargs,
  942. ) -> tuple | SequenceClassifierOutput:
  943. r"""
  944. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  945. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  946. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  947. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  948. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  949. into a tensor of type `torch.FloatTensor`. See [`HubertProcessor.__call__`] for details.
  950. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  951. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  952. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  953. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  954. """
  955. return_dict = return_dict if return_dict is not None else self.config.return_dict
  956. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  957. outputs = self.hubert(
  958. input_values,
  959. attention_mask=attention_mask,
  960. output_attentions=output_attentions,
  961. output_hidden_states=output_hidden_states,
  962. return_dict=return_dict,
  963. )
  964. if self.config.use_weighted_layer_sum:
  965. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  966. hidden_states = torch.stack(hidden_states, dim=1)
  967. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  968. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  969. else:
  970. hidden_states = outputs[0]
  971. hidden_states = self.projector(hidden_states)
  972. if attention_mask is None:
  973. pooled_output = hidden_states.mean(dim=1)
  974. else:
  975. padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
  976. expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  977. hidden_states[~expand_padding_mask] = 0.0
  978. pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
  979. logits = self.classifier(pooled_output)
  980. loss = None
  981. if labels is not None:
  982. loss_fct = CrossEntropyLoss()
  983. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  984. if not return_dict:
  985. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  986. return ((loss,) + output) if loss is not None else output
  987. return SequenceClassifierOutput(
  988. loss=loss,
  989. logits=logits,
  990. hidden_states=outputs.hidden_states,
  991. attentions=outputs.attentions,
  992. )
  993. __all__ = ["HubertForCTC", "HubertForSequenceClassification", "HubertModel", "HubertPreTrainedModel"]