modular_wavlm.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from ... import initialization as init
  6. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  7. from ...integrations.fsdp import is_fsdp_managed_module
  8. from ...modeling_layers import GradientCheckpointingLayer
  9. from ...modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput
  10. from ...modeling_utils import PreTrainedModel
  11. from ...utils import logging
  12. from ..wav2vec2.modeling_wav2vec2 import (
  13. Wav2Vec2FeatureProjection,
  14. Wav2Vec2FeedForward,
  15. Wav2Vec2ForAudioFrameClassification,
  16. Wav2Vec2ForCTC,
  17. Wav2Vec2ForSequenceClassification,
  18. Wav2Vec2ForXVector,
  19. Wav2Vec2Model,
  20. Wav2Vec2PositionalConvEmbedding,
  21. Wav2Vec2PreTrainedModel,
  22. )
  23. from .configuration_wavlm import WavLMConfig
  24. logger = logging.get_logger(__name__)
  25. class WavLMPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding):
  26. pass
  27. class WavLMFeatureProjection(Wav2Vec2FeatureProjection):
  28. pass
  29. class WavLMAttention(nn.Module):
  30. """Multi-headed attention from 'Attention Is All You Need' paper"""
  31. def __init__(
  32. self,
  33. embed_dim: int,
  34. num_heads: int,
  35. dropout: float | int = 0.0,
  36. num_buckets: int = 320,
  37. max_distance: int = 800,
  38. has_relative_position_bias: bool = True,
  39. ):
  40. super().__init__()
  41. self.embed_dim = embed_dim
  42. self.num_heads = num_heads
  43. self.dropout = dropout
  44. self.head_dim = embed_dim // num_heads
  45. if (self.head_dim * num_heads) != self.embed_dim:
  46. raise ValueError(
  47. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  48. f" and `num_heads`: {num_heads})."
  49. )
  50. self.scaling = self.head_dim**-0.5
  51. self.k_proj = nn.Linear(embed_dim, embed_dim)
  52. self.v_proj = nn.Linear(embed_dim, embed_dim)
  53. self.q_proj = nn.Linear(embed_dim, embed_dim)
  54. self.out_proj = nn.Linear(embed_dim, embed_dim)
  55. self.num_buckets = num_buckets
  56. self.max_distance = max_distance
  57. self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1))
  58. self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
  59. if has_relative_position_bias:
  60. self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads)
  61. def forward(
  62. self,
  63. hidden_states: torch.Tensor,
  64. attention_mask: torch.Tensor | None = None,
  65. position_bias: torch.Tensor | None = None,
  66. output_attentions: bool = False,
  67. index=0,
  68. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  69. """Attention layer with relative attention"""
  70. bsz, tgt_len, _ = hidden_states.size()
  71. # first pass of attention layer creates position bias
  72. if position_bias is None:
  73. position_bias = self.compute_bias(tgt_len, tgt_len)
  74. position_bias = (
  75. position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, tgt_len)
  76. )
  77. # Compute relative position bias:
  78. # 1) get reshape hidden_states
  79. gated_hidden_states = hidden_states.view(hidden_states.shape[:-1] + (self.num_heads, -1))
  80. gated_hidden_states = gated_hidden_states.permute(0, 2, 1, 3)
  81. # 2) project hidden states
  82. relative_position_proj = self.gru_rel_pos_linear(gated_hidden_states)
  83. relative_position_proj = relative_position_proj.view(gated_hidden_states.shape[:-1] + (2, 4)).sum(-1)
  84. # 3) compute gate for position bias from projected hidden states
  85. gate_a, gate_b = torch.sigmoid(relative_position_proj).chunk(2, dim=-1)
  86. gate_output = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
  87. # 4) apply gate to position bias to compute gated position_bias
  88. gated_position_bias = gate_output.view(bsz * self.num_heads, -1, 1) * position_bias
  89. gated_position_bias = gated_position_bias.view((-1, tgt_len, tgt_len))
  90. attn_output, attn_weights = self.torch_multi_head_self_attention(
  91. hidden_states, attention_mask, gated_position_bias, output_attentions
  92. )
  93. return attn_output, attn_weights, position_bias
  94. def torch_multi_head_self_attention(
  95. self,
  96. hidden_states: torch.FloatTensor,
  97. attention_mask: torch.LongTensor | torch.BoolTensor,
  98. gated_position_bias: torch.FloatTensor,
  99. output_attentions: bool,
  100. ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
  101. """simple wrapper around torch's multi_head_attention_forward function"""
  102. # self-attention assumes q = k = v
  103. query = key = value = hidden_states.transpose(0, 1)
  104. key_padding_mask = attention_mask.ne(1) if attention_mask is not None else None
  105. # disable bias and add_zero_attn
  106. bias_k = bias_v = None
  107. add_zero_attn = False
  108. # PyTorch 1.3.0 has F.multi_head_attention_forward defined
  109. # so no problem with backwards compatibility
  110. attn_output, attn_weights = F.multi_head_attention_forward(
  111. query,
  112. key,
  113. value,
  114. self.embed_dim,
  115. self.num_heads,
  116. torch.empty([0]),
  117. torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
  118. bias_k,
  119. bias_v,
  120. add_zero_attn,
  121. self.dropout,
  122. self.out_proj.weight,
  123. self.out_proj.bias,
  124. self.training,
  125. key_padding_mask,
  126. output_attentions,
  127. gated_position_bias,
  128. use_separate_proj_weight=True,
  129. q_proj_weight=self.q_proj.weight,
  130. k_proj_weight=self.k_proj.weight,
  131. v_proj_weight=self.v_proj.weight,
  132. )
  133. # [Seq_Len, Batch Size, ...] -> [Batch Size, Seq_Len, ...]
  134. attn_output = attn_output.transpose(0, 1)
  135. if attn_weights is not None:
  136. # IMPORTANT: Attention weights are averaged weights
  137. # here which should not be the case. This is an open issue
  138. # on PyTorch: https://github.com/pytorch/pytorch/issues/32590
  139. attn_weights = attn_weights[:, None].broadcast_to(
  140. attn_weights.shape[:1] + (self.num_heads,) + attn_weights.shape[1:]
  141. )
  142. return attn_output, attn_weights
  143. def compute_bias(self, query_length: int, key_length: int) -> torch.FloatTensor:
  144. context_position = torch.arange(query_length, dtype=torch.long)[:, None]
  145. memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
  146. relative_position = memory_position - context_position
  147. relative_position_bucket = self._relative_positions_bucket(relative_position)
  148. relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device)
  149. values = self.rel_attn_embed(relative_position_bucket)
  150. values = values.permute([2, 0, 1])
  151. return values
  152. def _relative_positions_bucket(self, relative_positions: torch.FloatTensor) -> torch.FloatTensor:
  153. num_buckets = self.num_buckets // 2
  154. relative_buckets = (relative_positions > 0).to(torch.long) * num_buckets
  155. relative_positions = torch.abs(relative_positions)
  156. max_exact = num_buckets // 2
  157. is_small = relative_positions < max_exact
  158. relative_positions_if_large = torch.log(relative_positions.float() / max_exact)
  159. relative_positions_if_large = relative_positions_if_large / math.log(self.max_distance / max_exact)
  160. relative_positions_if_large = relative_positions_if_large * (num_buckets - max_exact)
  161. relative_position_if_large = (max_exact + relative_positions_if_large).to(torch.long)
  162. relative_position_if_large = torch.min(
  163. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  164. )
  165. relative_buckets += torch.where(is_small, relative_positions, relative_position_if_large)
  166. return relative_buckets
  167. class WavLMFeedForward(Wav2Vec2FeedForward):
  168. pass
  169. class WavLMEncoderLayer(GradientCheckpointingLayer):
  170. def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
  171. super().__init__()
  172. self.attention = WavLMAttention(
  173. embed_dim=config.hidden_size,
  174. num_heads=config.num_attention_heads,
  175. dropout=config.attention_dropout,
  176. num_buckets=config.num_buckets,
  177. max_distance=config.max_bucket_distance,
  178. has_relative_position_bias=has_relative_position_bias,
  179. )
  180. self.dropout = nn.Dropout(config.hidden_dropout)
  181. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  182. self.feed_forward = WavLMFeedForward(config)
  183. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  184. def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
  185. attn_residual = hidden_states
  186. hidden_states, attn_weights, position_bias = self.attention(
  187. hidden_states,
  188. attention_mask=attention_mask,
  189. position_bias=position_bias,
  190. output_attentions=output_attentions,
  191. index=index,
  192. )
  193. hidden_states = self.dropout(hidden_states)
  194. hidden_states = attn_residual + hidden_states
  195. hidden_states = self.layer_norm(hidden_states)
  196. hidden_states = hidden_states + self.feed_forward(hidden_states)
  197. hidden_states = self.final_layer_norm(hidden_states)
  198. outputs = (hidden_states, position_bias)
  199. if output_attentions:
  200. outputs += (attn_weights,)
  201. return outputs
  202. class WavLMEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
  203. def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
  204. super().__init__()
  205. self.attention = WavLMAttention(
  206. embed_dim=config.hidden_size,
  207. num_heads=config.num_attention_heads,
  208. dropout=config.attention_dropout,
  209. num_buckets=config.num_buckets,
  210. max_distance=config.max_bucket_distance,
  211. has_relative_position_bias=has_relative_position_bias,
  212. )
  213. self.dropout = nn.Dropout(config.hidden_dropout)
  214. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  215. self.feed_forward = WavLMFeedForward(config)
  216. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  217. def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
  218. attn_residual = hidden_states
  219. hidden_states = self.layer_norm(hidden_states)
  220. hidden_states, attn_weights, position_bias = self.attention(
  221. hidden_states,
  222. attention_mask=attention_mask,
  223. position_bias=position_bias,
  224. output_attentions=output_attentions,
  225. )
  226. hidden_states = self.dropout(hidden_states)
  227. hidden_states = attn_residual + hidden_states
  228. hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
  229. outputs = (hidden_states, position_bias)
  230. if output_attentions:
  231. outputs += (attn_weights,)
  232. return outputs
  233. class WavLMEncoder(nn.Module):
  234. def __init__(self, config):
  235. super().__init__()
  236. self.config = config
  237. self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
  238. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  239. self.dropout = nn.Dropout(config.hidden_dropout)
  240. self.layers = nn.ModuleList(
  241. [WavLMEncoderLayer(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers)]
  242. )
  243. self.gradient_checkpointing = False
  244. def forward(
  245. self,
  246. hidden_states,
  247. attention_mask=None,
  248. output_attentions=False,
  249. output_hidden_states=False,
  250. return_dict=True,
  251. ):
  252. all_hidden_states = () if output_hidden_states else None
  253. all_self_attentions = () if output_attentions else None
  254. if attention_mask is not None:
  255. # make sure padded tokens output 0
  256. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  257. hidden_states[~expand_attention_mask] = 0
  258. position_embeddings = self.pos_conv_embed(hidden_states)
  259. hidden_states = hidden_states + position_embeddings
  260. hidden_states = self.layer_norm(hidden_states)
  261. hidden_states = self.dropout(hidden_states)
  262. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  263. position_bias = None
  264. for i, layer in enumerate(self.layers):
  265. if output_hidden_states:
  266. all_hidden_states = all_hidden_states + (hidden_states,)
  267. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  268. dropout_probability = torch.rand([])
  269. skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
  270. if not skip_the_layer or synced_gpus:
  271. # under fsdp or deepspeed zero3 all gpus must run in sync
  272. layer_outputs = layer(
  273. hidden_states,
  274. attention_mask=attention_mask,
  275. position_bias=position_bias,
  276. output_attentions=output_attentions,
  277. index=i,
  278. )
  279. hidden_states, position_bias = layer_outputs[:2]
  280. if skip_the_layer:
  281. layer_outputs = (None, None, None)
  282. if output_attentions:
  283. all_self_attentions = all_self_attentions + (layer_outputs[2],)
  284. if output_hidden_states:
  285. all_hidden_states = all_hidden_states + (hidden_states,)
  286. if not return_dict:
  287. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  288. return BaseModelOutput(
  289. last_hidden_state=hidden_states,
  290. hidden_states=all_hidden_states,
  291. attentions=all_self_attentions,
  292. )
  293. class WavLMEncoderStableLayerNorm(nn.Module):
  294. def __init__(self, config):
  295. super().__init__()
  296. self.config = config
  297. self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
  298. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  299. self.dropout = nn.Dropout(config.hidden_dropout)
  300. self.layers = nn.ModuleList(
  301. [
  302. WavLMEncoderLayerStableLayerNorm(config, has_relative_position_bias=(i == 0))
  303. for i in range(config.num_hidden_layers)
  304. ]
  305. )
  306. self.gradient_checkpointing = False
  307. def forward(
  308. self,
  309. hidden_states,
  310. attention_mask=None,
  311. output_attentions=False,
  312. output_hidden_states=False,
  313. return_dict=True,
  314. ):
  315. all_hidden_states = () if output_hidden_states else None
  316. all_self_attentions = () if output_attentions else None
  317. if attention_mask is not None:
  318. # make sure padded tokens are not attended to
  319. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  320. hidden_states[~expand_attention_mask] = 0
  321. position_embeddings = self.pos_conv_embed(hidden_states)
  322. hidden_states = hidden_states + position_embeddings
  323. hidden_states = self.dropout(hidden_states)
  324. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  325. position_bias = None
  326. for i, layer in enumerate(self.layers):
  327. if output_hidden_states:
  328. all_hidden_states = all_hidden_states + (hidden_states,)
  329. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  330. dropout_probability = torch.rand([])
  331. skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
  332. if not skip_the_layer or synced_gpus:
  333. # under fsdp or deepspeed zero3 all gpus must run in sync
  334. # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
  335. layer_outputs = layer(
  336. hidden_states,
  337. attention_mask=attention_mask,
  338. output_attentions=output_attentions,
  339. position_bias=position_bias,
  340. )
  341. hidden_states, position_bias = layer_outputs[:2]
  342. if skip_the_layer:
  343. layer_outputs = (None, None, None)
  344. if output_attentions:
  345. all_self_attentions = all_self_attentions + (layer_outputs[2],)
  346. hidden_states = self.layer_norm(hidden_states)
  347. if output_hidden_states:
  348. all_hidden_states = all_hidden_states + (hidden_states,)
  349. if not return_dict:
  350. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  351. return BaseModelOutput(
  352. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
  353. )
  354. class WavLMGumbelVectorQuantizer(nn.Module):
  355. """
  356. Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
  357. GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
  358. """
  359. def __init__(self, config):
  360. super().__init__()
  361. self.num_groups = config.num_codevector_groups
  362. self.num_vars = config.num_codevectors_per_group
  363. if config.codevector_dim % self.num_groups != 0:
  364. raise ValueError(
  365. f"`config.codevector_dim {config.codevector_dim} must be divisible"
  366. f" by `config.num_codevector_groups` {self.num_groups} "
  367. "for concatenation."
  368. )
  369. # storage for codebook variables (codewords)
  370. self.codevectors = nn.Parameter(
  371. torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
  372. )
  373. self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
  374. # can be decayed for training
  375. self.temperature = 2
  376. @staticmethod
  377. def _compute_perplexity(probs):
  378. marginal_probs = probs.mean(dim=0)
  379. perplexity = torch.exp(-torch.sum(torch.xlogy(marginal_probs, marginal_probs), dim=-1)).sum()
  380. return perplexity
  381. def forward(self, hidden_states):
  382. batch_size, sequence_length, hidden_size = hidden_states.shape
  383. # project to codevector dim
  384. hidden_states = self.weight_proj(hidden_states)
  385. hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
  386. if self.training:
  387. # sample code vector probs via gumbel in differentiateable way
  388. codevector_probs = nn.functional.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True)
  389. codevector_probs = codevector_probs.type_as(hidden_states)
  390. # compute perplexity
  391. codevector_soft_dist = torch.softmax(
  392. hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
  393. )
  394. perplexity = self._compute_perplexity(codevector_soft_dist)
  395. else:
  396. # take argmax in non-differentiable way
  397. # comptute hard codevector distribution (one hot)
  398. codevector_idx = hidden_states.argmax(dim=-1)
  399. codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
  400. -1, codevector_idx.view(-1, 1), 1.0
  401. )
  402. codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
  403. perplexity = self._compute_perplexity(codevector_probs)
  404. codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
  405. # use probs to retrieve codevectors
  406. codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
  407. codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
  408. codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
  409. return codevectors, perplexity
  410. class WavLMPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
  411. config: WavLMConfig
  412. base_model_prefix = "wavlm"
  413. main_input_name = "input_values"
  414. input_modalities = "audio"
  415. supports_gradient_checkpointing = True
  416. _supports_flash_attn = False
  417. _supports_sdpa = False
  418. _supports_flex_attn = False
  419. @torch.no_grad()
  420. def _init_weights(self, module):
  421. """Initialize the weights"""
  422. # gumbel softmax requires special init
  423. if isinstance(module, WavLMGumbelVectorQuantizer):
  424. init.normal_(module.weight_proj.weight, mean=0.0, std=1)
  425. init.zeros_(module.weight_proj.bias)
  426. init.uniform_(module.codevectors)
  427. elif isinstance(module, WavLMPositionalConvEmbedding):
  428. init.normal_(
  429. module.conv.weight,
  430. mean=0,
  431. std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
  432. )
  433. init.constant_(module.conv.bias, 0)
  434. elif isinstance(module, WavLMFeatureProjection):
  435. k = math.sqrt(1 / module.projection.in_features)
  436. init.uniform_(module.projection.weight, a=-k, b=k)
  437. init.uniform_(module.projection.bias, a=-k, b=k)
  438. elif isinstance(module, nn.Linear):
  439. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  440. if module.bias is not None:
  441. init.zeros_(module.bias)
  442. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  443. init.zeros_(module.bias)
  444. init.ones_(module.weight)
  445. elif isinstance(module, nn.Conv1d):
  446. init.kaiming_normal_(module.weight)
  447. if module.bias is not None:
  448. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  449. init.uniform_(module.bias, a=-k, b=k)
  450. def _get_adapters(self):
  451. raise AttributeError("Not needed for WavLM")
  452. def init_adapter_layers(self):
  453. raise AttributeError("Not needed for WavLM")
  454. def load_adapter(self):
  455. raise AttributeError("Not needed for WavLM")
  456. WavLMBaseModelOutput = Wav2Vec2BaseModelOutput
  457. class WavLMModel(Wav2Vec2Model):
  458. pass
  459. class WavLMForCTC(Wav2Vec2ForCTC):
  460. pass
  461. class WavLMForSequenceClassification(Wav2Vec2ForSequenceClassification):
  462. pass
  463. class WavLMForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification):
  464. pass
  465. class WavLMForXVector(Wav2Vec2ForXVector):
  466. pass
  467. __all__ = [
  468. "WavLMForAudioFrameClassification",
  469. "WavLMForCTC",
  470. "WavLMForSequenceClassification",
  471. "WavLMForXVector",
  472. "WavLMModel",
  473. "WavLMPreTrainedModel",
  474. ]