# Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools from collections.abc import Callable import torch from huggingface_hub.dataclasses import strict from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors from tokenizers.models import Unigram from torch import nn from ...masking_utils import create_bidirectional_mask from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...tokenization_utils_tokenizers import TokenizersBackend from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward from ..parakeet.configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig from ..parakeet.modeling_parakeet import ( ParakeetEncoderBlock, ParakeetEncoderConvolutionModule, ParakeetForCTC, ParakeetPreTrainedModel, ) from ..parakeet.processing_parakeet import ParakeetProcessor from ..t5.tokenization_t5 import T5Tokenizer class LasrTokenizer(T5Tokenizer, TokenizersBackend): def __init__( self, eos_token="", unk_token="", pad_token="", _spm_precompiled_charsmap=None, extra_ids=100, additional_special_tokens=None, vocab=None, vocab_file=None, **kwargs, ): self._extra_ids = extra_ids # Handle extra_ids and additional_special_tokens if additional_special_tokens is not None: extra_tokens = [x for x in additional_special_tokens if "" for i in range(extra_ids)] elif extra_ids > 0 and extra_ids != len(extra_tokens): raise ValueError( f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" " provided to LasrTokenizer. In this case the additional_special_tokens must include the extra_ids" " tokens" ) else: extra_tokens = [f"" for i in range(extra_ids)] additional_special_tokens = extra_tokens # LASR vocab structure: =0, =1, =2, then regular vocab, then extra_ids in reverse if vocab is not None: self._vocab_scores = vocab else: self._vocab_scores = [ (str(pad_token), 0.0), (str(eos_token), 0.0), (str(unk_token), 0.0), ("▁", -2.0), # Space token ] for i in range(extra_ids - 1, -1, -1): self._vocab_scores.append((f"", 0.0)) self._tokenizer = Tokenizer( Unigram( self._vocab_scores, unk_id=3, byte_fallback=False, ) ) if _spm_precompiled_charsmap is not None: self._tokenizer.normalizer = normalizers.Precompiled(_spm_precompiled_charsmap) self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence( [ pre_tokenizers.WhitespaceSplit(), pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True), ] ) self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True) TokenizersBackend.__init__( eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, extra_ids=extra_ids, additional_special_tokens=additional_special_tokens, **kwargs, ) self._tokenizer.post_processor = processors.TemplateProcessing( single=["$A", ""], pair=["$A", "", "$B", ""], special_tokens=[ ("", self.eos_token_id), ], ) def _decode( self, token_ids: int | list[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool | None = None, group_tokens: bool = True, **kwargs, ) -> str: if isinstance(token_ids, int): token_ids = [token_ids] if group_tokens: token_ids = [token_group[0] for token_group in itertools.groupby(token_ids)] # for CTC we filter out the blank token, which is the pad token token_ids = [token for token in token_ids if token != self.pad_token_id] return TokenizersBackend._decode( self, token_ids=token_ids, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs, ) class LasrProcessor(ParakeetProcessor): pass @auto_docstring(checkpoint="google/medasr") @strict class LasrEncoderConfig(ParakeetEncoderConfig): r""" convolution_bias (`bool`, *optional*, defaults to `False`): Whether to use bias in convolutions of the conformer's convolution module. conv_kernel_size (`int`, *optional*, defaults to 32): The kernel size of the convolution layers in the Conformer block. subsampling_conv_channels (`int`, *optional*, defaults to 256): The number of channels in the subsampling convolution layers. subsampling_conv_kernel_size (`int`, *optional*, defaults to 5): The kernel size of the subsampling convolution layers. subsampling_conv_stride (`int`, *optional*, defaults to 2): The stride of the subsampling convolution layers. dropout_positions (`float`, *optional*, defaults to 0.0): The dropout ratio for the positions in the input sequence. feed_forward_residual_weights (`tuple[float, float]`, *optional*, defaults to `[1.5, 0.5]`): The residual weights for the feed forward layers. conv_residual_weights (`tuple[float, float]`, *optional*, defaults to `[2.0, 1.0]`): The residual weights for the convolution layers. batch_norm_momentum (`float`, *optional*, defaults to 0.01): The momentum for the batch normalization layers Example: ```python >>> from transformers import LasrEncoderModel, LasrEncoderConfig >>> # Initializing a `LasrEncoder` configuration >>> configuration = LasrEncoderConfig() >>> # Initializing a model from the configuration >>> model = LasrEncoderModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ``` This configuration class is based on the LasrEncoder architecture from Google Health AI. You can find more details and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO). """ hidden_size: int = 512 num_hidden_layers: int = 17 intermediate_size: int = 2048 attention_bias: bool = False convolution_bias: bool = False conv_kernel_size: int = 32 subsampling_conv_kernel_size: int = 5 num_mel_bins: int = 128 max_position_embeddings: int = 10000 layer_norm_eps: float = 1e-6 feed_forward_residual_weights: list[float] | tuple[float, ...] = (1.5, 0.5) conv_residual_weights: list[float] | tuple[float, ...] = (2.0, 1.0) batch_norm_momentum: float = 0.01 rope_parameters: dict | None = None subsampling_factor = AttributeError() scale_input = AttributeError() @auto_docstring(checkpoint="google/medasr") @strict class LasrCTCConfig(ParakeetCTCConfig): r""" ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`): Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an instance of [`LasrForCTC`]. ctc_zero_infinity (`bool`, *optional*, defaults to `True`): Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance of [`LasrForCTC`]. Example: ```python >>> from transformers import LasrForCTC, LasrCTCConfig >>> # Initializing a Lasr configuration >>> configuration = LasrCTCConfig() >>> # Initializing a model from the configuration >>> model = LasrForCTC(configuration) >>> # Accessing the model configuration >>> configuration = model.config ``` This configuration class is based on the Lasr CTC architecture from Google Health AI. You can find more details and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO). """ vocab_size: int = 512 pad_token_id: int = 0 @property def inputs_to_logits_ratio(self): return self.encoder_config.subsampling_conv_stride**2 class LasrEncoderSubsampling(nn.Module): def __init__(self, config: LasrEncoderConfig): super().__init__() self.dense_0 = nn.Linear(config.num_mel_bins, config.hidden_size) self.conv_0 = nn.Conv1d( config.hidden_size, config.hidden_size, kernel_size=config.subsampling_conv_kernel_size, stride=config.subsampling_conv_stride, ) self.conv_1 = nn.Conv1d( config.hidden_size, config.subsampling_conv_channels, kernel_size=config.subsampling_conv_kernel_size, stride=config.subsampling_conv_stride, ) self.dense_1 = nn.Linear(config.subsampling_conv_channels, config.hidden_size) self.act_fn = nn.ReLU() def forward(self, input_features: torch.Tensor) -> torch.Tensor: hidden_states = self.act_fn(self.dense_0(input_features)) hidden_states = hidden_states.transpose(1, 2) hidden_states = self.act_fn(self.conv_0(hidden_states)) hidden_states = self.act_fn(self.conv_1(hidden_states)) hidden_states = hidden_states.transpose(1, 2) return self.dense_1(hidden_states) class LasrEncoderRotaryEmbedding(LlamaRotaryEmbedding): ... class LasrEncoderAttention(LlamaAttention): def __init__(self, config: LasrEncoderConfig, layer_idx: int): super().__init__(config, layer_idx) self.is_causal = False def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, attention_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class LasrEncoderConvolutionModule(ParakeetEncoderConvolutionModule): def __init__(self, config: LasrEncoderConfig, module_config=None): super().__init__(config, module_config) self.padding = "same" self.norm = nn.BatchNorm1d(config.hidden_size, momentum=config.batch_norm_momentum) class LasrEncoderBlock(ParakeetEncoderBlock): def __init__(self, config: LasrEncoderConfig, layer_idx: int): super().__init__(config, layer_idx) self.feed_forward_residual_weights = config.feed_forward_residual_weights self.conv_residual_weights = config.conv_residual_weights self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False) self.norm_self_att = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False) self.norm_conv = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False) self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False) self.norm_out = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_embeddings: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states)) hidden_states = ( self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states ) normalized_hidden_states = self.norm_self_att(hidden_states) attn_output, _ = self.self_attn( hidden_states=normalized_hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, **kwargs, ) hidden_states = hidden_states + attn_output conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask) hidden_states = self.conv_residual_weights[0] * hidden_states + self.conv_residual_weights[1] * conv_output residual = hidden_states hidden_states = self.feed_forward2(self.norm_feed_forward2(hidden_states)) hidden_states = ( self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states ) hidden_states = self.norm_out(hidden_states) return hidden_states class LasrPreTrainedModel(ParakeetPreTrainedModel): # padding is incompatible with flex attention as the resulting mask cannot be used to apply padding _supports_flex_attn = False def _init_weights(self, module): PreTrainedModel._init_weights(module) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): encoder_config = self.config.encoder_config if isinstance(self.config, LasrCTCConfig) else self.config kernel_size = encoder_config.subsampling_conv_kernel_size stride = encoder_config.subsampling_conv_stride num_layers = 2 for _ in range(num_layers): input_lengths = (input_lengths - kernel_size) // stride + 1 return input_lengths @auto_docstring( custom_intro=""" The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100). """ ) class LasrEncoder(LasrPreTrainedModel): config: LasrEncoderConfig base_model_prefix = "encoder" def __init__(self, config: LasrEncoderConfig): super().__init__(config) self.gradient_checkpointing = False self.dropout = config.dropout self.dropout_positions = config.dropout_positions self.layerdrop = config.layerdrop self.subsampler = LasrEncoderSubsampling(config) self.rotary_emb = LasrEncoderRotaryEmbedding(config) self.layers = nn.ModuleList( [LasrEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) self.post_init() @auto_docstring @merge_with_config_defaults @capture_outputs @can_return_tuple def forward( self, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" Example: ```python >>> from transformers import AutoProcessor, LasrEncoder >>> from datasets import load_dataset, Audio >>> model_id = TODO >>> processor = AutoProcessor.from_pretrained(model_id) >>> encoder = ParakeetEncoder.from_pretrained(model_id) >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) >>> inputs = processor(ds[0]["audio"]["array"]) >>> encoder_outputs = encoder(**inputs) >>> print(encoder_outputs.last_hidden_state.shape) ``` """ hidden_states = self.subsampler(input_features) cos, sin = self.rotary_emb( hidden_states, torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training) sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training) if attention_mask is not None: attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1]) attention_mask = create_bidirectional_mask( config=self.config, inputs_embeds=hidden_states, attention_mask=attention_mask, ) for encoder_layer in self.layers: # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) to_drop = False if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: # skip the layer to_drop = True if not to_drop: hidden_states = encoder_layer( hidden_states, attention_mask=attention_mask, position_embeddings=(cos, sin), **kwargs, ) hidden_states = self.out_norm(hidden_states) return BaseModelOutput(last_hidden_state=hidden_states) class LasrForCTC(ParakeetForCTC): def generate(**super_kwargs): r""" Example: ```python >>> from transformers import AutoProcessor, LasrForCTC >>> from datasets import load_dataset, Audio >>> model_id = TODO >>> processor = AutoProcessor.from_pretrained(model_id) >>> model = LasrForCTC.from_pretrained(model_id) >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate)) >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"]) >>> predicted_ids = model.generate(**inputs) >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) >>> print(transcription) ``` """ return super().generate(**super_kwargs) __all__ = [ "LasrForCTC", "LasrEncoder", "LasrPreTrainedModel", "LasrProcessor", "LasrEncoderConfig", "LasrCTCConfig", "LasrTokenizer", ]