| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359 |
- # Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch Whisper model."""
- import math
- from collections.abc import Callable
- import numpy as np
- import torch
- from torch import nn
- from torch.nn import CrossEntropyLoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
- from ...generation import GenerationMixin
- from ...masking_utils import create_causal_mask
- from ...modeling_flash_attention_utils import (
- FlashAttentionKwargs,
- )
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPastAndCrossAttentions,
- CausalLMOutputWithCrossAttentions,
- Seq2SeqLMOutput,
- Seq2SeqModelOutput,
- SequenceClassifierOutput,
- )
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
- from ...utils.generic import merge_with_config_defaults
- from ...utils.output_capturing import OutputRecorder, capture_outputs
- from .configuration_whisper import WhisperConfig
- from .generation_whisper import WhisperGenerationMixin
- logger = logging.get_logger(__name__)
- _HIDDEN_STATES_START_POSITION = 1
- def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
- """Returns sinusoids for positional embedding"""
- if channels % 2 != 0:
- raise ValueError(
- f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
- )
- log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
- inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
- scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
- return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
- # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
- def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
- """
- Shift input ids one token to the right.
- """
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
- shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
- shifted_input_ids[:, 0] = decoder_start_token_id
- if pad_token_id is None:
- raise ValueError("self.model.config.pad_token_id has to be defined.")
- # replace possible -100 values in labels by `pad_token_id`
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
- return shifted_input_ids
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
- def _compute_mask_indices(
- shape: tuple[int, int],
- mask_prob: float,
- mask_length: int,
- attention_mask: torch.LongTensor | None = None,
- min_masks: int = 0,
- ) -> np.ndarray:
- """
- Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
- ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
- CPU as part of the preprocessing during training.
- Args:
- shape: The shape for which to compute masks. This should be of a tuple of size 2 where
- the first element is the batch size and the second element is the length of the axis to span.
- mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
- independently generated mask spans of length `mask_length` is computed by
- `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
- actual percentage will be smaller.
- mask_length: size of the mask
- min_masks: minimum number of masked spans
- attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
- each batch dimension.
- """
- batch_size, sequence_length = shape
- if mask_length < 1:
- raise ValueError("`mask_length` has to be bigger than 0.")
- if mask_length > sequence_length:
- raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
- f" and `sequence_length`: {sequence_length}`"
- )
- # epsilon is used for probabilistic rounding
- epsilon = np.random.rand(1).item()
- def compute_num_masked_span(input_length):
- """Given input length, compute how many spans should be masked"""
- num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
- num_masked_span = max(num_masked_span, min_masks)
- # make sure num masked span <= sequence_length
- if num_masked_span * mask_length > sequence_length:
- num_masked_span = sequence_length // mask_length
- # make sure num_masked span is also <= input_length - (mask_length - 1)
- if input_length - (mask_length - 1) < num_masked_span:
- num_masked_span = max(input_length - (mask_length - 1), 0)
- return num_masked_span
- # compute number of masked spans in batch
- input_lengths = (
- attention_mask.detach().sum(-1).tolist()
- if attention_mask is not None
- else [sequence_length for _ in range(batch_size)]
- )
- # SpecAugment mask to fill
- spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
- spec_aug_mask_idxs = []
- max_num_masked_span = compute_num_masked_span(sequence_length)
- if max_num_masked_span == 0:
- return spec_aug_mask
- for input_length in input_lengths:
- # compute num of masked spans for this input
- num_masked_span = compute_num_masked_span(input_length)
- # get random indices to mask
- spec_aug_mask_idx = np.random.choice(
- np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
- )
- # pick first sampled index that will serve as a dummy index to pad vector
- # to ensure same dimension for all batches due to probabilistic rounding
- # Picking first sample just pads those vectors twice.
- if len(spec_aug_mask_idx) == 0:
- # this case can only happen if `input_length` is strictly smaller then
- # `sequence_length` in which case the last token has to be a padding
- # token which we can use as a dummy mask id
- dummy_mask_idx = sequence_length - 1
- else:
- dummy_mask_idx = spec_aug_mask_idx[0]
- spec_aug_mask_idx = np.concatenate(
- [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
- )
- spec_aug_mask_idxs.append(spec_aug_mask_idx)
- spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
- # expand masked indices to masked spans
- spec_aug_mask_idxs = np.broadcast_to(
- spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
- # add offset to the starting indexes so that indexes now create a span
- offsets = np.arange(mask_length)[None, None, :]
- offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
- batch_size, max_num_masked_span * mask_length
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
- # ensure that we cannot have indices larger than sequence_length
- if spec_aug_mask_idxs.max() > sequence_length - 1:
- spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
- # scatter indices to mask
- np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
- return spec_aug_mask
- class WhisperPositionalEmbedding(nn.Embedding):
- def __init__(self, num_positions: int, embedding_dim: int, padding_idx: int | None = None):
- super().__init__(num_positions, embedding_dim)
- def forward(self, input_ids, past_key_values_length=0, position_ids=None):
- if position_ids is None:
- return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
- else:
- return self.weight[position_ids]
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float | None = None,
- dropout: float = 0.0,
- **kwargs,
- ):
- if scaling is None:
- scaling = query.size(-1) ** -0.5
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class WhisperAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(
- self,
- embed_dim: int,
- num_heads: int,
- dropout: float = 0.0,
- is_decoder: bool = False,
- bias: bool = True,
- is_causal: bool = False,
- layer_idx: int | None = None,
- config: WhisperConfig | None = None,
- ):
- super().__init__()
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.dropout = dropout
- self.head_dim = embed_dim // num_heads
- self.config = config
- if (self.head_dim * num_heads) != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
- f" and `num_heads`: {num_heads})."
- )
- self.scaling = self.head_dim**-0.5
- self.is_decoder = is_decoder
- self.is_causal = is_causal
- if layer_idx is None and is_decoder:
- logger.warning_once(
- f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
- "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
- self.layer_idx = layer_idx
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: torch.Tensor | None = None,
- past_key_values: Cache | None = None,
- attention_mask: torch.Tensor | None = None,
- output_attentions: bool = False,
- # TODO: we need a refactor so that the different attention modules can get their specific kwargs
- # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
- """Input shape: Batch x Time x Channel"""
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- # Scaling is susceptible to floating point arithmetics' inprecisions
- # which can lead to different results (this is dependent from model
- # to model, e.g. whisper is one such case). We therefore keep the
- # original order of scaling to follow the original implementation
- # and enforce no scaling (1.0) in the attention call below.
- query_states = (self.q_proj(hidden_states) * self.scaling).view(hidden_shape).transpose(1, 2).contiguous()
- # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
- if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache):
- is_updated = past_key_values.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_states from cache
- past_key_values.is_updated[self.layer_idx] = True
- past_key_values = past_key_values.cross_attention_cache
- else:
- past_key_values = past_key_values.self_attention_cache
- # use key_value_states if cross attention
- current_states = key_value_states if key_value_states is not None else hidden_states
- if is_cross_attention and past_key_values and is_updated:
- # reuse k,v, cross_attentions
- key_states = past_key_values.layers[self.layer_idx].keys
- value_states = past_key_values.layers[self.layer_idx].values
- else:
- # Use the query's batch dimension for kv view so that a different-batch
- # encoder output (e.g. in tests) gets absorbed into the sequence axis,
- # preserving backward-compatible behaviour.
- kv_shape = (input_shape[0], -1, self.num_heads, self.head_dim)
- key_states = self.k_proj(current_states).view(kv_shape).transpose(1, 2).contiguous()
- value_states = self.v_proj(current_states).view(kv_shape).transpose(1, 2).contiguous()
- if past_key_values is not None:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.dropout,
- scaling=1.0,
- output_attentions=output_attentions,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights
- # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER
- class WhisperEncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: WhisperConfig):
- super().__init__()
- self.embed_dim = config.d_model
- self.self_attn = WhisperAttention(
- embed_dim=self.embed_dim,
- num_heads=config.encoder_attention_heads,
- dropout=config.attention_dropout,
- config=config,
- )
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.dropout = config.dropout
- self.activation_fn = ACT2FN[config.activation_function]
- self.activation_dropout = config.activation_dropout
- self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
- self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- """
- residual = hidden_states
- hidden_states = self.self_attn_layer_norm(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- **kwargs,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.final_layer_norm(hidden_states)
- hidden_states = self.activation_fn(self.fc1(hidden_states))
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
- hidden_states = self.fc2(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- if hidden_states.dtype == torch.float16:
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- return hidden_states
- class WhisperDecoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: WhisperConfig, layer_idx: int | None = None):
- super().__init__()
- self.embed_dim = config.d_model
- self.self_attn = WhisperAttention(
- embed_dim=self.embed_dim,
- num_heads=config.decoder_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=True,
- is_causal=True,
- layer_idx=layer_idx,
- config=config,
- )
- self.dropout = config.dropout
- self.activation_fn = ACT2FN[config.activation_function]
- self.activation_dropout = config.activation_dropout
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = WhisperAttention(
- self.embed_dim,
- config.decoder_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=True,
- layer_idx=layer_idx,
- config=config,
- )
- self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
- self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- encoder_hidden_states: torch.Tensor | None = None,
- encoder_attention_mask: torch.Tensor | None = None,
- past_key_values: EncoderDecoderCache | None = None,
- use_cache: bool | None = True,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- encoder_hidden_states (`torch.FloatTensor`):
- cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
- encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- past_key_values (`Cache`): cached past key and value projection states
- """
- residual = hidden_states
- hidden_states = self.self_attn_layer_norm(hidden_states)
- # Self Attention
- hidden_states, _ = self.self_attn(
- hidden_states,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- **kwargs,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- # Cross-Attention Block
- if encoder_hidden_states is not None:
- residual = hidden_states
- hidden_states = self.encoder_attn_layer_norm(hidden_states)
- hidden_states, _ = self.encoder_attn(
- hidden_states,
- key_value_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- past_key_values=past_key_values,
- **kwargs,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- # Fully Connected
- residual = hidden_states
- hidden_states = self.final_layer_norm(hidden_states)
- hidden_states = self.activation_fn(self.fc1(hidden_states))
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
- hidden_states = self.fc2(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- return hidden_states
- @auto_docstring
- class WhisperPreTrainedModel(PreTrainedModel):
- config: WhisperConfig
- base_model_prefix = "model"
- main_input_name = "input_features"
- input_modalities = ("audio", "text")
- supports_gradient_checkpointing = True
- _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _can_compile_fullgraph = True
- @torch.no_grad()
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, WhisperEncoder):
- init.copy_(module.embed_positions.weight, sinusoids(*module.embed_positions.weight.shape))
- elif isinstance(module, WhisperForAudioClassification):
- if self.config.use_weighted_layer_sum:
- init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1))
- def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
- """
- Computes the output length of the convolutional layers
- """
- input_lengths = (input_lengths - 1) // 2 + 1
- return input_lengths
- class WhisperEncoder(WhisperPreTrainedModel):
- """
- Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
- [`WhisperEncoderLayer`].
- Args:
- config: WhisperConfig
- """
- _can_record_outputs = {
- "hidden_states": WhisperEncoderLayer,
- "attentions": WhisperAttention,
- }
- input_modalities = ("audio",)
- def __init__(self, config: WhisperConfig):
- super().__init__(config)
- self.dropout = config.dropout
- self.layerdrop = config.encoder_layerdrop
- embed_dim = config.d_model
- self.num_mel_bins = config.num_mel_bins
- self.padding_idx = config.pad_token_id
- self.max_source_positions = config.max_source_positions
- self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
- self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
- self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
- self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
- self.embed_positions.requires_grad_(False)
- self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
- self.layer_norm = nn.LayerNorm(config.d_model)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- def _freeze_parameters(self):
- for param in self.parameters():
- param.requires_grad = False
- self._requires_grad = False
- def get_input_embeddings(self) -> nn.Module:
- return self.conv1
- def set_input_embeddings(self, value: nn.Module):
- self.conv1 = value
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self,
- input_features,
- attention_mask=None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutput:
- r"""
- Args:
- input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
- Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
- obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
- `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
- the soundfile library (`pip install soundfile`). To prepare the array into
- `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
- and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
- attention_mask (`torch.Tensor`)`, *optional*):
- Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
- but it is not used. By default the silence in the input log mel spectrogram are ignored.
- """
- expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
- if input_features.shape[-1] != expected_seq_length:
- raise ValueError(
- f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
- )
- inputs_embeds = nn.functional.gelu(self.conv1(input_features))
- inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
- inputs_embeds = inputs_embeds.permute(0, 2, 1)
- all_positions = torch.arange(self.embed_positions.num_embeddings, device=inputs_embeds.device)
- hidden_states = inputs_embeds + self.embed_positions(all_positions)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- for idx, encoder_layer in enumerate(self.layers):
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
- to_drop = False
- if self.training:
- dropout_probability = torch.rand([])
- if dropout_probability < self.layerdrop: # skip the layer
- to_drop = True
- if not to_drop:
- hidden_states = encoder_layer(
- hidden_states,
- None,
- **kwargs,
- )
- hidden_states = self.layer_norm(hidden_states)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- )
- class WhisperDecoder(WhisperPreTrainedModel):
- """
- Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`WhisperDecoderLayer`]
- Args:
- config: WhisperConfig
- """
- _can_record_outputs = {
- "hidden_states": WhisperDecoderLayer,
- "attentions": OutputRecorder(WhisperAttention, index=1, layer_name="self_attn"),
- "cross_attentions": OutputRecorder(WhisperAttention, index=1, layer_name="encoder_attn"),
- }
- main_input_name = "input_ids"
- input_modalities = ("text",)
- def __init__(self, config: WhisperConfig):
- super().__init__(config)
- self.dropout = config.dropout
- self.layerdrop = config.decoder_layerdrop
- self.padding_idx = config.pad_token_id
- self.max_target_positions = config.max_target_positions
- self.max_source_positions = config.max_source_positions
- self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
- self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
- self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
- self.layers = nn.ModuleList(
- [WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)]
- )
- self.layer_norm = nn.LayerNorm(config.d_model)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self,
- input_ids=None,
- attention_mask=None,
- encoder_hidden_states=None,
- past_key_values=None,
- inputs_embeds=None,
- position_ids=None,
- use_cache=None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPastAndCrossAttentions:
- r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
- provide it.
- Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
- of the decoder.
- past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
- that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
- all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of
- shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
- `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
- control over how to convert `input_ids` indices into associated vectors than the model's internal
- embedding lookup matrix.
- """
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = (
- EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
- if encoder_hidden_states is not None or self.config.is_encoder_decoder
- else DynamicCache(config=self.config)
- )
- past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
- if position_ids is None:
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_key_values_length
- position_ids = position_ids.unsqueeze(0).repeat(inputs_embeds.shape[0], 1)
- # embed positions
- if input_ids is not None:
- positions = self.embed_positions(
- input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
- )
- else:
- positions = self.embed_positions(
- inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
- )
- hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- causal_mask = create_causal_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- position_ids=position_ids,
- )
- for idx, decoder_layer in enumerate(self.layers):
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
- if self.training:
- dropout_probability = torch.rand([])
- if dropout_probability < self.layerdrop:
- continue
- hidden_states = decoder_layer(
- hidden_states,
- causal_mask,
- encoder_hidden_states,
- encoder_attention_mask=None,
- past_key_values=past_key_values if use_cache else None,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = self.layer_norm(hidden_states)
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- )
- @auto_docstring
- class WhisperModel(WhisperPreTrainedModel):
- def __init__(self, config: WhisperConfig):
- super().__init__(config)
- self.encoder = WhisperEncoder(config)
- self.decoder = WhisperDecoder(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.decoder.embed_tokens
- def set_input_embeddings(self, value):
- self.decoder.embed_tokens = value
- def freeze_encoder(self):
- """
- Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
- not be updated during training.
- """
- self.encoder._freeze_parameters()
- def _mask_input_features(
- self,
- input_features: torch.FloatTensor,
- attention_mask: torch.LongTensor | None = None,
- ):
- """
- Masks extracted features along time axis and/or along feature axis according to
- [SpecAugment](https://huggingface.co/papers/1904.08779).
- """
- # `config.apply_spec_augment` can set masking to False
- if not getattr(self.config, "apply_spec_augment", True):
- return input_features
- # generate indices & apply SpecAugment along time axis
- batch_size, hidden_size, sequence_length = input_features.size()
- if self.config.mask_time_prob > 0 and self.training:
- # generate indices & apply SpecAugment along time axis
- mask_time_indices = _compute_mask_indices(
- (batch_size, sequence_length),
- mask_prob=self.config.mask_time_prob,
- mask_length=self.config.mask_time_length,
- attention_mask=attention_mask,
- min_masks=self.config.mask_time_min_masks,
- )
- mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool)
- mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1)
- input_features[mask_time_indices] = 0
- if self.config.mask_feature_prob > 0 and self.training:
- # generate indices & apply SpecAugment along feature axis
- mask_feature_indices = _compute_mask_indices(
- (batch_size, hidden_size),
- mask_prob=self.config.mask_feature_prob,
- mask_length=self.config.mask_feature_length,
- min_masks=self.config.mask_feature_min_masks,
- )
- mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool)
- input_features[mask_feature_indices] = 0
- return input_features
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_features: torch.FloatTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.LongTensor | None = None,
- encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
- past_key_values: Cache | None = None,
- decoder_inputs_embeds: tuple[torch.FloatTensor] | None = None,
- decoder_position_ids: tuple[torch.LongTensor] | None = None,
- use_cache: bool | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor] | Seq2SeqModelOutput:
- r"""
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids)
- Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. If
- `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
- `past_key_values`).
- decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- If you want to change padding behavior, you should read
- [`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the BART
- paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy.
- decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
- [What are position IDs?](../glossary#position-ids)
- Example:
- ```python
- >>> import torch
- >>> from transformers import AutoFeatureExtractor, WhisperModel
- >>> from datasets import load_dataset
- >>> model = WhisperModel.from_pretrained("openai/whisper-base")
- >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
- >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
- >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
- >>> input_features = inputs.input_features
- >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
- >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
- >>> list(last_hidden_state.shape)
- [1, 2, 512]
- ```"""
- if encoder_outputs is None:
- input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
- encoder_outputs = self.encoder(
- input_features,
- **kwargs,
- )
- elif not isinstance(encoder_outputs, BaseModelOutput):
- encoder_outputs = BaseModelOutput(
- last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
- )
- # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
- decoder_outputs = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- encoder_hidden_states=encoder_outputs[0],
- past_key_values=past_key_values,
- inputs_embeds=decoder_inputs_embeds,
- position_ids=decoder_position_ids,
- use_cache=use_cache,
- **kwargs,
- )
- return Seq2SeqModelOutput(
- last_hidden_state=decoder_outputs.last_hidden_state,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
- """
- )
- class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedModel):
- base_model_prefix = "model"
- _tied_weights_keys = {"proj_out.weight": "model.decoder.embed_tokens.weight"}
- def __init__(self, config: WhisperConfig):
- super().__init__(config)
- self.model = WhisperModel(config)
- self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
- self.max_target_positions = config.max_target_positions
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.proj_out
- def set_output_embeddings(self, new_embeddings):
- self.proj_out = new_embeddings
- def get_input_embeddings(self) -> nn.Module:
- return self.model.get_input_embeddings()
- def freeze_encoder(self):
- """
- Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
- not be updated during training.
- """
- self.model.encoder._freeze_parameters()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_features: torch.FloatTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.LongTensor | None = None,
- encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
- past_key_values: Cache | None = None,
- decoder_inputs_embeds: tuple[torch.FloatTensor] | None = None,
- decoder_position_ids: tuple[torch.LongTensor] | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor] | Seq2SeqLMOutput:
- r"""
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids)
- Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. If
- `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
- `past_key_values`).
- decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- If you want to change padding behavior, you should read
- [`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the BART
- paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy.
- decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
- [What are position IDs?](../glossary#position-ids)
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
- or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
- only computed for the tokens with labels in `[0, ..., config.vocab_size]`. `sequence_length` should be smaller than or equal to `config.max_target_positions`.
- Example:
- ```python
- >>> import torch
- >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
- >>> from datasets import load_dataset
- >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
- >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
- >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
- >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
- >>> input_features = inputs.input_features
- >>> generated_ids = model.generate(inputs=input_features)
- >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
- >>> transcription
- ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
- ```"""
- if labels is not None:
- if labels.shape[1] > self.max_target_positions:
- raise ValueError(
- f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens."
- )
- if decoder_input_ids is None and decoder_inputs_embeds is None:
- decoder_input_ids = shift_tokens_right(
- labels, self.config.pad_token_id, self.config.decoder_start_token_id
- )
- outputs: Seq2SeqModelOutput = self.model(
- input_features,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- encoder_outputs=encoder_outputs,
- decoder_attention_mask=decoder_attention_mask,
- past_key_values=past_key_values,
- decoder_inputs_embeds=decoder_inputs_embeds,
- decoder_position_ids=decoder_position_ids,
- use_cache=use_cache,
- **kwargs,
- )
- lm_logits = self.proj_out(outputs.last_hidden_state)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- # move labels to correct device to enable PP
- labels = labels.to(lm_logits.device)
- loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
- return Seq2SeqLMOutput(
- loss=loss,
- logits=lm_logits,
- past_key_values=outputs.past_key_values,
- decoder_hidden_states=outputs.decoder_hidden_states,
- decoder_attentions=outputs.decoder_attentions,
- cross_attentions=outputs.cross_attentions,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_hidden_states=outputs.encoder_hidden_states,
- encoder_attentions=outputs.encoder_attentions,
- )
- class WhisperDecoderWrapper(WhisperPreTrainedModel):
- """
- This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
- used in combination with the [`EncoderDecoderModel`] framework.
- """
- def __init__(self, config):
- super().__init__(config)
- config.is_encoder_decoder = False
- self.decoder = WhisperDecoder(config)
- self.post_init()
- def get_input_embeddings(self):
- return self.decoder.embed_tokens
- def set_input_embeddings(self, value):
- self.decoder.embed_tokens = value
- def forward(self, *args, **kwargs):
- return self.decoder(*args, **kwargs)
- @auto_docstring(
- custom_intro="""
- Whisper decoder with a language modeling head on top (linear layer with weights tied to the input embeddings).
- """
- )
- class WhisperForCausalLM(WhisperPreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"proj_out.weight": "model.decoder.embed_tokens.weight"}
- main_input_name = "input_ids"
- def __init__(self, config):
- super().__init__(config)
- config.is_encoder_decoder = False
- self.model = WhisperDecoderWrapper(config)
- self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.proj_out
- def set_output_embeddings(self, new_embeddings):
- self.proj_out = new_embeddings
- def get_input_embeddings(self) -> nn.Module:
- return self.model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.model.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- encoder_outputs: tuple[torch.FloatTensor] | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- **kwargs,
- ) -> tuple | CausalLMOutputWithCrossAttentions:
- r"""
- encoder_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
- if the model is configured as a decoder.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Example:
- ```python
- >>> from transformers import WhisperForCausalLM, WhisperForConditionalGeneration, WhisperProcessor
- >>> import torch
- >>> from datasets import load_dataset
- >>> processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
- >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2")
- >>> assistant_model = WhisperForCausalLM.from_pretrained("distil-whisper/distil-large-v2")
- >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
- >>> sample = ds[0]["audio"]
- >>> input_features = processor(
- ... sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
- ... ).input_features
- >>> predicted_ids = model.generate(input_features, assistant_model=assistant_model)
- >>> # decode token ids to text
- >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
- >>> transcription
- ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'
- ```"""
- # If the user passed a tuple or `BaseModelOutput` for encoder_outputs, we extract only the hidden states
- if isinstance(encoder_outputs, (BaseModelOutput, tuple, list)):
- encoder_outputs = encoder_outputs[0]
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model.decoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_outputs,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- logits = self.proj_out(outputs[0])
- loss = None
- if labels is not None:
- labels = labels.to(logits.device)
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
- return CausalLMOutputWithCrossAttentions(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- @auto_docstring(
- custom_intro="""
- Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
- like SUPERB Keyword Spotting.
- """
- )
- class WhisperForAudioClassification(WhisperPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.encoder = WhisperEncoder(config)
- num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
- if config.use_weighted_layer_sum:
- self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
- self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
- self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- def freeze_encoder(self):
- """
- Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
- not be updated during training. Only the projection layers and classification head will be updated.
- """
- self.encoder._freeze_parameters()
- def get_input_embeddings(self) -> nn.Module:
- return self.encoder.get_input_embeddings()
- def set_input_embeddings(self, value: nn.Module):
- self.encoder.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_features: torch.LongTensor | None = None,
- encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
- labels: torch.LongTensor | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- Example:
- ```python
- >>> import torch
- >>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification
- >>> from datasets import load_dataset
- >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
- >>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
- >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
- >>> sample = next(iter(ds))
- >>> inputs = feature_extractor(
- ... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt"
- ... )
- >>> input_features = inputs.input_features
- >>> with torch.no_grad():
- ... logits = model(input_features).logits
- >>> predicted_class_ids = torch.argmax(logits).item()
- >>> predicted_label = model.config.id2label[predicted_class_ids]
- >>> predicted_label
- 'Afrikaans'
- ```"""
- if self.config.use_weighted_layer_sum:
- kwargs["output_hidden_states"] = True
- if encoder_outputs is None:
- encoder_outputs = self.encoder(
- input_features,
- **kwargs,
- )
- elif not isinstance(encoder_outputs, BaseModelOutput):
- encoder_outputs = BaseModelOutput(
- last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
- )
- if self.config.use_weighted_layer_sum:
- hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
- hidden_states = torch.stack(hidden_states, dim=1)
- norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
- hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
- else:
- hidden_states = encoder_outputs[0]
- hidden_states = self.projector(hidden_states)
- pooled_output = hidden_states.mean(dim=1)
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- # move labels to correct device to enable PP
- labels = labels.to(logits.device)
- loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- __all__ = [
- "WhisperForCausalLM",
- "WhisperForConditionalGeneration",
- "WhisperModel",
- "WhisperPreTrainedModel",
- "WhisperForAudioClassification",
- ]
|