modular_sew.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. # Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch SEW model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from ... import initialization as init
  19. from ...activations import ACT2FN
  20. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  21. from ...integrations.fsdp import is_fsdp_managed_module
  22. from ...modeling_outputs import BaseModelOutput
  23. from ...modeling_utils import PreTrainedModel
  24. from ...utils import auto_docstring
  25. from ...utils.generic import is_flash_attention_requested
  26. from ..wav2vec2.modeling_wav2vec2 import (
  27. Wav2Vec2Attention,
  28. Wav2Vec2EncoderLayer,
  29. Wav2Vec2FeatureEncoder,
  30. Wav2Vec2FeedForward,
  31. Wav2Vec2ForCTC,
  32. Wav2Vec2ForSequenceClassification,
  33. Wav2Vec2GroupNormConvLayer,
  34. Wav2Vec2LayerNormConvLayer,
  35. Wav2Vec2NoLayerNormConvLayer,
  36. Wav2Vec2SamePadLayer,
  37. _compute_mask_indices,
  38. )
  39. from .configuration_sew import SEWConfig
  40. _HIDDEN_STATES_START_POSITION = 1
  41. class SEWNoLayerNormConvLayer(Wav2Vec2NoLayerNormConvLayer):
  42. pass
  43. class SEWLayerNormConvLayer(Wav2Vec2LayerNormConvLayer):
  44. pass
  45. class SEWGroupNormConvLayer(Wav2Vec2GroupNormConvLayer):
  46. pass
  47. class SEWPositionalConvEmbedding(nn.Module):
  48. def __init__(self, config):
  49. super().__init__()
  50. self.conv = nn.Conv1d(
  51. config.hidden_size,
  52. config.hidden_size,
  53. kernel_size=config.num_conv_pos_embeddings,
  54. padding=config.num_conv_pos_embeddings // 2,
  55. groups=config.num_conv_pos_embedding_groups,
  56. stride=config.squeeze_factor,
  57. )
  58. weight_norm = nn.utils.weight_norm
  59. if hasattr(nn.utils.parametrizations, "weight_norm"):
  60. weight_norm = nn.utils.parametrizations.weight_norm
  61. if is_deepspeed_zero3_enabled():
  62. import deepspeed
  63. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  64. self.conv = weight_norm(self.conv, name="weight", dim=2)
  65. if hasattr(self.conv, "parametrizations"):
  66. weight_g = self.conv.parametrizations.weight.original0
  67. weight_v = self.conv.parametrizations.weight.original1
  68. else:
  69. weight_g = self.conv.weight_g
  70. weight_v = self.conv.weight_v
  71. deepspeed.zero.register_external_parameter(self, weight_v)
  72. deepspeed.zero.register_external_parameter(self, weight_g)
  73. else:
  74. self.conv = weight_norm(self.conv, name="weight", dim=2)
  75. self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings)
  76. self.activation = ACT2FN[config.feat_extract_activation]
  77. def forward(self, hidden_states):
  78. hidden_states = self.conv(hidden_states)
  79. hidden_states = self.padding(hidden_states)
  80. hidden_states = self.activation(hidden_states)
  81. return hidden_states
  82. class SEWSamePadLayer(Wav2Vec2SamePadLayer):
  83. pass
  84. class SEWUpsampling(nn.Module):
  85. def __init__(self, config):
  86. super().__init__()
  87. self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor)
  88. self.activation = ACT2FN[config.feat_extract_activation]
  89. self.squeeze_factor = config.squeeze_factor
  90. def forward(self, hidden_states):
  91. hidden_states = self.projection(hidden_states)
  92. hidden_states = self.activation(hidden_states)
  93. if self.squeeze_factor > 1:
  94. # transform embedding channels to sequence length
  95. bsz, src_len, src_embed_dim = hidden_states.size()
  96. tgt_len = src_len * self.squeeze_factor
  97. tgt_embed_dim = src_embed_dim // self.squeeze_factor
  98. hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim)
  99. hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim)
  100. return hidden_states
  101. class SEWFeatureEncoder(Wav2Vec2FeatureEncoder):
  102. pass
  103. class SEWAttention(Wav2Vec2Attention):
  104. pass
  105. class SEWFeedForward(Wav2Vec2FeedForward):
  106. pass
  107. class SEWEncoderLayer(Wav2Vec2EncoderLayer):
  108. pass
  109. class SEWEncoder(nn.Module):
  110. def __init__(self, config):
  111. super().__init__()
  112. self.config = config
  113. self.pos_conv_embed = SEWPositionalConvEmbedding(config)
  114. self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor)
  115. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  116. self.dropout = nn.Dropout(config.hidden_dropout)
  117. self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  118. self.upsample = SEWUpsampling(config)
  119. self.gradient_checkpointing = False
  120. def forward(
  121. self,
  122. hidden_states,
  123. attention_mask=None,
  124. output_attentions=False,
  125. output_hidden_states=False,
  126. return_dict=True,
  127. ):
  128. all_hidden_states = () if output_hidden_states else None
  129. all_self_attentions = () if output_attentions else None
  130. if attention_mask is not None:
  131. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  132. if is_flash_attention_requested(self.config):
  133. # make sure padded tokens output 0
  134. hidden_states[~expand_attention_mask] = 0.0
  135. # 2d mask is passed through the layers
  136. attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
  137. else:
  138. # make sure padded tokens output 0
  139. hidden_states[~expand_attention_mask] = 0.0
  140. input_lengths = (attention_mask.long()).sum(-1)
  141. # apply pooling formula to get real output_lengths
  142. output_lengths = input_lengths // self.config.squeeze_factor
  143. max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor
  144. attention_ids = (
  145. torch.arange(0, max_encoder_length, device=output_lengths.device)
  146. .view(1, -1)
  147. .expand(output_lengths.shape[0], -1)
  148. )
  149. attention_mask = (attention_ids < output_lengths.view(-1, 1)).long()
  150. # extend attention_mask
  151. attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
  152. attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
  153. attention_mask = attention_mask.expand(
  154. attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
  155. )
  156. n_input_timesteps = hidden_states.shape[1]
  157. hidden_states = hidden_states.transpose(1, 2)
  158. position_embeddings = self.pos_conv_embed(hidden_states)
  159. pooled_hidden_states = self.pool(hidden_states)
  160. min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1))
  161. hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length]
  162. hidden_states = hidden_states.transpose(1, 2)
  163. hidden_states = self.layer_norm(hidden_states)
  164. hidden_states = self.dropout(hidden_states)
  165. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  166. for layer in self.layers:
  167. if output_hidden_states:
  168. all_hidden_states = all_hidden_states + (hidden_states,)
  169. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  170. dropout_probability = torch.rand([])
  171. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  172. if not skip_the_layer or synced_gpus:
  173. # under fsdp or deepspeed zero3 all gpus must run in sync
  174. layer_outputs = layer(
  175. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  176. )
  177. hidden_states = layer_outputs[0]
  178. if skip_the_layer:
  179. layer_outputs = (None, None)
  180. if output_attentions:
  181. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  182. if output_hidden_states:
  183. all_hidden_states = all_hidden_states + (hidden_states,)
  184. hidden_states = self.upsample(hidden_states)
  185. if hidden_states.shape[1] < n_input_timesteps:
  186. hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1]))
  187. if not return_dict:
  188. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  189. return BaseModelOutput(
  190. last_hidden_state=hidden_states,
  191. hidden_states=all_hidden_states,
  192. attentions=all_self_attentions,
  193. )
  194. @auto_docstring
  195. class SEWPreTrainedModel(PreTrainedModel):
  196. config: SEWConfig
  197. base_model_prefix = "sew"
  198. main_input_name = "input_values"
  199. input_modalities = "audio"
  200. supports_gradient_checkpointing = True
  201. _supports_flash_attn = True
  202. _supports_sdpa = True
  203. _supports_flex_attn = False # needs a proper look into the mask creation
  204. @torch.no_grad()
  205. def _init_weights(self, module):
  206. """Initialize the weights"""
  207. if isinstance(module, SEWPositionalConvEmbedding):
  208. init.normal_(
  209. module.conv.weight,
  210. mean=0,
  211. std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
  212. )
  213. init.constant_(module.conv.bias, 0)
  214. elif isinstance(module, nn.Linear):
  215. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  216. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  217. init.zeros_(module.bias)
  218. init.ones_(module.weight)
  219. elif isinstance(module, nn.Conv1d):
  220. if is_deepspeed_zero3_enabled():
  221. import deepspeed
  222. if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
  223. with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
  224. init.kaiming_normal_(module.weight)
  225. else:
  226. with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
  227. init.kaiming_normal_(module.weight)
  228. else:
  229. init.kaiming_normal_(module.weight)
  230. if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
  231. init.zeros_(module.bias)
  232. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int):
  233. """
  234. Computes the output length of the convolutional layers
  235. """
  236. def _conv_out_length(input_length, kernel_size, stride):
  237. # 1D convolutional layer output length formula taken
  238. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  239. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  240. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  241. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  242. return input_lengths
  243. def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
  244. output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  245. batch_size = attention_mask.shape[0]
  246. attention_mask = torch.zeros(
  247. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  248. )
  249. # these two operations makes sure that all values before the output lengths idxs are attended to
  250. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  251. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  252. return attention_mask
  253. @auto_docstring
  254. class SEWModel(SEWPreTrainedModel):
  255. def __init__(self, config: SEWConfig):
  256. super().__init__(config)
  257. self.config = config
  258. self.feature_extractor = SEWFeatureEncoder(config)
  259. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  260. self.project_features = config.conv_dim[-1] != config.hidden_size
  261. if self.project_features:
  262. self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  263. self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
  264. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  265. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  266. self.encoder = SEWEncoder(config)
  267. # Initialize weights and apply final processing
  268. self.post_init()
  269. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
  270. def _mask_hidden_states(
  271. self,
  272. hidden_states: torch.FloatTensor,
  273. mask_time_indices: torch.FloatTensor | None = None,
  274. attention_mask: torch.LongTensor | None = None,
  275. ):
  276. """
  277. Masks extracted features along time axis and/or along feature axis according to
  278. [SpecAugment](https://huggingface.co/papers/1904.08779).
  279. """
  280. # `config.apply_spec_augment` can set masking to False
  281. if not getattr(self.config, "apply_spec_augment", True):
  282. return hidden_states
  283. # generate indices & apply SpecAugment along time axis
  284. batch_size, sequence_length, hidden_size = hidden_states.size()
  285. if mask_time_indices is not None:
  286. # apply SpecAugment along time axis with given mask_time_indices
  287. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  288. elif self.config.mask_time_prob > 0 and self.training:
  289. mask_time_indices = _compute_mask_indices(
  290. (batch_size, sequence_length),
  291. mask_prob=self.config.mask_time_prob,
  292. mask_length=self.config.mask_time_length,
  293. attention_mask=attention_mask,
  294. min_masks=self.config.mask_time_min_masks,
  295. )
  296. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  297. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  298. if self.config.mask_feature_prob > 0 and self.training:
  299. # generate indices & apply SpecAugment along feature axis
  300. mask_feature_indices = _compute_mask_indices(
  301. (batch_size, hidden_size),
  302. mask_prob=self.config.mask_feature_prob,
  303. mask_length=self.config.mask_feature_length,
  304. min_masks=self.config.mask_feature_min_masks,
  305. )
  306. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  307. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  308. hidden_states[mask_feature_indices] = 0
  309. return hidden_states
  310. @auto_docstring
  311. def forward(
  312. self,
  313. input_values: torch.Tensor | None,
  314. attention_mask: torch.Tensor | None = None,
  315. mask_time_indices: torch.FloatTensor | None = None,
  316. output_attentions: bool | None = None,
  317. output_hidden_states: bool | None = None,
  318. return_dict: bool | None = None,
  319. **kwargs,
  320. ) -> tuple | BaseModelOutput:
  321. r"""
  322. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  323. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  324. masked extracted features in *config.proj_codevector_dim* space.
  325. """
  326. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  327. output_hidden_states = (
  328. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  329. )
  330. return_dict = return_dict if return_dict is not None else self.config.return_dict
  331. extract_features = self.feature_extractor(input_values)
  332. extract_features = extract_features.transpose(1, 2)
  333. extract_features = self.layer_norm(extract_features)
  334. if self.project_features:
  335. extract_features = self.feature_projection(extract_features)
  336. hidden_states = self.feature_dropout(extract_features)
  337. if attention_mask is not None:
  338. # compute reduced attention_mask corresponding to feature vectors
  339. attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
  340. hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
  341. encoder_outputs = self.encoder(
  342. hidden_states,
  343. attention_mask=attention_mask,
  344. output_attentions=output_attentions,
  345. output_hidden_states=output_hidden_states,
  346. return_dict=return_dict,
  347. )
  348. hidden_states = encoder_outputs[0]
  349. if not return_dict:
  350. return (hidden_states,) + encoder_outputs[1:]
  351. return BaseModelOutput(
  352. last_hidden_state=hidden_states,
  353. hidden_states=encoder_outputs.hidden_states,
  354. attentions=encoder_outputs.attentions,
  355. )
  356. class SEWForCTC(Wav2Vec2ForCTC):
  357. pass
  358. class SEWForSequenceClassification(Wav2Vec2ForSequenceClassification):
  359. pass
  360. __all__ = ["SEWForCTC", "SEWForSequenceClassification", "SEWModel", "SEWPreTrainedModel"]