| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523 |
- # Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import itertools
- from collections.abc import Callable
- import torch
- from huggingface_hub.dataclasses import strict
- from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
- from tokenizers.models import Unigram
- from torch import nn
- from ...masking_utils import create_bidirectional_mask
- from ...modeling_outputs import BaseModelOutput
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...tokenization_utils_tokenizers import TokenizersBackend
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
- from ...utils.generic import merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from ..llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward
- from ..parakeet.configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
- from ..parakeet.modeling_parakeet import (
- ParakeetEncoderBlock,
- ParakeetEncoderConvolutionModule,
- ParakeetForCTC,
- ParakeetPreTrainedModel,
- )
- from ..parakeet.processing_parakeet import ParakeetProcessor
- from ..t5.tokenization_t5 import T5Tokenizer
- class LasrTokenizer(T5Tokenizer, TokenizersBackend):
- def __init__(
- self,
- eos_token="</s>",
- unk_token="<unk>",
- pad_token="<pad>",
- _spm_precompiled_charsmap=None,
- extra_ids=100,
- additional_special_tokens=None,
- vocab=None,
- vocab_file=None,
- **kwargs,
- ):
- self._extra_ids = extra_ids
- # Handle extra_ids and additional_special_tokens
- if additional_special_tokens is not None:
- extra_tokens = [x for x in additional_special_tokens if "<extra_id_" in str(x)]
- if len(extra_tokens) < 1:
- additional_special_tokens += [f"<extra_id_{i}>" for i in range(extra_ids)]
- elif extra_ids > 0 and extra_ids != len(extra_tokens):
- raise ValueError(
- f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
- " provided to LasrTokenizer. In this case the additional_special_tokens must include the extra_ids"
- " tokens"
- )
- else:
- extra_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
- additional_special_tokens = extra_tokens
- # LASR vocab structure: <pad>=0, </s>=1, <unk>=2, then regular vocab, then extra_ids in reverse
- if vocab is not None:
- self._vocab_scores = vocab
- else:
- self._vocab_scores = [
- (str(pad_token), 0.0),
- (str(eos_token), 0.0),
- (str(unk_token), 0.0),
- ("▁", -2.0), # Space token
- ]
- for i in range(extra_ids - 1, -1, -1):
- self._vocab_scores.append((f"<extra_id_{i}>", 0.0))
- self._tokenizer = Tokenizer(
- Unigram(
- self._vocab_scores,
- unk_id=3,
- byte_fallback=False,
- )
- )
- if _spm_precompiled_charsmap is not None:
- self._tokenizer.normalizer = normalizers.Precompiled(_spm_precompiled_charsmap)
- self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
- [
- pre_tokenizers.WhitespaceSplit(),
- pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True),
- ]
- )
- self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
- TokenizersBackend.__init__(
- eos_token=eos_token,
- unk_token=unk_token,
- pad_token=pad_token,
- extra_ids=extra_ids,
- additional_special_tokens=additional_special_tokens,
- **kwargs,
- )
- self._tokenizer.post_processor = processors.TemplateProcessing(
- single=["$A", "</s>"],
- pair=["$A", "</s>", "$B", "</s>"],
- special_tokens=[
- ("</s>", self.eos_token_id),
- ],
- )
- def _decode(
- self,
- token_ids: int | list[int],
- skip_special_tokens: bool = False,
- clean_up_tokenization_spaces: bool | None = None,
- group_tokens: bool = True,
- **kwargs,
- ) -> str:
- if isinstance(token_ids, int):
- token_ids = [token_ids]
- if group_tokens:
- token_ids = [token_group[0] for token_group in itertools.groupby(token_ids)]
- # for CTC we filter out the blank token, which is the pad token
- token_ids = [token for token in token_ids if token != self.pad_token_id]
- return TokenizersBackend._decode(
- self,
- token_ids=token_ids,
- skip_special_tokens=skip_special_tokens,
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
- **kwargs,
- )
- class LasrProcessor(ParakeetProcessor):
- pass
- @auto_docstring(checkpoint="google/medasr")
- @strict
- class LasrEncoderConfig(ParakeetEncoderConfig):
- r"""
- convolution_bias (`bool`, *optional*, defaults to `False`):
- Whether to use bias in convolutions of the conformer's convolution module.
- conv_kernel_size (`int`, *optional*, defaults to 32):
- The kernel size of the convolution layers in the Conformer block.
- subsampling_conv_channels (`int`, *optional*, defaults to 256):
- The number of channels in the subsampling convolution layers.
- subsampling_conv_kernel_size (`int`, *optional*, defaults to 5):
- The kernel size of the subsampling convolution layers.
- subsampling_conv_stride (`int`, *optional*, defaults to 2):
- The stride of the subsampling convolution layers.
- dropout_positions (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the positions in the input sequence.
- feed_forward_residual_weights (`tuple[float, float]`, *optional*, defaults to `[1.5, 0.5]`):
- The residual weights for the feed forward layers.
- conv_residual_weights (`tuple[float, float]`, *optional*, defaults to `[2.0, 1.0]`):
- The residual weights for the convolution layers.
- batch_norm_momentum (`float`, *optional*, defaults to 0.01):
- The momentum for the batch normalization layers
- Example:
- ```python
- >>> from transformers import LasrEncoderModel, LasrEncoderConfig
- >>> # Initializing a `LasrEncoder` configuration
- >>> configuration = LasrEncoderConfig()
- >>> # Initializing a model from the configuration
- >>> model = LasrEncoderModel(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```
- This configuration class is based on the LasrEncoder architecture from Google Health AI. You can find more details
- and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO).
- """
- hidden_size: int = 512
- num_hidden_layers: int = 17
- intermediate_size: int = 2048
- attention_bias: bool = False
- convolution_bias: bool = False
- conv_kernel_size: int = 32
- subsampling_conv_kernel_size: int = 5
- num_mel_bins: int = 128
- max_position_embeddings: int = 10000
- layer_norm_eps: float = 1e-6
- feed_forward_residual_weights: list[float] | tuple[float, ...] = (1.5, 0.5)
- conv_residual_weights: list[float] | tuple[float, ...] = (2.0, 1.0)
- batch_norm_momentum: float = 0.01
- rope_parameters: dict | None = None
- subsampling_factor = AttributeError()
- scale_input = AttributeError()
- @auto_docstring(checkpoint="google/medasr")
- @strict
- class LasrCTCConfig(ParakeetCTCConfig):
- r"""
- ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`):
- Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
- instance of [`LasrForCTC`].
- ctc_zero_infinity (`bool`, *optional*, defaults to `True`):
- Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
- occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
- of [`LasrForCTC`].
- Example:
- ```python
- >>> from transformers import LasrForCTC, LasrCTCConfig
- >>> # Initializing a Lasr configuration
- >>> configuration = LasrCTCConfig()
- >>> # Initializing a model from the configuration
- >>> model = LasrForCTC(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```
- This configuration class is based on the Lasr CTC architecture from Google Health AI. You can find more details
- and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO).
- """
- vocab_size: int = 512
- pad_token_id: int = 0
- @property
- def inputs_to_logits_ratio(self):
- return self.encoder_config.subsampling_conv_stride**2
- class LasrEncoderSubsampling(nn.Module):
- def __init__(self, config: LasrEncoderConfig):
- super().__init__()
- self.dense_0 = nn.Linear(config.num_mel_bins, config.hidden_size)
- self.conv_0 = nn.Conv1d(
- config.hidden_size,
- config.hidden_size,
- kernel_size=config.subsampling_conv_kernel_size,
- stride=config.subsampling_conv_stride,
- )
- self.conv_1 = nn.Conv1d(
- config.hidden_size,
- config.subsampling_conv_channels,
- kernel_size=config.subsampling_conv_kernel_size,
- stride=config.subsampling_conv_stride,
- )
- self.dense_1 = nn.Linear(config.subsampling_conv_channels, config.hidden_size)
- self.act_fn = nn.ReLU()
- def forward(self, input_features: torch.Tensor) -> torch.Tensor:
- hidden_states = self.act_fn(self.dense_0(input_features))
- hidden_states = hidden_states.transpose(1, 2)
- hidden_states = self.act_fn(self.conv_0(hidden_states))
- hidden_states = self.act_fn(self.conv_1(hidden_states))
- hidden_states = hidden_states.transpose(1, 2)
- return self.dense_1(hidden_states)
- class LasrEncoderRotaryEmbedding(LlamaRotaryEmbedding): ...
- class LasrEncoderAttention(LlamaAttention):
- def __init__(self, config: LasrEncoderConfig, layer_idx: int):
- super().__init__(config, layer_idx)
- self.is_causal = False
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class LasrEncoderConvolutionModule(ParakeetEncoderConvolutionModule):
- def __init__(self, config: LasrEncoderConfig, module_config=None):
- super().__init__(config, module_config)
- self.padding = "same"
- self.norm = nn.BatchNorm1d(config.hidden_size, momentum=config.batch_norm_momentum)
- class LasrEncoderBlock(ParakeetEncoderBlock):
- def __init__(self, config: LasrEncoderConfig, layer_idx: int):
- super().__init__(config, layer_idx)
- self.feed_forward_residual_weights = config.feed_forward_residual_weights
- self.conv_residual_weights = config.conv_residual_weights
- self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
- self.norm_self_att = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
- self.norm_conv = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
- self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
- self.norm_out = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- position_embeddings: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
- hidden_states = (
- self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
- )
- normalized_hidden_states = self.norm_self_att(hidden_states)
- attn_output, _ = self.self_attn(
- hidden_states=normalized_hidden_states,
- attention_mask=attention_mask,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = hidden_states + attn_output
- conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
- hidden_states = self.conv_residual_weights[0] * hidden_states + self.conv_residual_weights[1] * conv_output
- residual = hidden_states
- hidden_states = self.feed_forward2(self.norm_feed_forward2(hidden_states))
- hidden_states = (
- self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
- )
- hidden_states = self.norm_out(hidden_states)
- return hidden_states
- class LasrPreTrainedModel(ParakeetPreTrainedModel):
- # padding is incompatible with flex attention as the resulting mask cannot be used to apply padding
- _supports_flex_attn = False
- def _init_weights(self, module):
- PreTrainedModel._init_weights(module)
- def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
- encoder_config = self.config.encoder_config if isinstance(self.config, LasrCTCConfig) else self.config
- kernel_size = encoder_config.subsampling_conv_kernel_size
- stride = encoder_config.subsampling_conv_stride
- num_layers = 2
- for _ in range(num_layers):
- input_lengths = (input_lengths - kernel_size) // stride + 1
- return input_lengths
- @auto_docstring(
- custom_intro="""
- The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100).
- """
- )
- class LasrEncoder(LasrPreTrainedModel):
- config: LasrEncoderConfig
- base_model_prefix = "encoder"
- def __init__(self, config: LasrEncoderConfig):
- super().__init__(config)
- self.gradient_checkpointing = False
- self.dropout = config.dropout
- self.dropout_positions = config.dropout_positions
- self.layerdrop = config.layerdrop
- self.subsampler = LasrEncoderSubsampling(config)
- self.rotary_emb = LasrEncoderRotaryEmbedding(config)
- self.layers = nn.ModuleList(
- [LasrEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)
- self.post_init()
- @auto_docstring
- @merge_with_config_defaults
- @capture_outputs
- @can_return_tuple
- def forward(
- self,
- input_features: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutput:
- r"""
- Example:
- ```python
- >>> from transformers import AutoProcessor, LasrEncoder
- >>> from datasets import load_dataset, Audio
- >>> model_id = TODO
- >>> processor = AutoProcessor.from_pretrained(model_id)
- >>> encoder = ParakeetEncoder.from_pretrained(model_id)
- >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
- >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
- >>> inputs = processor(ds[0]["audio"]["array"])
- >>> encoder_outputs = encoder(**inputs)
- >>> print(encoder_outputs.last_hidden_state.shape)
- ```
- """
- hidden_states = self.subsampler(input_features)
- cos, sin = self.rotary_emb(
- hidden_states, torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training)
- sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training)
- if attention_mask is not None:
- attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
- attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=hidden_states,
- attention_mask=attention_mask,
- )
- for encoder_layer in self.layers:
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
- to_drop = False
- if self.training:
- dropout_probability = torch.rand([])
- if dropout_probability < self.layerdrop: # skip the layer
- to_drop = True
- if not to_drop:
- hidden_states = encoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- position_embeddings=(cos, sin),
- **kwargs,
- )
- hidden_states = self.out_norm(hidden_states)
- return BaseModelOutput(last_hidden_state=hidden_states)
- class LasrForCTC(ParakeetForCTC):
- def generate(**super_kwargs):
- r"""
- Example:
- ```python
- >>> from transformers import AutoProcessor, LasrForCTC
- >>> from datasets import load_dataset, Audio
- >>> model_id = TODO
- >>> processor = AutoProcessor.from_pretrained(model_id)
- >>> model = LasrForCTC.from_pretrained(model_id)
- >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
- >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
- >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
- >>> predicted_ids = model.generate(**inputs)
- >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
- >>> print(transcription)
- ```
- """
- return super().generate(**super_kwargs)
- __all__ = [
- "LasrForCTC",
- "LasrEncoder",
- "LasrPreTrainedModel",
- "LasrProcessor",
- "LasrEncoderConfig",
- "LasrCTCConfig",
- "LasrTokenizer",
- ]
|