modeling_sew.py 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/sew/modular_sew.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_sew.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. import numpy as np
  23. import torch
  24. from torch import nn
  25. from torch.nn import CrossEntropyLoss
  26. from ... import initialization as init
  27. from ...activations import ACT2FN
  28. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  29. from ...integrations.fsdp import is_fsdp_managed_module
  30. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, get_torch_context_manager_or_global_device
  34. from ...processing_utils import Unpack
  35. from ...utils import TransformersKwargs, auto_docstring, logging
  36. from ...utils.generic import is_flash_attention_requested
  37. from .configuration_sew import SEWConfig
  38. logger = logging.get_logger(__name__)
  39. class SEWNoLayerNormConvLayer(GradientCheckpointingLayer):
  40. def __init__(self, config, layer_id=0):
  41. super().__init__()
  42. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  43. self.out_conv_dim = config.conv_dim[layer_id]
  44. self.conv = nn.Conv1d(
  45. self.in_conv_dim,
  46. self.out_conv_dim,
  47. kernel_size=config.conv_kernel[layer_id],
  48. stride=config.conv_stride[layer_id],
  49. bias=config.conv_bias,
  50. )
  51. self.activation = ACT2FN[config.feat_extract_activation]
  52. def forward(self, hidden_states):
  53. hidden_states = self.conv(hidden_states)
  54. hidden_states = self.activation(hidden_states)
  55. return hidden_states
  56. class SEWLayerNormConvLayer(GradientCheckpointingLayer):
  57. def __init__(self, config, layer_id=0):
  58. super().__init__()
  59. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  60. self.out_conv_dim = config.conv_dim[layer_id]
  61. self.conv = nn.Conv1d(
  62. self.in_conv_dim,
  63. self.out_conv_dim,
  64. kernel_size=config.conv_kernel[layer_id],
  65. stride=config.conv_stride[layer_id],
  66. bias=config.conv_bias,
  67. )
  68. self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
  69. self.activation = ACT2FN[config.feat_extract_activation]
  70. def forward(self, hidden_states):
  71. hidden_states = self.conv(hidden_states)
  72. hidden_states = hidden_states.transpose(-2, -1)
  73. hidden_states = self.layer_norm(hidden_states)
  74. hidden_states = hidden_states.transpose(-2, -1)
  75. hidden_states = self.activation(hidden_states)
  76. return hidden_states
  77. class SEWGroupNormConvLayer(GradientCheckpointingLayer):
  78. def __init__(self, config, layer_id=0):
  79. super().__init__()
  80. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  81. self.out_conv_dim = config.conv_dim[layer_id]
  82. self.conv = nn.Conv1d(
  83. self.in_conv_dim,
  84. self.out_conv_dim,
  85. kernel_size=config.conv_kernel[layer_id],
  86. stride=config.conv_stride[layer_id],
  87. bias=config.conv_bias,
  88. )
  89. self.activation = ACT2FN[config.feat_extract_activation]
  90. self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
  91. def forward(self, hidden_states):
  92. hidden_states = self.conv(hidden_states)
  93. hidden_states = self.layer_norm(hidden_states)
  94. hidden_states = self.activation(hidden_states)
  95. return hidden_states
  96. class SEWPositionalConvEmbedding(nn.Module):
  97. def __init__(self, config):
  98. super().__init__()
  99. self.conv = nn.Conv1d(
  100. config.hidden_size,
  101. config.hidden_size,
  102. kernel_size=config.num_conv_pos_embeddings,
  103. padding=config.num_conv_pos_embeddings // 2,
  104. groups=config.num_conv_pos_embedding_groups,
  105. stride=config.squeeze_factor,
  106. )
  107. weight_norm = nn.utils.weight_norm
  108. if hasattr(nn.utils.parametrizations, "weight_norm"):
  109. weight_norm = nn.utils.parametrizations.weight_norm
  110. if is_deepspeed_zero3_enabled():
  111. import deepspeed
  112. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  113. self.conv = weight_norm(self.conv, name="weight", dim=2)
  114. if hasattr(self.conv, "parametrizations"):
  115. weight_g = self.conv.parametrizations.weight.original0
  116. weight_v = self.conv.parametrizations.weight.original1
  117. else:
  118. weight_g = self.conv.weight_g
  119. weight_v = self.conv.weight_v
  120. deepspeed.zero.register_external_parameter(self, weight_v)
  121. deepspeed.zero.register_external_parameter(self, weight_g)
  122. else:
  123. self.conv = weight_norm(self.conv, name="weight", dim=2)
  124. self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings)
  125. self.activation = ACT2FN[config.feat_extract_activation]
  126. def forward(self, hidden_states):
  127. hidden_states = self.conv(hidden_states)
  128. hidden_states = self.padding(hidden_states)
  129. hidden_states = self.activation(hidden_states)
  130. return hidden_states
  131. class SEWSamePadLayer(nn.Module):
  132. def __init__(self, num_conv_pos_embeddings):
  133. super().__init__()
  134. self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
  135. def forward(self, hidden_states):
  136. if self.num_pad_remove > 0:
  137. hidden_states = hidden_states[:, :, : -self.num_pad_remove]
  138. return hidden_states
  139. class SEWUpsampling(nn.Module):
  140. def __init__(self, config):
  141. super().__init__()
  142. self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor)
  143. self.activation = ACT2FN[config.feat_extract_activation]
  144. self.squeeze_factor = config.squeeze_factor
  145. def forward(self, hidden_states):
  146. hidden_states = self.projection(hidden_states)
  147. hidden_states = self.activation(hidden_states)
  148. if self.squeeze_factor > 1:
  149. # transform embedding channels to sequence length
  150. bsz, src_len, src_embed_dim = hidden_states.size()
  151. tgt_len = src_len * self.squeeze_factor
  152. tgt_embed_dim = src_embed_dim // self.squeeze_factor
  153. hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim)
  154. hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim)
  155. return hidden_states
  156. class SEWFeatureEncoder(nn.Module):
  157. """Construct the features from raw audio waveform"""
  158. def __init__(self, config):
  159. super().__init__()
  160. if config.feat_extract_norm == "group":
  161. conv_layers = [SEWGroupNormConvLayer(config, layer_id=0)] + [
  162. SEWNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
  163. ]
  164. elif config.feat_extract_norm == "layer":
  165. conv_layers = [SEWLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
  166. else:
  167. raise ValueError(
  168. f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
  169. )
  170. self.conv_layers = nn.ModuleList(conv_layers)
  171. self.gradient_checkpointing = False
  172. self._requires_grad = True
  173. def _freeze_parameters(self):
  174. for param in self.parameters():
  175. param.requires_grad = False
  176. self._requires_grad = False
  177. def forward(self, input_values):
  178. hidden_states = input_values[:, None]
  179. # make sure hidden_states require grad for gradient_checkpointing
  180. if self._requires_grad and self.training:
  181. hidden_states.requires_grad = True
  182. for conv_layer in self.conv_layers:
  183. hidden_states = conv_layer(hidden_states)
  184. return hidden_states
  185. def eager_attention_forward(
  186. module: nn.Module,
  187. query: torch.Tensor,
  188. key: torch.Tensor,
  189. value: torch.Tensor,
  190. attention_mask: torch.Tensor | None,
  191. scaling: float | None = None,
  192. dropout: float = 0.0,
  193. **kwargs: Unpack[TransformersKwargs],
  194. ):
  195. if scaling is None:
  196. scaling = query.size(-1) ** -0.5
  197. # Take the dot product between "query" and "key" to get the raw attention scores.
  198. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  199. if attention_mask is not None:
  200. attn_weights = attn_weights + attention_mask
  201. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  202. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  203. attn_output = torch.matmul(attn_weights, value)
  204. attn_output = attn_output.transpose(1, 2).contiguous()
  205. return attn_output, attn_weights
  206. class SEWAttention(nn.Module):
  207. """Multi-headed attention from 'Attention Is All You Need' paper"""
  208. def __init__(
  209. self,
  210. embed_dim: int,
  211. num_heads: int,
  212. dropout: float = 0.0,
  213. is_decoder: bool = False,
  214. bias: bool = True,
  215. is_causal: bool = False,
  216. config: SEWConfig | None = None,
  217. ):
  218. super().__init__()
  219. self.embed_dim = embed_dim
  220. self.num_heads = num_heads
  221. self.dropout = dropout
  222. self.head_dim = embed_dim // num_heads
  223. self.config = config
  224. if (self.head_dim * num_heads) != self.embed_dim:
  225. raise ValueError(
  226. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  227. f" and `num_heads`: {num_heads})."
  228. )
  229. self.scaling = self.head_dim**-0.5
  230. self.is_decoder = is_decoder
  231. self.is_causal = is_causal
  232. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  233. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  234. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  235. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  236. def forward(
  237. self,
  238. hidden_states: torch.Tensor,
  239. key_value_states: torch.Tensor | None = None,
  240. attention_mask: torch.Tensor | None = None,
  241. output_attentions: bool | None = False,
  242. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  243. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  244. **kwargs: Unpack[FlashAttentionKwargs],
  245. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  246. """Input shape: Batch x Time x Channel"""
  247. # if key_value_states are provided this layer is used as a cross-attention layer
  248. # for the decoder
  249. is_cross_attention = key_value_states is not None
  250. # determine input shapes
  251. input_shape = hidden_states.shape[:-1]
  252. hidden_shape = (*input_shape, -1, self.head_dim)
  253. # get query proj
  254. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  255. current_states = key_value_states if is_cross_attention else hidden_states
  256. kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
  257. key_states = self.k_proj(current_states).view(kv_shape).transpose(1, 2)
  258. value_states = self.v_proj(current_states).view(kv_shape).transpose(1, 2)
  259. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  260. self.config._attn_implementation, eager_attention_forward
  261. )
  262. attn_output, attn_weights = attention_interface(
  263. self,
  264. query_states,
  265. key_states,
  266. value_states,
  267. attention_mask,
  268. dropout=0.0 if not self.training else self.dropout,
  269. scaling=self.scaling,
  270. output_attentions=output_attentions,
  271. **kwargs,
  272. )
  273. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  274. attn_output = self.out_proj(attn_output)
  275. return attn_output, attn_weights, None
  276. class SEWFeedForward(nn.Module):
  277. def __init__(self, config):
  278. super().__init__()
  279. self.intermediate_dropout = nn.Dropout(config.activation_dropout)
  280. self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
  281. if isinstance(config.hidden_act, str):
  282. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  283. else:
  284. self.intermediate_act_fn = config.hidden_act
  285. self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
  286. self.output_dropout = nn.Dropout(config.hidden_dropout)
  287. def forward(self, hidden_states):
  288. hidden_states = self.intermediate_dense(hidden_states)
  289. hidden_states = self.intermediate_act_fn(hidden_states)
  290. hidden_states = self.intermediate_dropout(hidden_states)
  291. hidden_states = self.output_dense(hidden_states)
  292. hidden_states = self.output_dropout(hidden_states)
  293. return hidden_states
  294. class SEWEncoderLayer(GradientCheckpointingLayer):
  295. def __init__(self, config):
  296. super().__init__()
  297. self.attention = SEWAttention(
  298. embed_dim=config.hidden_size,
  299. num_heads=config.num_attention_heads,
  300. dropout=config.attention_dropout,
  301. is_decoder=False,
  302. config=config,
  303. )
  304. self.dropout = nn.Dropout(config.hidden_dropout)
  305. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  306. self.feed_forward = SEWFeedForward(config)
  307. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  308. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  309. attn_residual = hidden_states
  310. hidden_states, attn_weights, _ = self.attention(
  311. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  312. )
  313. hidden_states = self.dropout(hidden_states)
  314. hidden_states = attn_residual + hidden_states
  315. hidden_states = self.layer_norm(hidden_states)
  316. hidden_states = hidden_states + self.feed_forward(hidden_states)
  317. hidden_states = self.final_layer_norm(hidden_states)
  318. outputs = (hidden_states,)
  319. if output_attentions:
  320. outputs += (attn_weights,)
  321. return outputs
  322. class SEWEncoder(nn.Module):
  323. def __init__(self, config):
  324. super().__init__()
  325. self.config = config
  326. self.pos_conv_embed = SEWPositionalConvEmbedding(config)
  327. self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor)
  328. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  329. self.dropout = nn.Dropout(config.hidden_dropout)
  330. self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  331. self.upsample = SEWUpsampling(config)
  332. self.gradient_checkpointing = False
  333. def forward(
  334. self,
  335. hidden_states,
  336. attention_mask=None,
  337. output_attentions=False,
  338. output_hidden_states=False,
  339. return_dict=True,
  340. ):
  341. all_hidden_states = () if output_hidden_states else None
  342. all_self_attentions = () if output_attentions else None
  343. if attention_mask is not None:
  344. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  345. if is_flash_attention_requested(self.config):
  346. # make sure padded tokens output 0
  347. hidden_states[~expand_attention_mask] = 0.0
  348. # 2d mask is passed through the layers
  349. attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
  350. else:
  351. # make sure padded tokens output 0
  352. hidden_states[~expand_attention_mask] = 0.0
  353. input_lengths = (attention_mask.long()).sum(-1)
  354. # apply pooling formula to get real output_lengths
  355. output_lengths = input_lengths // self.config.squeeze_factor
  356. max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor
  357. attention_ids = (
  358. torch.arange(0, max_encoder_length, device=output_lengths.device)
  359. .view(1, -1)
  360. .expand(output_lengths.shape[0], -1)
  361. )
  362. attention_mask = (attention_ids < output_lengths.view(-1, 1)).long()
  363. # extend attention_mask
  364. attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
  365. attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
  366. attention_mask = attention_mask.expand(
  367. attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
  368. )
  369. n_input_timesteps = hidden_states.shape[1]
  370. hidden_states = hidden_states.transpose(1, 2)
  371. position_embeddings = self.pos_conv_embed(hidden_states)
  372. pooled_hidden_states = self.pool(hidden_states)
  373. min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1))
  374. hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length]
  375. hidden_states = hidden_states.transpose(1, 2)
  376. hidden_states = self.layer_norm(hidden_states)
  377. hidden_states = self.dropout(hidden_states)
  378. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  379. for layer in self.layers:
  380. if output_hidden_states:
  381. all_hidden_states = all_hidden_states + (hidden_states,)
  382. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  383. dropout_probability = torch.rand([])
  384. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  385. if not skip_the_layer or synced_gpus:
  386. # under fsdp or deepspeed zero3 all gpus must run in sync
  387. layer_outputs = layer(
  388. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  389. )
  390. hidden_states = layer_outputs[0]
  391. if skip_the_layer:
  392. layer_outputs = (None, None)
  393. if output_attentions:
  394. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  395. if output_hidden_states:
  396. all_hidden_states = all_hidden_states + (hidden_states,)
  397. hidden_states = self.upsample(hidden_states)
  398. if hidden_states.shape[1] < n_input_timesteps:
  399. hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1]))
  400. if not return_dict:
  401. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  402. return BaseModelOutput(
  403. last_hidden_state=hidden_states,
  404. hidden_states=all_hidden_states,
  405. attentions=all_self_attentions,
  406. )
  407. @auto_docstring
  408. class SEWPreTrainedModel(PreTrainedModel):
  409. config: SEWConfig
  410. base_model_prefix = "sew"
  411. main_input_name = "input_values"
  412. input_modalities = "audio"
  413. supports_gradient_checkpointing = True
  414. _supports_flash_attn = True
  415. _supports_sdpa = True
  416. _supports_flex_attn = False # needs a proper look into the mask creation
  417. @torch.no_grad()
  418. def _init_weights(self, module):
  419. """Initialize the weights"""
  420. if isinstance(module, SEWPositionalConvEmbedding):
  421. init.normal_(
  422. module.conv.weight,
  423. mean=0,
  424. std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
  425. )
  426. init.constant_(module.conv.bias, 0)
  427. elif isinstance(module, nn.Linear):
  428. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  429. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  430. init.zeros_(module.bias)
  431. init.ones_(module.weight)
  432. elif isinstance(module, nn.Conv1d):
  433. if is_deepspeed_zero3_enabled():
  434. import deepspeed
  435. if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
  436. with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
  437. init.kaiming_normal_(module.weight)
  438. else:
  439. with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
  440. init.kaiming_normal_(module.weight)
  441. else:
  442. init.kaiming_normal_(module.weight)
  443. if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
  444. init.zeros_(module.bias)
  445. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int):
  446. """
  447. Computes the output length of the convolutional layers
  448. """
  449. def _conv_out_length(input_length, kernel_size, stride):
  450. # 1D convolutional layer output length formula taken
  451. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  452. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  453. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  454. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  455. return input_lengths
  456. def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
  457. output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  458. batch_size = attention_mask.shape[0]
  459. attention_mask = torch.zeros(
  460. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  461. )
  462. # these two operations makes sure that all values before the output lengths idxs are attended to
  463. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  464. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  465. return attention_mask
  466. def _compute_mask_indices(
  467. shape: tuple[int, int],
  468. mask_prob: float,
  469. mask_length: int,
  470. attention_mask: torch.LongTensor | None = None,
  471. min_masks: int = 0,
  472. ) -> np.ndarray:
  473. """
  474. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  475. ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  476. CPU as part of the preprocessing during training.
  477. Args:
  478. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  479. the first element is the batch size and the second element is the length of the axis to span.
  480. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  481. independently generated mask spans of length `mask_length` is computed by
  482. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  483. actual percentage will be smaller.
  484. mask_length: size of the mask
  485. min_masks: minimum number of masked spans
  486. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  487. each batch dimension.
  488. """
  489. batch_size, sequence_length = shape
  490. if mask_length < 1:
  491. raise ValueError("`mask_length` has to be bigger than 0.")
  492. if mask_length > sequence_length:
  493. raise ValueError(
  494. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  495. f" and `sequence_length`: {sequence_length}`"
  496. )
  497. # epsilon is used for probabilistic rounding
  498. epsilon = np.random.rand(1).item()
  499. def compute_num_masked_span(input_length):
  500. """Given input length, compute how many spans should be masked"""
  501. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  502. num_masked_span = max(num_masked_span, min_masks)
  503. # make sure num masked span <= sequence_length
  504. if num_masked_span * mask_length > sequence_length:
  505. num_masked_span = sequence_length // mask_length
  506. # make sure num_masked span is also <= input_length - (mask_length - 1)
  507. if input_length - (mask_length - 1) < num_masked_span:
  508. num_masked_span = max(input_length - (mask_length - 1), 0)
  509. return num_masked_span
  510. # compute number of masked spans in batch
  511. input_lengths = (
  512. attention_mask.detach().sum(-1).tolist()
  513. if attention_mask is not None
  514. else [sequence_length for _ in range(batch_size)]
  515. )
  516. # SpecAugment mask to fill
  517. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  518. spec_aug_mask_idxs = []
  519. max_num_masked_span = compute_num_masked_span(sequence_length)
  520. if max_num_masked_span == 0:
  521. return spec_aug_mask
  522. for input_length in input_lengths:
  523. # compute num of masked spans for this input
  524. num_masked_span = compute_num_masked_span(input_length)
  525. # get random indices to mask
  526. spec_aug_mask_idx = np.random.choice(
  527. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  528. )
  529. # pick first sampled index that will serve as a dummy index to pad vector
  530. # to ensure same dimension for all batches due to probabilistic rounding
  531. # Picking first sample just pads those vectors twice.
  532. if len(spec_aug_mask_idx) == 0:
  533. # this case can only happen if `input_length` is strictly smaller then
  534. # `sequence_length` in which case the last token has to be a padding
  535. # token which we can use as a dummy mask id
  536. dummy_mask_idx = sequence_length - 1
  537. else:
  538. dummy_mask_idx = spec_aug_mask_idx[0]
  539. spec_aug_mask_idx = np.concatenate(
  540. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  541. )
  542. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  543. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  544. # expand masked indices to masked spans
  545. spec_aug_mask_idxs = np.broadcast_to(
  546. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  547. )
  548. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  549. # add offset to the starting indexes so that indexes now create a span
  550. offsets = np.arange(mask_length)[None, None, :]
  551. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  552. batch_size, max_num_masked_span * mask_length
  553. )
  554. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  555. # ensure that we cannot have indices larger than sequence_length
  556. if spec_aug_mask_idxs.max() > sequence_length - 1:
  557. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  558. # scatter indices to mask
  559. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  560. return spec_aug_mask
  561. @auto_docstring
  562. class SEWModel(SEWPreTrainedModel):
  563. def __init__(self, config: SEWConfig):
  564. super().__init__(config)
  565. self.config = config
  566. self.feature_extractor = SEWFeatureEncoder(config)
  567. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  568. self.project_features = config.conv_dim[-1] != config.hidden_size
  569. if self.project_features:
  570. self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  571. self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
  572. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  573. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  574. self.encoder = SEWEncoder(config)
  575. # Initialize weights and apply final processing
  576. self.post_init()
  577. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
  578. def _mask_hidden_states(
  579. self,
  580. hidden_states: torch.FloatTensor,
  581. mask_time_indices: torch.FloatTensor | None = None,
  582. attention_mask: torch.LongTensor | None = None,
  583. ):
  584. """
  585. Masks extracted features along time axis and/or along feature axis according to
  586. [SpecAugment](https://huggingface.co/papers/1904.08779).
  587. """
  588. # `config.apply_spec_augment` can set masking to False
  589. if not getattr(self.config, "apply_spec_augment", True):
  590. return hidden_states
  591. # generate indices & apply SpecAugment along time axis
  592. batch_size, sequence_length, hidden_size = hidden_states.size()
  593. if mask_time_indices is not None:
  594. # apply SpecAugment along time axis with given mask_time_indices
  595. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  596. elif self.config.mask_time_prob > 0 and self.training:
  597. mask_time_indices = _compute_mask_indices(
  598. (batch_size, sequence_length),
  599. mask_prob=self.config.mask_time_prob,
  600. mask_length=self.config.mask_time_length,
  601. attention_mask=attention_mask,
  602. min_masks=self.config.mask_time_min_masks,
  603. )
  604. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  605. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  606. if self.config.mask_feature_prob > 0 and self.training:
  607. # generate indices & apply SpecAugment along feature axis
  608. mask_feature_indices = _compute_mask_indices(
  609. (batch_size, hidden_size),
  610. mask_prob=self.config.mask_feature_prob,
  611. mask_length=self.config.mask_feature_length,
  612. min_masks=self.config.mask_feature_min_masks,
  613. )
  614. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  615. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  616. hidden_states[mask_feature_indices] = 0
  617. return hidden_states
  618. @auto_docstring
  619. def forward(
  620. self,
  621. input_values: torch.Tensor | None,
  622. attention_mask: torch.Tensor | None = None,
  623. mask_time_indices: torch.FloatTensor | None = None,
  624. output_attentions: bool | None = None,
  625. output_hidden_states: bool | None = None,
  626. return_dict: bool | None = None,
  627. **kwargs,
  628. ) -> tuple | BaseModelOutput:
  629. r"""
  630. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  631. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  632. masked extracted features in *config.proj_codevector_dim* space.
  633. """
  634. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  635. output_hidden_states = (
  636. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  637. )
  638. return_dict = return_dict if return_dict is not None else self.config.return_dict
  639. extract_features = self.feature_extractor(input_values)
  640. extract_features = extract_features.transpose(1, 2)
  641. extract_features = self.layer_norm(extract_features)
  642. if self.project_features:
  643. extract_features = self.feature_projection(extract_features)
  644. hidden_states = self.feature_dropout(extract_features)
  645. if attention_mask is not None:
  646. # compute reduced attention_mask corresponding to feature vectors
  647. attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
  648. hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
  649. encoder_outputs = self.encoder(
  650. hidden_states,
  651. attention_mask=attention_mask,
  652. output_attentions=output_attentions,
  653. output_hidden_states=output_hidden_states,
  654. return_dict=return_dict,
  655. )
  656. hidden_states = encoder_outputs[0]
  657. if not return_dict:
  658. return (hidden_states,) + encoder_outputs[1:]
  659. return BaseModelOutput(
  660. last_hidden_state=hidden_states,
  661. hidden_states=encoder_outputs.hidden_states,
  662. attentions=encoder_outputs.attentions,
  663. )
  664. _HIDDEN_STATES_START_POSITION = 1
  665. @auto_docstring(
  666. custom_intro="""
  667. SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
  668. """
  669. )
  670. class SEWForCTC(SEWPreTrainedModel):
  671. def __init__(self, config, target_lang: str | None = None):
  672. r"""
  673. target_lang (`str`, *optional*):
  674. Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
  675. adapter.<lang>.bin. Only relevant when using an instance of [`SEWForCTC`] with adapters. Uses 'eng' by
  676. default.
  677. """
  678. super().__init__(config)
  679. self.sew = SEWModel(config)
  680. self.dropout = nn.Dropout(config.final_dropout)
  681. self.target_lang = target_lang
  682. if config.vocab_size is None:
  683. raise ValueError(
  684. f"You are trying to instantiate {self.__class__} with a configuration that "
  685. "does not define the vocabulary size of the language model head. Please "
  686. "instantiate the model as follows: `SEWForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
  687. "or define `vocab_size` of your model's configuration."
  688. )
  689. output_hidden_size = (
  690. config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
  691. )
  692. self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
  693. # Initialize weights and apply final processing
  694. self.post_init()
  695. def tie_weights(self, **kwargs):
  696. """
  697. This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
  698. passing `target_lang=...` to `from_pretrained(...)`.
  699. This method is **not** supposed to be called by the user and is prone to be changed in the future.
  700. """
  701. if get_torch_context_manager_or_global_device() == torch.device("meta"):
  702. return
  703. # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
  704. # correctly load adapter layers for SEW so that we do not have to introduce a new API to
  705. # [`PreTrainedModel`]. While slightly hacky, SEW never has to tie input and output embeddings, so that it is
  706. # ok to repurpose this function here.
  707. target_lang = self.target_lang
  708. if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
  709. raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
  710. elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
  711. logger.info("By default `target_lang` is set to 'eng'.")
  712. elif target_lang is not None:
  713. self.load_adapter(target_lang, force_load=True)
  714. def freeze_feature_encoder(self):
  715. """
  716. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  717. not be updated during training.
  718. """
  719. self.sew.feature_extractor._freeze_parameters()
  720. def freeze_base_model(self):
  721. """
  722. Calling this function will disable the gradient computation for the base model so that its parameters will not
  723. be updated during training. Only the classification head will be updated.
  724. """
  725. for param in self.sew.parameters():
  726. param.requires_grad = False
  727. @auto_docstring
  728. def forward(
  729. self,
  730. input_values: torch.Tensor | None,
  731. attention_mask: torch.Tensor | None = None,
  732. output_attentions: bool | None = None,
  733. output_hidden_states: bool | None = None,
  734. return_dict: bool | None = None,
  735. labels: torch.Tensor | None = None,
  736. **kwargs,
  737. ) -> tuple | CausalLMOutput:
  738. r"""
  739. labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
  740. Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
  741. the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
  742. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  743. config.vocab_size - 1]`.
  744. """
  745. return_dict = return_dict if return_dict is not None else self.config.return_dict
  746. if labels is not None and labels.max() >= self.config.vocab_size:
  747. raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
  748. outputs = self.sew(
  749. input_values,
  750. attention_mask=attention_mask,
  751. output_attentions=output_attentions,
  752. output_hidden_states=output_hidden_states,
  753. return_dict=return_dict,
  754. )
  755. hidden_states = outputs[0]
  756. hidden_states = self.dropout(hidden_states)
  757. logits = self.lm_head(hidden_states)
  758. loss = None
  759. if labels is not None:
  760. # retrieve loss input_lengths from attention_mask
  761. attention_mask = (
  762. attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
  763. )
  764. input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  765. # assuming that padded tokens are filled with -100
  766. # when not being attended to
  767. labels_mask = labels >= 0
  768. target_lengths = labels_mask.sum(-1)
  769. flattened_targets = labels.masked_select(labels_mask)
  770. # ctc_loss doesn't support fp16
  771. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  772. with torch.backends.cudnn.flags(enabled=False):
  773. loss = nn.functional.ctc_loss(
  774. log_probs,
  775. flattened_targets,
  776. input_lengths,
  777. target_lengths,
  778. blank=self.config.pad_token_id,
  779. reduction=self.config.ctc_loss_reduction,
  780. zero_infinity=self.config.ctc_zero_infinity,
  781. )
  782. if not return_dict:
  783. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  784. return ((loss,) + output) if loss is not None else output
  785. return CausalLMOutput(
  786. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  787. )
  788. @auto_docstring(
  789. custom_intro="""
  790. SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
  791. SUPERB Keyword Spotting.
  792. """
  793. )
  794. class SEWForSequenceClassification(SEWPreTrainedModel):
  795. def __init__(self, config):
  796. super().__init__(config)
  797. if hasattr(config, "add_adapter") and config.add_adapter:
  798. raise ValueError(
  799. "Sequence classification does not support the use of SEW adapters (config.add_adapter=True)"
  800. )
  801. self.sew = SEWModel(config)
  802. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  803. if config.use_weighted_layer_sum:
  804. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  805. self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
  806. self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
  807. # Initialize weights and apply final processing
  808. self.post_init()
  809. def freeze_feature_encoder(self):
  810. """
  811. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  812. not be updated during training.
  813. """
  814. self.sew.feature_extractor._freeze_parameters()
  815. def freeze_base_model(self):
  816. """
  817. Calling this function will disable the gradient computation for the base model so that its parameters will not
  818. be updated during training. Only the classification head will be updated.
  819. """
  820. for param in self.sew.parameters():
  821. param.requires_grad = False
  822. @auto_docstring
  823. def forward(
  824. self,
  825. input_values: torch.Tensor | None,
  826. attention_mask: torch.Tensor | None = None,
  827. output_attentions: bool | None = None,
  828. output_hidden_states: bool | None = None,
  829. return_dict: bool | None = None,
  830. labels: torch.Tensor | None = None,
  831. **kwargs,
  832. ) -> tuple | SequenceClassifierOutput:
  833. r"""
  834. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  835. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  836. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  837. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  838. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  839. into a tensor of type `torch.FloatTensor`. See [`SEWProcessor.__call__`] for details.
  840. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  841. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  842. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  843. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  844. """
  845. return_dict = return_dict if return_dict is not None else self.config.return_dict
  846. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  847. outputs = self.sew(
  848. input_values,
  849. attention_mask=attention_mask,
  850. output_attentions=output_attentions,
  851. output_hidden_states=output_hidden_states,
  852. return_dict=return_dict,
  853. )
  854. if self.config.use_weighted_layer_sum:
  855. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  856. hidden_states = torch.stack(hidden_states, dim=1)
  857. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  858. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  859. else:
  860. hidden_states = outputs[0]
  861. hidden_states = self.projector(hidden_states)
  862. if attention_mask is None:
  863. pooled_output = hidden_states.mean(dim=1)
  864. else:
  865. padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
  866. expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  867. hidden_states[~expand_padding_mask] = 0.0
  868. pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
  869. logits = self.classifier(pooled_output)
  870. loss = None
  871. if labels is not None:
  872. loss_fct = CrossEntropyLoss()
  873. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  874. if not return_dict:
  875. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  876. return ((loss,) + output) if loss is not None else output
  877. return SequenceClassifierOutput(
  878. loss=loss,
  879. logits=logits,
  880. hidden_states=outputs.hidden_states,
  881. attentions=outputs.attentions,
  882. )
  883. __all__ = ["SEWForCTC", "SEWForSequenceClassification", "SEWModel", "SEWPreTrainedModel"]