modular_hubert.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Hubert model."""
  15. import torch
  16. import torch.nn as nn
  17. from ... import initialization as init
  18. from ...activations import ACT2FN
  19. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  20. from ...modeling_outputs import BaseModelOutput
  21. from ...modeling_utils import PreTrainedModel
  22. from ...utils import auto_docstring
  23. from ..wav2vec2.modeling_wav2vec2 import (
  24. Wav2Vec2Encoder,
  25. Wav2Vec2EncoderStableLayerNorm,
  26. Wav2Vec2FeatureEncoder,
  27. Wav2Vec2ForCTC,
  28. Wav2Vec2ForSequenceClassification,
  29. Wav2Vec2Model,
  30. Wav2Vec2SamePadLayer,
  31. )
  32. from .configuration_hubert import HubertConfig
  33. _HIDDEN_STATES_START_POSITION = 1
  34. class HubertPositionalConvEmbedding(nn.Module):
  35. def __init__(self, config):
  36. super().__init__()
  37. self.conv = nn.Conv1d(
  38. config.hidden_size,
  39. config.hidden_size,
  40. kernel_size=config.num_conv_pos_embeddings,
  41. padding=config.num_conv_pos_embeddings // 2,
  42. groups=config.num_conv_pos_embedding_groups,
  43. )
  44. self.batch_norm = None
  45. if config.conv_pos_batch_norm:
  46. self.batch_norm = nn.BatchNorm1d(config.hidden_size)
  47. else:
  48. weight_norm = nn.utils.weight_norm
  49. if hasattr(nn.utils.parametrizations, "weight_norm"):
  50. weight_norm = nn.utils.parametrizations.weight_norm
  51. if is_deepspeed_zero3_enabled():
  52. import deepspeed
  53. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  54. self.conv = weight_norm(self.conv, name="weight", dim=2)
  55. if hasattr(self.conv, "parametrizations"):
  56. weight_g = self.conv.parametrizations.weight.original0
  57. weight_v = self.conv.parametrizations.weight.original1
  58. else:
  59. weight_g = self.conv.weight_g
  60. weight_v = self.conv.weight_v
  61. deepspeed.zero.register_external_parameter(self, weight_v)
  62. deepspeed.zero.register_external_parameter(self, weight_g)
  63. else:
  64. self.conv = weight_norm(self.conv, name="weight", dim=2)
  65. self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
  66. self.activation = ACT2FN[config.feat_extract_activation]
  67. def forward(self, hidden_states):
  68. hidden_states = hidden_states.transpose(1, 2)
  69. if self.batch_norm is not None:
  70. hidden_states = self.batch_norm(hidden_states)
  71. hidden_states = self.conv(hidden_states)
  72. hidden_states = self.padding(hidden_states)
  73. hidden_states = self.activation(hidden_states)
  74. hidden_states = hidden_states.transpose(1, 2)
  75. return hidden_states
  76. class HubertSamePadLayer(Wav2Vec2SamePadLayer):
  77. pass
  78. class HubertFeatureEncoder(Wav2Vec2FeatureEncoder):
  79. pass
  80. class HubertFeatureProjection(nn.Module):
  81. def __init__(self, config):
  82. super().__init__()
  83. self.feat_proj_layer_norm = config.feat_proj_layer_norm
  84. if self.feat_proj_layer_norm:
  85. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  86. self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  87. self.dropout = nn.Dropout(config.feat_proj_dropout)
  88. def forward(self, hidden_states):
  89. # non-projected hidden states are needed for quantization
  90. if self.feat_proj_layer_norm:
  91. hidden_states = self.layer_norm(hidden_states)
  92. hidden_states = self.projection(hidden_states)
  93. hidden_states = self.dropout(hidden_states)
  94. return hidden_states
  95. class HubertEncoder(Wav2Vec2Encoder):
  96. pass
  97. class HubertEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm):
  98. pass
  99. @auto_docstring
  100. class HubertPreTrainedModel(PreTrainedModel):
  101. config: HubertConfig
  102. base_model_prefix = "hubert"
  103. main_input_name = "input_values"
  104. input_modalities = "audio"
  105. _no_split_modules = ["HubertEncoderLayer", "ParametrizedConv1d"]
  106. supports_gradient_checkpointing = True
  107. _supports_flash_attn = True
  108. _supports_sdpa = True
  109. _supports_flex_attn = True
  110. @torch.no_grad()
  111. def _init_weights(self, module):
  112. """Initialize the weights"""
  113. if isinstance(module, nn.Linear):
  114. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  115. if module.bias is not None:
  116. init.zeros_(module.bias)
  117. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
  118. init.zeros_(module.bias)
  119. init.ones_(module.weight)
  120. if getattr(module, "running_mean", None) is not None:
  121. init.zeros_(module.running_mean)
  122. init.ones_(module.running_var)
  123. init.zeros_(module.num_batches_tracked)
  124. elif isinstance(module, nn.Conv1d):
  125. if is_deepspeed_zero3_enabled():
  126. import deepspeed
  127. if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
  128. with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
  129. init.kaiming_normal_(module.weight)
  130. else:
  131. with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
  132. init.kaiming_normal_(module.weight)
  133. else:
  134. init.kaiming_normal_(module.weight)
  135. if module.bias is not None:
  136. init.zeros_(module.bias)
  137. elif isinstance(module, HubertModel):
  138. if hasattr(module, "masked_spec_embed"):
  139. init.uniform_(module.masked_spec_embed)
  140. elif isinstance(module, HubertForSequenceClassification):
  141. if hasattr(module, "layer_weights"):
  142. init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1))
  143. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int):
  144. """
  145. Computes the output length of the convolutional layers
  146. """
  147. def _conv_out_length(input_length, kernel_size, stride):
  148. # 1D convolutional layer output length formula taken
  149. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  150. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  151. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  152. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  153. return input_lengths
  154. def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
  155. output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  156. batch_size = attention_mask.shape[0]
  157. attention_mask = torch.zeros(
  158. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  159. )
  160. # these two operations makes sure that all values before the output lengths idxs are attended to
  161. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  162. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  163. return attention_mask
  164. class HubertModel(Wav2Vec2Model, HubertPreTrainedModel):
  165. def __init__(self, config: HubertConfig):
  166. super().__init__(config)
  167. self.config = config
  168. self.feature_extractor = HubertFeatureEncoder(config)
  169. self.feature_projection = HubertFeatureProjection(config)
  170. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  171. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  172. if config.do_stable_layer_norm:
  173. self.encoder = HubertEncoderStableLayerNorm(config)
  174. else:
  175. self.encoder = HubertEncoder(config)
  176. # Initialize weights and apply final processing
  177. self.post_init()
  178. del self.adapter
  179. def freeze_feature_encoder(self):
  180. raise AttributeError("Not needed for Hubert")
  181. def forward(
  182. self,
  183. input_values: torch.Tensor | None,
  184. attention_mask: torch.Tensor | None = None,
  185. mask_time_indices: torch.FloatTensor | None = None,
  186. output_attentions: bool | None = None,
  187. output_hidden_states: bool | None = None,
  188. return_dict: bool | None = None,
  189. **kwargs,
  190. ) -> tuple | BaseModelOutput:
  191. r"""
  192. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  193. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  194. masked extracted features in *config.proj_codevector_dim* space.
  195. Example:
  196. ```python
  197. >>> from transformers import AutoProcessor, HubertModel
  198. >>> from datasets import load_dataset
  199. >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
  200. >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
  201. >>> def map_to_array(example):
  202. ... example["speech"] = example["audio"]["array"]
  203. ... return example
  204. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  205. >>> ds = ds.map(map_to_array)
  206. >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
  207. >>> hidden_states = model(input_values).last_hidden_state
  208. ```"""
  209. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  210. output_hidden_states = (
  211. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  212. )
  213. return_dict = return_dict if return_dict is not None else self.config.return_dict
  214. extract_features = self.feature_extractor(input_values)
  215. extract_features = extract_features.transpose(1, 2)
  216. if attention_mask is not None:
  217. # compute reduced attention_mask corresponding to feature vectors
  218. attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
  219. hidden_states = self.feature_projection(extract_features)
  220. hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
  221. encoder_outputs = self.encoder(
  222. hidden_states,
  223. attention_mask=attention_mask,
  224. output_attentions=output_attentions,
  225. output_hidden_states=output_hidden_states,
  226. return_dict=return_dict,
  227. )
  228. hidden_states = encoder_outputs[0]
  229. if not return_dict:
  230. return (hidden_states,) + encoder_outputs[1:]
  231. return BaseModelOutput(
  232. last_hidden_state=hidden_states,
  233. hidden_states=encoder_outputs.hidden_states,
  234. attentions=encoder_outputs.attentions,
  235. )
  236. class HubertForCTC(Wav2Vec2ForCTC):
  237. pass
  238. class HubertForSequenceClassification(Wav2Vec2ForSequenceClassification):
  239. pass
  240. __all__ = ["HubertForCTC", "HubertForSequenceClassification", "HubertModel", "HubertPreTrainedModel"]