modeling_whisper.py 57 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359
  1. # Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Whisper model."""
  15. import math
  16. from collections.abc import Callable
  17. import numpy as np
  18. import torch
  19. from torch import nn
  20. from torch.nn import CrossEntropyLoss
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  24. from ...generation import GenerationMixin
  25. from ...masking_utils import create_causal_mask
  26. from ...modeling_flash_attention_utils import (
  27. FlashAttentionKwargs,
  28. )
  29. from ...modeling_layers import GradientCheckpointingLayer
  30. from ...modeling_outputs import (
  31. BaseModelOutput,
  32. BaseModelOutputWithPastAndCrossAttentions,
  33. CausalLMOutputWithCrossAttentions,
  34. Seq2SeqLMOutput,
  35. Seq2SeqModelOutput,
  36. SequenceClassifierOutput,
  37. )
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  41. from ...utils.generic import merge_with_config_defaults
  42. from ...utils.output_capturing import OutputRecorder, capture_outputs
  43. from .configuration_whisper import WhisperConfig
  44. from .generation_whisper import WhisperGenerationMixin
  45. logger = logging.get_logger(__name__)
  46. _HIDDEN_STATES_START_POSITION = 1
  47. def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
  48. """Returns sinusoids for positional embedding"""
  49. if channels % 2 != 0:
  50. raise ValueError(
  51. f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
  52. )
  53. log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
  54. inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
  55. scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
  56. return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
  57. # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
  58. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  59. """
  60. Shift input ids one token to the right.
  61. """
  62. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  63. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  64. shifted_input_ids[:, 0] = decoder_start_token_id
  65. if pad_token_id is None:
  66. raise ValueError("self.model.config.pad_token_id has to be defined.")
  67. # replace possible -100 values in labels by `pad_token_id`
  68. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  69. return shifted_input_ids
  70. # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
  71. def _compute_mask_indices(
  72. shape: tuple[int, int],
  73. mask_prob: float,
  74. mask_length: int,
  75. attention_mask: torch.LongTensor | None = None,
  76. min_masks: int = 0,
  77. ) -> np.ndarray:
  78. """
  79. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  80. ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  81. CPU as part of the preprocessing during training.
  82. Args:
  83. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  84. the first element is the batch size and the second element is the length of the axis to span.
  85. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  86. independently generated mask spans of length `mask_length` is computed by
  87. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  88. actual percentage will be smaller.
  89. mask_length: size of the mask
  90. min_masks: minimum number of masked spans
  91. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  92. each batch dimension.
  93. """
  94. batch_size, sequence_length = shape
  95. if mask_length < 1:
  96. raise ValueError("`mask_length` has to be bigger than 0.")
  97. if mask_length > sequence_length:
  98. raise ValueError(
  99. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  100. f" and `sequence_length`: {sequence_length}`"
  101. )
  102. # epsilon is used for probabilistic rounding
  103. epsilon = np.random.rand(1).item()
  104. def compute_num_masked_span(input_length):
  105. """Given input length, compute how many spans should be masked"""
  106. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  107. num_masked_span = max(num_masked_span, min_masks)
  108. # make sure num masked span <= sequence_length
  109. if num_masked_span * mask_length > sequence_length:
  110. num_masked_span = sequence_length // mask_length
  111. # make sure num_masked span is also <= input_length - (mask_length - 1)
  112. if input_length - (mask_length - 1) < num_masked_span:
  113. num_masked_span = max(input_length - (mask_length - 1), 0)
  114. return num_masked_span
  115. # compute number of masked spans in batch
  116. input_lengths = (
  117. attention_mask.detach().sum(-1).tolist()
  118. if attention_mask is not None
  119. else [sequence_length for _ in range(batch_size)]
  120. )
  121. # SpecAugment mask to fill
  122. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  123. spec_aug_mask_idxs = []
  124. max_num_masked_span = compute_num_masked_span(sequence_length)
  125. if max_num_masked_span == 0:
  126. return spec_aug_mask
  127. for input_length in input_lengths:
  128. # compute num of masked spans for this input
  129. num_masked_span = compute_num_masked_span(input_length)
  130. # get random indices to mask
  131. spec_aug_mask_idx = np.random.choice(
  132. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  133. )
  134. # pick first sampled index that will serve as a dummy index to pad vector
  135. # to ensure same dimension for all batches due to probabilistic rounding
  136. # Picking first sample just pads those vectors twice.
  137. if len(spec_aug_mask_idx) == 0:
  138. # this case can only happen if `input_length` is strictly smaller then
  139. # `sequence_length` in which case the last token has to be a padding
  140. # token which we can use as a dummy mask id
  141. dummy_mask_idx = sequence_length - 1
  142. else:
  143. dummy_mask_idx = spec_aug_mask_idx[0]
  144. spec_aug_mask_idx = np.concatenate(
  145. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  146. )
  147. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  148. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  149. # expand masked indices to masked spans
  150. spec_aug_mask_idxs = np.broadcast_to(
  151. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  152. )
  153. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  154. # add offset to the starting indexes so that indexes now create a span
  155. offsets = np.arange(mask_length)[None, None, :]
  156. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  157. batch_size, max_num_masked_span * mask_length
  158. )
  159. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  160. # ensure that we cannot have indices larger than sequence_length
  161. if spec_aug_mask_idxs.max() > sequence_length - 1:
  162. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  163. # scatter indices to mask
  164. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  165. return spec_aug_mask
  166. class WhisperPositionalEmbedding(nn.Embedding):
  167. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: int | None = None):
  168. super().__init__(num_positions, embedding_dim)
  169. def forward(self, input_ids, past_key_values_length=0, position_ids=None):
  170. if position_ids is None:
  171. return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
  172. else:
  173. return self.weight[position_ids]
  174. def eager_attention_forward(
  175. module: nn.Module,
  176. query: torch.Tensor,
  177. key: torch.Tensor,
  178. value: torch.Tensor,
  179. attention_mask: torch.Tensor | None,
  180. scaling: float | None = None,
  181. dropout: float = 0.0,
  182. **kwargs,
  183. ):
  184. if scaling is None:
  185. scaling = query.size(-1) ** -0.5
  186. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  187. if attention_mask is not None:
  188. attn_weights = attn_weights + attention_mask
  189. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  190. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  191. attn_output = torch.matmul(attn_weights, value)
  192. attn_output = attn_output.transpose(1, 2).contiguous()
  193. return attn_output, attn_weights
  194. class WhisperAttention(nn.Module):
  195. """Multi-headed attention from 'Attention Is All You Need' paper"""
  196. def __init__(
  197. self,
  198. embed_dim: int,
  199. num_heads: int,
  200. dropout: float = 0.0,
  201. is_decoder: bool = False,
  202. bias: bool = True,
  203. is_causal: bool = False,
  204. layer_idx: int | None = None,
  205. config: WhisperConfig | None = None,
  206. ):
  207. super().__init__()
  208. self.embed_dim = embed_dim
  209. self.num_heads = num_heads
  210. self.dropout = dropout
  211. self.head_dim = embed_dim // num_heads
  212. self.config = config
  213. if (self.head_dim * num_heads) != self.embed_dim:
  214. raise ValueError(
  215. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  216. f" and `num_heads`: {num_heads})."
  217. )
  218. self.scaling = self.head_dim**-0.5
  219. self.is_decoder = is_decoder
  220. self.is_causal = is_causal
  221. if layer_idx is None and is_decoder:
  222. logger.warning_once(
  223. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  224. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  225. "when creating this class."
  226. )
  227. self.layer_idx = layer_idx
  228. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
  229. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  230. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  231. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  232. def forward(
  233. self,
  234. hidden_states: torch.Tensor,
  235. key_value_states: torch.Tensor | None = None,
  236. past_key_values: Cache | None = None,
  237. attention_mask: torch.Tensor | None = None,
  238. output_attentions: bool = False,
  239. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  240. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  241. **kwargs: Unpack[FlashAttentionKwargs],
  242. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  243. """Input shape: Batch x Time x Channel"""
  244. # if key_value_states are provided this layer is used as a cross-attention layer
  245. # for the decoder
  246. is_cross_attention = key_value_states is not None
  247. input_shape = hidden_states.shape[:-1]
  248. hidden_shape = (*input_shape, -1, self.head_dim)
  249. # Scaling is susceptible to floating point arithmetics' inprecisions
  250. # which can lead to different results (this is dependent from model
  251. # to model, e.g. whisper is one such case). We therefore keep the
  252. # original order of scaling to follow the original implementation
  253. # and enforce no scaling (1.0) in the attention call below.
  254. query_states = (self.q_proj(hidden_states) * self.scaling).view(hidden_shape).transpose(1, 2).contiguous()
  255. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  256. if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache):
  257. is_updated = past_key_values.is_updated.get(self.layer_idx)
  258. if is_cross_attention:
  259. # after the first generated id, we can subsequently re-use all key/value_states from cache
  260. past_key_values.is_updated[self.layer_idx] = True
  261. past_key_values = past_key_values.cross_attention_cache
  262. else:
  263. past_key_values = past_key_values.self_attention_cache
  264. # use key_value_states if cross attention
  265. current_states = key_value_states if key_value_states is not None else hidden_states
  266. if is_cross_attention and past_key_values and is_updated:
  267. # reuse k,v, cross_attentions
  268. key_states = past_key_values.layers[self.layer_idx].keys
  269. value_states = past_key_values.layers[self.layer_idx].values
  270. else:
  271. # Use the query's batch dimension for kv view so that a different-batch
  272. # encoder output (e.g. in tests) gets absorbed into the sequence axis,
  273. # preserving backward-compatible behaviour.
  274. kv_shape = (input_shape[0], -1, self.num_heads, self.head_dim)
  275. key_states = self.k_proj(current_states).view(kv_shape).transpose(1, 2).contiguous()
  276. value_states = self.v_proj(current_states).view(kv_shape).transpose(1, 2).contiguous()
  277. if past_key_values is not None:
  278. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  279. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  280. self.config._attn_implementation, eager_attention_forward
  281. )
  282. attn_output, attn_weights = attention_interface(
  283. self,
  284. query_states,
  285. key_states,
  286. value_states,
  287. attention_mask,
  288. dropout=0.0 if not self.training else self.dropout,
  289. scaling=1.0,
  290. output_attentions=output_attentions,
  291. **kwargs,
  292. )
  293. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  294. attn_output = self.out_proj(attn_output)
  295. return attn_output, attn_weights
  296. # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER
  297. class WhisperEncoderLayer(GradientCheckpointingLayer):
  298. def __init__(self, config: WhisperConfig):
  299. super().__init__()
  300. self.embed_dim = config.d_model
  301. self.self_attn = WhisperAttention(
  302. embed_dim=self.embed_dim,
  303. num_heads=config.encoder_attention_heads,
  304. dropout=config.attention_dropout,
  305. config=config,
  306. )
  307. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  308. self.dropout = config.dropout
  309. self.activation_fn = ACT2FN[config.activation_function]
  310. self.activation_dropout = config.activation_dropout
  311. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  312. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  313. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  314. def forward(
  315. self,
  316. hidden_states: torch.Tensor,
  317. attention_mask: torch.Tensor,
  318. **kwargs: Unpack[TransformersKwargs],
  319. ) -> torch.Tensor:
  320. """
  321. Args:
  322. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  323. attention_mask (`torch.FloatTensor`): attention mask of size
  324. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  325. """
  326. residual = hidden_states
  327. hidden_states = self.self_attn_layer_norm(hidden_states)
  328. hidden_states, _ = self.self_attn(
  329. hidden_states=hidden_states,
  330. attention_mask=attention_mask,
  331. **kwargs,
  332. )
  333. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  334. hidden_states = residual + hidden_states
  335. residual = hidden_states
  336. hidden_states = self.final_layer_norm(hidden_states)
  337. hidden_states = self.activation_fn(self.fc1(hidden_states))
  338. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  339. hidden_states = self.fc2(hidden_states)
  340. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  341. hidden_states = residual + hidden_states
  342. if hidden_states.dtype == torch.float16:
  343. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  344. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  345. return hidden_states
  346. class WhisperDecoderLayer(GradientCheckpointingLayer):
  347. def __init__(self, config: WhisperConfig, layer_idx: int | None = None):
  348. super().__init__()
  349. self.embed_dim = config.d_model
  350. self.self_attn = WhisperAttention(
  351. embed_dim=self.embed_dim,
  352. num_heads=config.decoder_attention_heads,
  353. dropout=config.attention_dropout,
  354. is_decoder=True,
  355. is_causal=True,
  356. layer_idx=layer_idx,
  357. config=config,
  358. )
  359. self.dropout = config.dropout
  360. self.activation_fn = ACT2FN[config.activation_function]
  361. self.activation_dropout = config.activation_dropout
  362. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  363. self.encoder_attn = WhisperAttention(
  364. self.embed_dim,
  365. config.decoder_attention_heads,
  366. dropout=config.attention_dropout,
  367. is_decoder=True,
  368. layer_idx=layer_idx,
  369. config=config,
  370. )
  371. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  372. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  373. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  374. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  375. def forward(
  376. self,
  377. hidden_states: torch.Tensor,
  378. attention_mask: torch.Tensor | None = None,
  379. encoder_hidden_states: torch.Tensor | None = None,
  380. encoder_attention_mask: torch.Tensor | None = None,
  381. past_key_values: EncoderDecoderCache | None = None,
  382. use_cache: bool | None = True,
  383. **kwargs: Unpack[TransformersKwargs],
  384. ) -> torch.Tensor:
  385. """
  386. Args:
  387. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  388. attention_mask (`torch.FloatTensor`): attention mask of size
  389. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  390. encoder_hidden_states (`torch.FloatTensor`):
  391. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  392. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  393. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  394. past_key_values (`Cache`): cached past key and value projection states
  395. """
  396. residual = hidden_states
  397. hidden_states = self.self_attn_layer_norm(hidden_states)
  398. # Self Attention
  399. hidden_states, _ = self.self_attn(
  400. hidden_states,
  401. past_key_values=past_key_values,
  402. attention_mask=attention_mask,
  403. **kwargs,
  404. )
  405. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  406. hidden_states = residual + hidden_states
  407. # Cross-Attention Block
  408. if encoder_hidden_states is not None:
  409. residual = hidden_states
  410. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  411. hidden_states, _ = self.encoder_attn(
  412. hidden_states,
  413. key_value_states=encoder_hidden_states,
  414. attention_mask=encoder_attention_mask,
  415. past_key_values=past_key_values,
  416. **kwargs,
  417. )
  418. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  419. hidden_states = residual + hidden_states
  420. # Fully Connected
  421. residual = hidden_states
  422. hidden_states = self.final_layer_norm(hidden_states)
  423. hidden_states = self.activation_fn(self.fc1(hidden_states))
  424. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  425. hidden_states = self.fc2(hidden_states)
  426. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  427. hidden_states = residual + hidden_states
  428. return hidden_states
  429. @auto_docstring
  430. class WhisperPreTrainedModel(PreTrainedModel):
  431. config: WhisperConfig
  432. base_model_prefix = "model"
  433. main_input_name = "input_features"
  434. input_modalities = ("audio", "text")
  435. supports_gradient_checkpointing = True
  436. _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
  437. _supports_flash_attn = True
  438. _supports_sdpa = True
  439. _supports_flex_attn = True
  440. _can_compile_fullgraph = True
  441. @torch.no_grad()
  442. def _init_weights(self, module):
  443. super()._init_weights(module)
  444. if isinstance(module, WhisperEncoder):
  445. init.copy_(module.embed_positions.weight, sinusoids(*module.embed_positions.weight.shape))
  446. elif isinstance(module, WhisperForAudioClassification):
  447. if self.config.use_weighted_layer_sum:
  448. init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1))
  449. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
  450. """
  451. Computes the output length of the convolutional layers
  452. """
  453. input_lengths = (input_lengths - 1) // 2 + 1
  454. return input_lengths
  455. class WhisperEncoder(WhisperPreTrainedModel):
  456. """
  457. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  458. [`WhisperEncoderLayer`].
  459. Args:
  460. config: WhisperConfig
  461. """
  462. _can_record_outputs = {
  463. "hidden_states": WhisperEncoderLayer,
  464. "attentions": WhisperAttention,
  465. }
  466. input_modalities = ("audio",)
  467. def __init__(self, config: WhisperConfig):
  468. super().__init__(config)
  469. self.dropout = config.dropout
  470. self.layerdrop = config.encoder_layerdrop
  471. embed_dim = config.d_model
  472. self.num_mel_bins = config.num_mel_bins
  473. self.padding_idx = config.pad_token_id
  474. self.max_source_positions = config.max_source_positions
  475. self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  476. self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
  477. self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
  478. self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
  479. self.embed_positions.requires_grad_(False)
  480. self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
  481. self.layer_norm = nn.LayerNorm(config.d_model)
  482. self.gradient_checkpointing = False
  483. # Initialize weights and apply final processing
  484. self.post_init()
  485. def _freeze_parameters(self):
  486. for param in self.parameters():
  487. param.requires_grad = False
  488. self._requires_grad = False
  489. def get_input_embeddings(self) -> nn.Module:
  490. return self.conv1
  491. def set_input_embeddings(self, value: nn.Module):
  492. self.conv1 = value
  493. @merge_with_config_defaults
  494. @capture_outputs
  495. def forward(
  496. self,
  497. input_features,
  498. attention_mask=None,
  499. **kwargs: Unpack[TransformersKwargs],
  500. ) -> BaseModelOutput:
  501. r"""
  502. Args:
  503. input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
  504. Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
  505. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
  506. `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
  507. the soundfile library (`pip install soundfile`). To prepare the array into
  508. `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
  509. and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
  510. attention_mask (`torch.Tensor`)`, *optional*):
  511. Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
  512. but it is not used. By default the silence in the input log mel spectrogram are ignored.
  513. """
  514. expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
  515. if input_features.shape[-1] != expected_seq_length:
  516. raise ValueError(
  517. f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
  518. )
  519. inputs_embeds = nn.functional.gelu(self.conv1(input_features))
  520. inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
  521. inputs_embeds = inputs_embeds.permute(0, 2, 1)
  522. all_positions = torch.arange(self.embed_positions.num_embeddings, device=inputs_embeds.device)
  523. hidden_states = inputs_embeds + self.embed_positions(all_positions)
  524. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  525. for idx, encoder_layer in enumerate(self.layers):
  526. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  527. to_drop = False
  528. if self.training:
  529. dropout_probability = torch.rand([])
  530. if dropout_probability < self.layerdrop: # skip the layer
  531. to_drop = True
  532. if not to_drop:
  533. hidden_states = encoder_layer(
  534. hidden_states,
  535. None,
  536. **kwargs,
  537. )
  538. hidden_states = self.layer_norm(hidden_states)
  539. return BaseModelOutput(
  540. last_hidden_state=hidden_states,
  541. )
  542. class WhisperDecoder(WhisperPreTrainedModel):
  543. """
  544. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`WhisperDecoderLayer`]
  545. Args:
  546. config: WhisperConfig
  547. """
  548. _can_record_outputs = {
  549. "hidden_states": WhisperDecoderLayer,
  550. "attentions": OutputRecorder(WhisperAttention, index=1, layer_name="self_attn"),
  551. "cross_attentions": OutputRecorder(WhisperAttention, index=1, layer_name="encoder_attn"),
  552. }
  553. main_input_name = "input_ids"
  554. input_modalities = ("text",)
  555. def __init__(self, config: WhisperConfig):
  556. super().__init__(config)
  557. self.dropout = config.dropout
  558. self.layerdrop = config.decoder_layerdrop
  559. self.padding_idx = config.pad_token_id
  560. self.max_target_positions = config.max_target_positions
  561. self.max_source_positions = config.max_source_positions
  562. self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  563. self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
  564. self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
  565. self.layers = nn.ModuleList(
  566. [WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)]
  567. )
  568. self.layer_norm = nn.LayerNorm(config.d_model)
  569. self.gradient_checkpointing = False
  570. # Initialize weights and apply final processing
  571. self.post_init()
  572. @merge_with_config_defaults
  573. @capture_outputs
  574. def forward(
  575. self,
  576. input_ids=None,
  577. attention_mask=None,
  578. encoder_hidden_states=None,
  579. past_key_values=None,
  580. inputs_embeds=None,
  581. position_ids=None,
  582. use_cache=None,
  583. **kwargs: Unpack[TransformersKwargs],
  584. ) -> BaseModelOutputWithPastAndCrossAttentions:
  585. r"""
  586. Args:
  587. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  588. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  589. provide it.
  590. Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  591. [`PreTrainedTokenizer.__call__`] for details.
  592. [What are input IDs?](../glossary#input-ids)
  593. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  594. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  595. - 1 for tokens that are **not masked**,
  596. - 0 for tokens that are **masked**.
  597. [What are attention masks?](../glossary#attention-mask)
  598. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  599. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  600. of the decoder.
  601. past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
  602. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  603. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  604. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  605. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  606. inputs_embeds (`torch.FloatTensor` of
  607. shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
  608. `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
  609. control over how to convert `input_ids` indices into associated vectors than the model's internal
  610. embedding lookup matrix.
  611. """
  612. if (input_ids is None) ^ (inputs_embeds is not None):
  613. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  614. if inputs_embeds is None:
  615. inputs_embeds = self.embed_tokens(input_ids)
  616. if use_cache and past_key_values is None:
  617. past_key_values = (
  618. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  619. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  620. else DynamicCache(config=self.config)
  621. )
  622. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  623. if position_ids is None:
  624. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_key_values_length
  625. position_ids = position_ids.unsqueeze(0).repeat(inputs_embeds.shape[0], 1)
  626. # embed positions
  627. if input_ids is not None:
  628. positions = self.embed_positions(
  629. input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
  630. )
  631. else:
  632. positions = self.embed_positions(
  633. inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
  634. )
  635. hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
  636. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  637. causal_mask = create_causal_mask(
  638. config=self.config,
  639. inputs_embeds=inputs_embeds,
  640. attention_mask=attention_mask,
  641. past_key_values=past_key_values,
  642. position_ids=position_ids,
  643. )
  644. for idx, decoder_layer in enumerate(self.layers):
  645. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  646. if self.training:
  647. dropout_probability = torch.rand([])
  648. if dropout_probability < self.layerdrop:
  649. continue
  650. hidden_states = decoder_layer(
  651. hidden_states,
  652. causal_mask,
  653. encoder_hidden_states,
  654. encoder_attention_mask=None,
  655. past_key_values=past_key_values if use_cache else None,
  656. use_cache=use_cache,
  657. **kwargs,
  658. )
  659. hidden_states = self.layer_norm(hidden_states)
  660. return BaseModelOutputWithPastAndCrossAttentions(
  661. last_hidden_state=hidden_states,
  662. past_key_values=past_key_values,
  663. )
  664. @auto_docstring
  665. class WhisperModel(WhisperPreTrainedModel):
  666. def __init__(self, config: WhisperConfig):
  667. super().__init__(config)
  668. self.encoder = WhisperEncoder(config)
  669. self.decoder = WhisperDecoder(config)
  670. # Initialize weights and apply final processing
  671. self.post_init()
  672. def get_input_embeddings(self):
  673. return self.decoder.embed_tokens
  674. def set_input_embeddings(self, value):
  675. self.decoder.embed_tokens = value
  676. def freeze_encoder(self):
  677. """
  678. Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
  679. not be updated during training.
  680. """
  681. self.encoder._freeze_parameters()
  682. def _mask_input_features(
  683. self,
  684. input_features: torch.FloatTensor,
  685. attention_mask: torch.LongTensor | None = None,
  686. ):
  687. """
  688. Masks extracted features along time axis and/or along feature axis according to
  689. [SpecAugment](https://huggingface.co/papers/1904.08779).
  690. """
  691. # `config.apply_spec_augment` can set masking to False
  692. if not getattr(self.config, "apply_spec_augment", True):
  693. return input_features
  694. # generate indices & apply SpecAugment along time axis
  695. batch_size, hidden_size, sequence_length = input_features.size()
  696. if self.config.mask_time_prob > 0 and self.training:
  697. # generate indices & apply SpecAugment along time axis
  698. mask_time_indices = _compute_mask_indices(
  699. (batch_size, sequence_length),
  700. mask_prob=self.config.mask_time_prob,
  701. mask_length=self.config.mask_time_length,
  702. attention_mask=attention_mask,
  703. min_masks=self.config.mask_time_min_masks,
  704. )
  705. mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool)
  706. mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1)
  707. input_features[mask_time_indices] = 0
  708. if self.config.mask_feature_prob > 0 and self.training:
  709. # generate indices & apply SpecAugment along feature axis
  710. mask_feature_indices = _compute_mask_indices(
  711. (batch_size, hidden_size),
  712. mask_prob=self.config.mask_feature_prob,
  713. mask_length=self.config.mask_feature_length,
  714. min_masks=self.config.mask_feature_min_masks,
  715. )
  716. mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool)
  717. input_features[mask_feature_indices] = 0
  718. return input_features
  719. @can_return_tuple
  720. @auto_docstring
  721. def forward(
  722. self,
  723. input_features: torch.FloatTensor | None = None,
  724. attention_mask: torch.LongTensor | None = None,
  725. decoder_input_ids: torch.LongTensor | None = None,
  726. decoder_attention_mask: torch.LongTensor | None = None,
  727. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  728. past_key_values: Cache | None = None,
  729. decoder_inputs_embeds: tuple[torch.FloatTensor] | None = None,
  730. decoder_position_ids: tuple[torch.LongTensor] | None = None,
  731. use_cache: bool | None = None,
  732. **kwargs,
  733. ) -> tuple[torch.Tensor] | Seq2SeqModelOutput:
  734. r"""
  735. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  736. Indices of decoder input sequence tokens in the vocabulary.
  737. Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  738. [`PreTrainedTokenizer.__call__`] for details.
  739. [What are decoder input IDs?](../glossary#decoder-input-ids)
  740. Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. If
  741. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  742. `past_key_values`).
  743. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  744. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  745. be used by default.
  746. If you want to change padding behavior, you should read
  747. [`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the BART
  748. paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy.
  749. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  750. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  751. config.n_positions - 1]`.
  752. [What are position IDs?](../glossary#position-ids)
  753. Example:
  754. ```python
  755. >>> import torch
  756. >>> from transformers import AutoFeatureExtractor, WhisperModel
  757. >>> from datasets import load_dataset
  758. >>> model = WhisperModel.from_pretrained("openai/whisper-base")
  759. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
  760. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  761. >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
  762. >>> input_features = inputs.input_features
  763. >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
  764. >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
  765. >>> list(last_hidden_state.shape)
  766. [1, 2, 512]
  767. ```"""
  768. if encoder_outputs is None:
  769. input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
  770. encoder_outputs = self.encoder(
  771. input_features,
  772. **kwargs,
  773. )
  774. elif not isinstance(encoder_outputs, BaseModelOutput):
  775. encoder_outputs = BaseModelOutput(
  776. last_hidden_state=encoder_outputs[0],
  777. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  778. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  779. )
  780. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  781. decoder_outputs = self.decoder(
  782. input_ids=decoder_input_ids,
  783. attention_mask=decoder_attention_mask,
  784. encoder_hidden_states=encoder_outputs[0],
  785. past_key_values=past_key_values,
  786. inputs_embeds=decoder_inputs_embeds,
  787. position_ids=decoder_position_ids,
  788. use_cache=use_cache,
  789. **kwargs,
  790. )
  791. return Seq2SeqModelOutput(
  792. last_hidden_state=decoder_outputs.last_hidden_state,
  793. past_key_values=decoder_outputs.past_key_values,
  794. decoder_hidden_states=decoder_outputs.hidden_states,
  795. decoder_attentions=decoder_outputs.attentions,
  796. cross_attentions=decoder_outputs.cross_attentions,
  797. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  798. encoder_hidden_states=encoder_outputs.hidden_states,
  799. encoder_attentions=encoder_outputs.attentions,
  800. )
  801. @auto_docstring(
  802. custom_intro="""
  803. The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
  804. """
  805. )
  806. class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedModel):
  807. base_model_prefix = "model"
  808. _tied_weights_keys = {"proj_out.weight": "model.decoder.embed_tokens.weight"}
  809. def __init__(self, config: WhisperConfig):
  810. super().__init__(config)
  811. self.model = WhisperModel(config)
  812. self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
  813. self.max_target_positions = config.max_target_positions
  814. # Initialize weights and apply final processing
  815. self.post_init()
  816. def get_output_embeddings(self):
  817. return self.proj_out
  818. def set_output_embeddings(self, new_embeddings):
  819. self.proj_out = new_embeddings
  820. def get_input_embeddings(self) -> nn.Module:
  821. return self.model.get_input_embeddings()
  822. def freeze_encoder(self):
  823. """
  824. Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
  825. not be updated during training.
  826. """
  827. self.model.encoder._freeze_parameters()
  828. @can_return_tuple
  829. @auto_docstring
  830. def forward(
  831. self,
  832. input_features: torch.FloatTensor | None = None,
  833. attention_mask: torch.LongTensor | None = None,
  834. decoder_input_ids: torch.LongTensor | None = None,
  835. decoder_attention_mask: torch.LongTensor | None = None,
  836. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  837. past_key_values: Cache | None = None,
  838. decoder_inputs_embeds: tuple[torch.FloatTensor] | None = None,
  839. decoder_position_ids: tuple[torch.LongTensor] | None = None,
  840. labels: torch.LongTensor | None = None,
  841. use_cache: bool | None = None,
  842. **kwargs,
  843. ) -> tuple[torch.Tensor] | Seq2SeqLMOutput:
  844. r"""
  845. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  846. Indices of decoder input sequence tokens in the vocabulary.
  847. Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  848. [`PreTrainedTokenizer.__call__`] for details.
  849. [What are decoder input IDs?](../glossary#decoder-input-ids)
  850. Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. If
  851. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  852. `past_key_values`).
  853. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  854. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  855. be used by default.
  856. If you want to change padding behavior, you should read
  857. [`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the BART
  858. paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy.
  859. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  860. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  861. config.n_positions - 1]`.
  862. [What are position IDs?](../glossary#position-ids)
  863. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  864. Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
  865. or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
  866. only computed for the tokens with labels in `[0, ..., config.vocab_size]`. `sequence_length` should be smaller than or equal to `config.max_target_positions`.
  867. Example:
  868. ```python
  869. >>> import torch
  870. >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
  871. >>> from datasets import load_dataset
  872. >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
  873. >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
  874. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  875. >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
  876. >>> input_features = inputs.input_features
  877. >>> generated_ids = model.generate(inputs=input_features)
  878. >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  879. >>> transcription
  880. ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
  881. ```"""
  882. if labels is not None:
  883. if labels.shape[1] > self.max_target_positions:
  884. raise ValueError(
  885. f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens."
  886. )
  887. if decoder_input_ids is None and decoder_inputs_embeds is None:
  888. decoder_input_ids = shift_tokens_right(
  889. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  890. )
  891. outputs: Seq2SeqModelOutput = self.model(
  892. input_features,
  893. attention_mask=attention_mask,
  894. decoder_input_ids=decoder_input_ids,
  895. encoder_outputs=encoder_outputs,
  896. decoder_attention_mask=decoder_attention_mask,
  897. past_key_values=past_key_values,
  898. decoder_inputs_embeds=decoder_inputs_embeds,
  899. decoder_position_ids=decoder_position_ids,
  900. use_cache=use_cache,
  901. **kwargs,
  902. )
  903. lm_logits = self.proj_out(outputs.last_hidden_state)
  904. loss = None
  905. if labels is not None:
  906. loss_fct = CrossEntropyLoss()
  907. # move labels to correct device to enable PP
  908. labels = labels.to(lm_logits.device)
  909. loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
  910. return Seq2SeqLMOutput(
  911. loss=loss,
  912. logits=lm_logits,
  913. past_key_values=outputs.past_key_values,
  914. decoder_hidden_states=outputs.decoder_hidden_states,
  915. decoder_attentions=outputs.decoder_attentions,
  916. cross_attentions=outputs.cross_attentions,
  917. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  918. encoder_hidden_states=outputs.encoder_hidden_states,
  919. encoder_attentions=outputs.encoder_attentions,
  920. )
  921. class WhisperDecoderWrapper(WhisperPreTrainedModel):
  922. """
  923. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  924. used in combination with the [`EncoderDecoderModel`] framework.
  925. """
  926. def __init__(self, config):
  927. super().__init__(config)
  928. config.is_encoder_decoder = False
  929. self.decoder = WhisperDecoder(config)
  930. self.post_init()
  931. def get_input_embeddings(self):
  932. return self.decoder.embed_tokens
  933. def set_input_embeddings(self, value):
  934. self.decoder.embed_tokens = value
  935. def forward(self, *args, **kwargs):
  936. return self.decoder(*args, **kwargs)
  937. @auto_docstring(
  938. custom_intro="""
  939. Whisper decoder with a language modeling head on top (linear layer with weights tied to the input embeddings).
  940. """
  941. )
  942. class WhisperForCausalLM(WhisperPreTrainedModel, GenerationMixin):
  943. _tied_weights_keys = {"proj_out.weight": "model.decoder.embed_tokens.weight"}
  944. main_input_name = "input_ids"
  945. def __init__(self, config):
  946. super().__init__(config)
  947. config.is_encoder_decoder = False
  948. self.model = WhisperDecoderWrapper(config)
  949. self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  950. # Initialize weights and apply final processing
  951. self.post_init()
  952. def get_output_embeddings(self):
  953. return self.proj_out
  954. def set_output_embeddings(self, new_embeddings):
  955. self.proj_out = new_embeddings
  956. def get_input_embeddings(self) -> nn.Module:
  957. return self.model.get_input_embeddings()
  958. def set_input_embeddings(self, value):
  959. self.model.set_input_embeddings(value)
  960. @can_return_tuple
  961. @auto_docstring
  962. def forward(
  963. self,
  964. input_ids: torch.LongTensor | None = None,
  965. attention_mask: torch.Tensor | None = None,
  966. encoder_outputs: tuple[torch.FloatTensor] | None = None,
  967. past_key_values: Cache | None = None,
  968. inputs_embeds: torch.FloatTensor | None = None,
  969. labels: torch.LongTensor | None = None,
  970. use_cache: bool | None = None,
  971. **kwargs,
  972. ) -> tuple | CausalLMOutputWithCrossAttentions:
  973. r"""
  974. encoder_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  975. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  976. if the model is configured as a decoder.
  977. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  978. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  979. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  980. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  981. Example:
  982. ```python
  983. >>> from transformers import WhisperForCausalLM, WhisperForConditionalGeneration, WhisperProcessor
  984. >>> import torch
  985. >>> from datasets import load_dataset
  986. >>> processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
  987. >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2")
  988. >>> assistant_model = WhisperForCausalLM.from_pretrained("distil-whisper/distil-large-v2")
  989. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  990. >>> sample = ds[0]["audio"]
  991. >>> input_features = processor(
  992. ... sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
  993. ... ).input_features
  994. >>> predicted_ids = model.generate(input_features, assistant_model=assistant_model)
  995. >>> # decode token ids to text
  996. >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
  997. >>> transcription
  998. ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'
  999. ```"""
  1000. # If the user passed a tuple or `BaseModelOutput` for encoder_outputs, we extract only the hidden states
  1001. if isinstance(encoder_outputs, (BaseModelOutput, tuple, list)):
  1002. encoder_outputs = encoder_outputs[0]
  1003. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1004. outputs = self.model.decoder(
  1005. input_ids=input_ids,
  1006. attention_mask=attention_mask,
  1007. encoder_hidden_states=encoder_outputs,
  1008. past_key_values=past_key_values,
  1009. inputs_embeds=inputs_embeds,
  1010. use_cache=use_cache,
  1011. **kwargs,
  1012. )
  1013. logits = self.proj_out(outputs[0])
  1014. loss = None
  1015. if labels is not None:
  1016. labels = labels.to(logits.device)
  1017. loss_fct = CrossEntropyLoss()
  1018. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1019. return CausalLMOutputWithCrossAttentions(
  1020. loss=loss,
  1021. logits=logits,
  1022. past_key_values=outputs.past_key_values,
  1023. hidden_states=outputs.hidden_states,
  1024. attentions=outputs.attentions,
  1025. cross_attentions=outputs.cross_attentions,
  1026. )
  1027. @auto_docstring(
  1028. custom_intro="""
  1029. Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
  1030. like SUPERB Keyword Spotting.
  1031. """
  1032. )
  1033. class WhisperForAudioClassification(WhisperPreTrainedModel):
  1034. def __init__(self, config):
  1035. super().__init__(config)
  1036. self.encoder = WhisperEncoder(config)
  1037. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1038. if config.use_weighted_layer_sum:
  1039. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1040. self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
  1041. self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
  1042. # Initialize weights and apply final processing
  1043. self.post_init()
  1044. def freeze_encoder(self):
  1045. """
  1046. Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
  1047. not be updated during training. Only the projection layers and classification head will be updated.
  1048. """
  1049. self.encoder._freeze_parameters()
  1050. def get_input_embeddings(self) -> nn.Module:
  1051. return self.encoder.get_input_embeddings()
  1052. def set_input_embeddings(self, value: nn.Module):
  1053. self.encoder.set_input_embeddings(value)
  1054. @can_return_tuple
  1055. @auto_docstring
  1056. def forward(
  1057. self,
  1058. input_features: torch.LongTensor | None = None,
  1059. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  1060. labels: torch.LongTensor | None = None,
  1061. **kwargs,
  1062. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  1063. r"""
  1064. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1065. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1066. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1067. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1068. Example:
  1069. ```python
  1070. >>> import torch
  1071. >>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification
  1072. >>> from datasets import load_dataset
  1073. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
  1074. >>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
  1075. >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
  1076. >>> sample = next(iter(ds))
  1077. >>> inputs = feature_extractor(
  1078. ... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt"
  1079. ... )
  1080. >>> input_features = inputs.input_features
  1081. >>> with torch.no_grad():
  1082. ... logits = model(input_features).logits
  1083. >>> predicted_class_ids = torch.argmax(logits).item()
  1084. >>> predicted_label = model.config.id2label[predicted_class_ids]
  1085. >>> predicted_label
  1086. 'Afrikaans'
  1087. ```"""
  1088. if self.config.use_weighted_layer_sum:
  1089. kwargs["output_hidden_states"] = True
  1090. if encoder_outputs is None:
  1091. encoder_outputs = self.encoder(
  1092. input_features,
  1093. **kwargs,
  1094. )
  1095. elif not isinstance(encoder_outputs, BaseModelOutput):
  1096. encoder_outputs = BaseModelOutput(
  1097. last_hidden_state=encoder_outputs[0],
  1098. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1099. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1100. )
  1101. if self.config.use_weighted_layer_sum:
  1102. hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
  1103. hidden_states = torch.stack(hidden_states, dim=1)
  1104. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1105. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1106. else:
  1107. hidden_states = encoder_outputs[0]
  1108. hidden_states = self.projector(hidden_states)
  1109. pooled_output = hidden_states.mean(dim=1)
  1110. logits = self.classifier(pooled_output)
  1111. loss = None
  1112. if labels is not None:
  1113. loss_fct = CrossEntropyLoss()
  1114. # move labels to correct device to enable PP
  1115. labels = labels.to(logits.device)
  1116. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1117. return SequenceClassifierOutput(
  1118. loss=loss,
  1119. logits=logits,
  1120. hidden_states=encoder_outputs.hidden_states,
  1121. attentions=encoder_outputs.attentions,
  1122. )
  1123. __all__ = [
  1124. "WhisperForCausalLM",
  1125. "WhisperForConditionalGeneration",
  1126. "WhisperModel",
  1127. "WhisperPreTrainedModel",
  1128. "WhisperForAudioClassification",
  1129. ]