| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/lasr/modular_lasr.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_lasr.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Callable
- from dataclasses import dataclass
- from typing import Optional
- import torch
- from torch import nn
- from ...activations import ACT2FN
- from ...integrations import use_kernel_func_from_hub, use_kernelized_func
- from ...masking_utils import create_bidirectional_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutput, CausalLMOutput
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
- from ...utils.generic import maybe_autocast, merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from .configuration_lasr import LasrCTCConfig, LasrEncoderConfig
- class LasrEncoderSubsampling(nn.Module):
- def __init__(self, config: LasrEncoderConfig):
- super().__init__()
- self.dense_0 = nn.Linear(config.num_mel_bins, config.hidden_size)
- self.conv_0 = nn.Conv1d(
- config.hidden_size,
- config.hidden_size,
- kernel_size=config.subsampling_conv_kernel_size,
- stride=config.subsampling_conv_stride,
- )
- self.conv_1 = nn.Conv1d(
- config.hidden_size,
- config.subsampling_conv_channels,
- kernel_size=config.subsampling_conv_kernel_size,
- stride=config.subsampling_conv_stride,
- )
- self.dense_1 = nn.Linear(config.subsampling_conv_channels, config.hidden_size)
- self.act_fn = nn.ReLU()
- def forward(self, input_features: torch.Tensor) -> torch.Tensor:
- hidden_states = self.act_fn(self.dense_0(input_features))
- hidden_states = hidden_states.transpose(1, 2)
- hidden_states = self.act_fn(self.conv_0(hidden_states))
- hidden_states = self.act_fn(self.conv_1(hidden_states))
- hidden_states = hidden_states.transpose(1, 2)
- return self.dense_1(hidden_states)
- class LasrEncoderRotaryEmbedding(nn.Module):
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, config: LasrEncoderConfig, device=None):
- super().__init__()
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.rope_type = self.config.rope_parameters["rope_type"]
- rope_init_fn: Callable = self.compute_default_rope_parameters
- if self.rope_type != "default":
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
- @staticmethod
- def compute_default_rope_parameters(
- config: LasrEncoderConfig | None = None,
- device: Optional["torch.device"] = None,
- seq_len: int | None = None,
- ) -> tuple["torch.Tensor", float]:
- """
- Computes the inverse frequencies according to the original RoPE implementation
- Args:
- config ([`~transformers.PreTrainedConfig`]):
- The model configuration.
- device (`torch.device`):
- The device to use for initialization of the inverse frequencies.
- seq_len (`int`, *optional*):
- The current sequence length. Unused for this type of RoPE.
- Returns:
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
- """
- base = config.rope_parameters["rope_theta"]
- dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
- attention_factor = 1.0 # Unused in this type of RoPE
- # Compute the inverse frequencies
- inv_freq = 1.0 / (
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
- )
- return inv_freq, attention_factor
- @torch.no_grad()
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
- def forward(self, x, position_ids):
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
- position_ids_expanded = position_ids[:, None, :].float()
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * self.attention_scaling
- sin = emb.sin() * self.attention_scaling
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- @use_kernel_func_from_hub("rotary_pos_emb")
- def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float,
- dropout: float = 0.0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- @use_kernelized_func(apply_rotary_pos_emb)
- class LasrEncoderAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: LasrEncoderConfig, layer_idx: int):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = config.attention_dropout
- self.is_causal = False
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class LasrEncoderConvolutionModule(nn.Module):
- def __init__(self, config: LasrEncoderConfig, module_config=None):
- """
- Args:
- config (LasrEncoderConfig): Configuration for the model.
- module_config (dict): Configuration for the module (e.g., encoder or decoder).
- """
- super().__init__()
- channels = config.hidden_size
- # kernel_size should be an odd number for 'SAME' padding
- if module_config is None:
- # e.g. using `LasrEncoderEncoderConfig` in src/transformers/models/lasr_encoder/configuration_lasr_encoder.py
- kernel_size = config.conv_kernel_size
- self.activation = ACT2FN[getattr(config, "hidden_act", "silu")]
- else:
- kernel_size = module_config["kernel_size"]
- self.activation = ACT2FN[module_config.get("activation", "silu")]
- self.padding = "same"
- self.pointwise_conv1 = nn.Conv1d(
- channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
- )
- self.depthwise_conv = nn.Conv1d(
- channels,
- channels,
- kernel_size,
- stride=1,
- padding=self.padding,
- groups=channels,
- bias=config.convolution_bias,
- )
- self.norm = nn.BatchNorm1d(config.hidden_size, momentum=config.batch_norm_momentum)
- self.pointwise_conv2 = nn.Conv1d(
- channels, channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
- )
- def forward(self, hidden_states, attention_mask=None):
- """
- Compute convolution module.
- Args:
- hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
- attention_mask (`torch.Tensor` of shape `(batch, 1, time, time)`): Attention mask.
- Returns:
- `torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
- """
- # exchange the temporal dimension and the feature dimension
- hidden_states = hidden_states.transpose(1, 2)
- # GLU mechanism, (batch_size, 2*channel, dim)
- hidden_states = self.pointwise_conv1(hidden_states)
- # (batch_size, channel, dim)
- hidden_states = nn.functional.glu(hidden_states, dim=1)
- # Apply padding mask before convolution
- if attention_mask is not None:
- if attention_mask.dtype == torch.bool:
- all_masked_rows = torch.all(~attention_mask, dim=2)
- else:
- all_masked_rows = torch.all(~(attention_mask == 0.0), dim=2)
- hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
- # 1D Depthwise Conv
- hidden_states = self.depthwise_conv(hidden_states)
- hidden_states = self.norm(hidden_states)
- hidden_states = self.activation(hidden_states)
- hidden_states = self.pointwise_conv2(hidden_states)
- return hidden_states.transpose(1, 2)
- class LasrEncoderFeedForward(nn.Module):
- def __init__(self, config: LasrEncoderConfig):
- super().__init__()
- self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.attention_bias)
- self.activation = ACT2FN[config.hidden_act]
- self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.attention_bias)
- self.activation_dropout = config.activation_dropout
- def forward(self, hidden_states):
- hidden_states = self.activation(self.linear1(hidden_states))
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
- hidden_states = self.linear2(hidden_states)
- return hidden_states
- class LasrEncoderBlock(GradientCheckpointingLayer):
- def __init__(self, config: LasrEncoderConfig, layer_idx: int):
- super().__init__()
- self.gradient_checkpointing = False
- self.feed_forward1 = LasrEncoderFeedForward(config)
- self.self_attn = LasrEncoderAttention(config, layer_idx)
- self.conv = LasrEncoderConvolutionModule(config)
- self.feed_forward2 = LasrEncoderFeedForward(config)
- self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
- self.norm_self_att = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
- self.norm_conv = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
- self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
- self.norm_out = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
- self.feed_forward_residual_weights = config.feed_forward_residual_weights
- self.conv_residual_weights = config.conv_residual_weights
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- position_embeddings: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
- hidden_states = (
- self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
- )
- normalized_hidden_states = self.norm_self_att(hidden_states)
- attn_output, _ = self.self_attn(
- hidden_states=normalized_hidden_states,
- attention_mask=attention_mask,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = hidden_states + attn_output
- conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
- hidden_states = self.conv_residual_weights[0] * hidden_states + self.conv_residual_weights[1] * conv_output
- residual = hidden_states
- hidden_states = self.feed_forward2(self.norm_feed_forward2(hidden_states))
- hidden_states = (
- self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
- )
- hidden_states = self.norm_out(hidden_states)
- return hidden_states
- @auto_docstring
- class LasrPreTrainedModel(PreTrainedModel):
- config: LasrCTCConfig
- base_model_prefix = "model"
- main_input_name = "input_features"
- input_modalities = "audio"
- supports_gradient_checkpointing = True
- _no_split_modules = ["LasrEncoderBlock"]
- _supports_flat_attention_mask = True
- _supports_sdpa = True
- # padding is incompatible with flex attention as the resulting mask cannot be used to apply padding
- _supports_flex_attn = False
- # TODO: @eustlb, add support when flash attention supports custom attention bias
- _supports_flash_attn = False
- _can_compile_fullgraph = True
- _supports_attention_backend = True
- _can_record_outputs = {
- "hidden_states": LasrEncoderBlock,
- "attentions": LasrEncoderAttention,
- }
- @torch.no_grad()
- def _init_weights(self, module):
- super()._init_weights(module)
- def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
- encoder_config = self.config.encoder_config if isinstance(self.config, LasrCTCConfig) else self.config
- kernel_size = encoder_config.subsampling_conv_kernel_size
- stride = encoder_config.subsampling_conv_stride
- num_layers = 2
- for _ in range(num_layers):
- input_lengths = (input_lengths - kernel_size) // stride + 1
- return input_lengths
- def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: int | None = None):
- """
- Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
- when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
- """
- output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
- # Use target_length if provided, otherwise use max length in batch
- max_length = target_length if target_length is not None else output_lengths.max()
- attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None]
- return attention_mask
- @auto_docstring(
- custom_intro="""
- The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100).
- """
- )
- class LasrEncoder(LasrPreTrainedModel):
- config: LasrEncoderConfig
- base_model_prefix = "encoder"
- def __init__(self, config: LasrEncoderConfig):
- super().__init__(config)
- self.gradient_checkpointing = False
- self.dropout = config.dropout
- self.dropout_positions = config.dropout_positions
- self.layerdrop = config.layerdrop
- self.subsampler = LasrEncoderSubsampling(config)
- self.rotary_emb = LasrEncoderRotaryEmbedding(config)
- self.layers = nn.ModuleList(
- [LasrEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)
- self.post_init()
- @auto_docstring
- @merge_with_config_defaults
- @capture_outputs
- @can_return_tuple
- def forward(
- self,
- input_features: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutput:
- r"""
- Example:
- ```python
- >>> from transformers import AutoProcessor, LasrEncoder
- >>> from datasets import load_dataset, Audio
- >>> model_id = TODO
- >>> processor = AutoProcessor.from_pretrained(model_id)
- >>> encoder = ParakeetEncoder.from_pretrained(model_id)
- >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
- >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
- >>> inputs = processor(ds[0]["audio"]["array"])
- >>> encoder_outputs = encoder(**inputs)
- >>> print(encoder_outputs.last_hidden_state.shape)
- ```
- """
- hidden_states = self.subsampler(input_features)
- cos, sin = self.rotary_emb(
- hidden_states, torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training)
- sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training)
- if attention_mask is not None:
- attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
- attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=hidden_states,
- attention_mask=attention_mask,
- )
- for encoder_layer in self.layers:
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
- to_drop = False
- if self.training:
- dropout_probability = torch.rand([])
- if dropout_probability < self.layerdrop: # skip the layer
- to_drop = True
- if not to_drop:
- hidden_states = encoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- position_embeddings=(cos, sin),
- **kwargs,
- )
- hidden_states = self.out_norm(hidden_states)
- return BaseModelOutput(last_hidden_state=hidden_states)
- @dataclass
- class LasrGenerateOutput(ModelOutput):
- """
- Outputs of Lasr models.
- Args:
- sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
- if all batches finished early due to the `eos_token_id`.
- logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
- Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
- each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
- attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
- hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
- """
- sequences: torch.LongTensor
- logits: tuple[torch.FloatTensor] | None = None
- attentions: tuple[tuple[torch.FloatTensor]] | None = None
- hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
- @auto_docstring(
- custom_intro="""
- Lasr Encoder with a Connectionist Temporal Classification (CTC) head.
- """
- )
- class LasrForCTC(LasrPreTrainedModel):
- config: LasrCTCConfig
- def __init__(self, config: LasrCTCConfig):
- super().__init__(config)
- self.encoder = LasrEncoder(config.encoder_config)
- # Conv rather than linear to be consistent with NeMO decoding layer
- self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
- self.post_init()
- @auto_docstring
- @can_return_tuple
- def forward(
- self,
- input_features: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutput:
- r"""
- Example:
- ```python
- >>> from transformers import AutoProcessor, LasrForCTC
- >>> from datasets import load_dataset, Audio
- >>> model_id = "nvidia/lasr-ctc-1.1b"
- >>> processor = AutoProcessor.from_pretrained(model_id)
- >>> model = LasrForCTC.from_pretrained(model_id)
- >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
- >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
- >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
- >>> outputs = model(**inputs)
- >>> print(outputs.loss)
- ```"""
- encoder_outputs = self.encoder(
- input_features=input_features,
- attention_mask=attention_mask,
- **kwargs,
- )
- hidden_states = encoder_outputs.last_hidden_state
- logits = self.ctc_head(hidden_states.transpose(1, 2)).transpose(1, 2)
- loss = None
- if labels is not None:
- # retrieve loss input_lengths from attention_mask
- attention_mask = (
- attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
- )
- input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
- # assuming that padded tokens are filled with -100
- # when not being attended to
- labels_mask = labels != self.config.pad_token_id
- target_lengths = labels_mask.sum(-1)
- flattened_targets = labels.masked_select(labels_mask)
- # ctc_loss doesn't support fp16
- log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
- with torch.backends.cudnn.flags(enabled=False):
- loss = nn.functional.ctc_loss(
- log_probs,
- flattened_targets,
- input_lengths,
- target_lengths,
- blank=self.config.pad_token_id,
- reduction=self.config.ctc_loss_reduction,
- zero_infinity=self.config.ctc_zero_infinity,
- )
- return CausalLMOutput(
- loss=loss,
- logits=logits,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- @torch.no_grad()
- def generate(
- self,
- input_features: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- return_dict_in_generate: bool = False,
- **kwargs: Unpack[TransformersKwargs],
- ) -> LasrGenerateOutput | torch.LongTensor:
- r"""
- Example:
- ```python
- >>> from transformers import AutoProcessor, LasrForCTC
- >>> from datasets import load_dataset, Audio
- >>> model_id = TODO
- >>> processor = AutoProcessor.from_pretrained(model_id)
- >>> model = LasrForCTC.from_pretrained(model_id)
- >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
- >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
- >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
- >>> predicted_ids = model.generate(**inputs)
- >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
- >>> print(transcription)
- ```
- """
- kwargs["return_dict"] = True
- outputs: CausalLMOutput = self.forward(
- input_features=input_features,
- attention_mask=attention_mask,
- **kwargs,
- )
- # greedy decoding
- sequences = outputs.logits.argmax(dim=-1)
- # mask out padded tokens
- if attention_mask is not None:
- attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequences.shape[1])
- sequences[~attention_mask] = self.config.pad_token_id
- if return_dict_in_generate:
- return LasrGenerateOutput(
- sequences=sequences,
- logits=outputs.logits,
- attentions=outputs.attentions,
- hidden_states=outputs.hidden_states,
- )
- return sequences
- __all__ = ["LasrForCTC", "LasrEncoder", "LasrPreTrainedModel"]
|