modular_lasr.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. # Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import itertools
  15. from collections.abc import Callable
  16. import torch
  17. from huggingface_hub.dataclasses import strict
  18. from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
  19. from tokenizers.models import Unigram
  20. from torch import nn
  21. from ...masking_utils import create_bidirectional_mask
  22. from ...modeling_outputs import BaseModelOutput
  23. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  24. from ...processing_utils import Unpack
  25. from ...tokenization_utils_tokenizers import TokenizersBackend
  26. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  27. from ...utils.generic import merge_with_config_defaults
  28. from ...utils.output_capturing import capture_outputs
  29. from ..llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward
  30. from ..parakeet.configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
  31. from ..parakeet.modeling_parakeet import (
  32. ParakeetEncoderBlock,
  33. ParakeetEncoderConvolutionModule,
  34. ParakeetForCTC,
  35. ParakeetPreTrainedModel,
  36. )
  37. from ..parakeet.processing_parakeet import ParakeetProcessor
  38. from ..t5.tokenization_t5 import T5Tokenizer
  39. class LasrTokenizer(T5Tokenizer, TokenizersBackend):
  40. def __init__(
  41. self,
  42. eos_token="</s>",
  43. unk_token="<unk>",
  44. pad_token="<pad>",
  45. _spm_precompiled_charsmap=None,
  46. extra_ids=100,
  47. additional_special_tokens=None,
  48. vocab=None,
  49. vocab_file=None,
  50. **kwargs,
  51. ):
  52. self._extra_ids = extra_ids
  53. # Handle extra_ids and additional_special_tokens
  54. if additional_special_tokens is not None:
  55. extra_tokens = [x for x in additional_special_tokens if "<extra_id_" in str(x)]
  56. if len(extra_tokens) < 1:
  57. additional_special_tokens += [f"<extra_id_{i}>" for i in range(extra_ids)]
  58. elif extra_ids > 0 and extra_ids != len(extra_tokens):
  59. raise ValueError(
  60. f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
  61. " provided to LasrTokenizer. In this case the additional_special_tokens must include the extra_ids"
  62. " tokens"
  63. )
  64. else:
  65. extra_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
  66. additional_special_tokens = extra_tokens
  67. # LASR vocab structure: <pad>=0, </s>=1, <unk>=2, then regular vocab, then extra_ids in reverse
  68. if vocab is not None:
  69. self._vocab_scores = vocab
  70. else:
  71. self._vocab_scores = [
  72. (str(pad_token), 0.0),
  73. (str(eos_token), 0.0),
  74. (str(unk_token), 0.0),
  75. ("▁", -2.0), # Space token
  76. ]
  77. for i in range(extra_ids - 1, -1, -1):
  78. self._vocab_scores.append((f"<extra_id_{i}>", 0.0))
  79. self._tokenizer = Tokenizer(
  80. Unigram(
  81. self._vocab_scores,
  82. unk_id=3,
  83. byte_fallback=False,
  84. )
  85. )
  86. if _spm_precompiled_charsmap is not None:
  87. self._tokenizer.normalizer = normalizers.Precompiled(_spm_precompiled_charsmap)
  88. self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  89. [
  90. pre_tokenizers.WhitespaceSplit(),
  91. pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True),
  92. ]
  93. )
  94. self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
  95. TokenizersBackend.__init__(
  96. eos_token=eos_token,
  97. unk_token=unk_token,
  98. pad_token=pad_token,
  99. extra_ids=extra_ids,
  100. additional_special_tokens=additional_special_tokens,
  101. **kwargs,
  102. )
  103. self._tokenizer.post_processor = processors.TemplateProcessing(
  104. single=["$A", "</s>"],
  105. pair=["$A", "</s>", "$B", "</s>"],
  106. special_tokens=[
  107. ("</s>", self.eos_token_id),
  108. ],
  109. )
  110. def _decode(
  111. self,
  112. token_ids: int | list[int],
  113. skip_special_tokens: bool = False,
  114. clean_up_tokenization_spaces: bool | None = None,
  115. group_tokens: bool = True,
  116. **kwargs,
  117. ) -> str:
  118. if isinstance(token_ids, int):
  119. token_ids = [token_ids]
  120. if group_tokens:
  121. token_ids = [token_group[0] for token_group in itertools.groupby(token_ids)]
  122. # for CTC we filter out the blank token, which is the pad token
  123. token_ids = [token for token in token_ids if token != self.pad_token_id]
  124. return TokenizersBackend._decode(
  125. self,
  126. token_ids=token_ids,
  127. skip_special_tokens=skip_special_tokens,
  128. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  129. **kwargs,
  130. )
  131. class LasrProcessor(ParakeetProcessor):
  132. pass
  133. @auto_docstring(checkpoint="google/medasr")
  134. @strict
  135. class LasrEncoderConfig(ParakeetEncoderConfig):
  136. r"""
  137. convolution_bias (`bool`, *optional*, defaults to `False`):
  138. Whether to use bias in convolutions of the conformer's convolution module.
  139. conv_kernel_size (`int`, *optional*, defaults to 32):
  140. The kernel size of the convolution layers in the Conformer block.
  141. subsampling_conv_channels (`int`, *optional*, defaults to 256):
  142. The number of channels in the subsampling convolution layers.
  143. subsampling_conv_kernel_size (`int`, *optional*, defaults to 5):
  144. The kernel size of the subsampling convolution layers.
  145. subsampling_conv_stride (`int`, *optional*, defaults to 2):
  146. The stride of the subsampling convolution layers.
  147. dropout_positions (`float`, *optional*, defaults to 0.0):
  148. The dropout ratio for the positions in the input sequence.
  149. feed_forward_residual_weights (`tuple[float, float]`, *optional*, defaults to `[1.5, 0.5]`):
  150. The residual weights for the feed forward layers.
  151. conv_residual_weights (`tuple[float, float]`, *optional*, defaults to `[2.0, 1.0]`):
  152. The residual weights for the convolution layers.
  153. batch_norm_momentum (`float`, *optional*, defaults to 0.01):
  154. The momentum for the batch normalization layers
  155. Example:
  156. ```python
  157. >>> from transformers import LasrEncoderModel, LasrEncoderConfig
  158. >>> # Initializing a `LasrEncoder` configuration
  159. >>> configuration = LasrEncoderConfig()
  160. >>> # Initializing a model from the configuration
  161. >>> model = LasrEncoderModel(configuration)
  162. >>> # Accessing the model configuration
  163. >>> configuration = model.config
  164. ```
  165. This configuration class is based on the LasrEncoder architecture from Google Health AI. You can find more details
  166. and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO).
  167. """
  168. hidden_size: int = 512
  169. num_hidden_layers: int = 17
  170. intermediate_size: int = 2048
  171. attention_bias: bool = False
  172. convolution_bias: bool = False
  173. conv_kernel_size: int = 32
  174. subsampling_conv_kernel_size: int = 5
  175. num_mel_bins: int = 128
  176. max_position_embeddings: int = 10000
  177. layer_norm_eps: float = 1e-6
  178. feed_forward_residual_weights: list[float] | tuple[float, ...] = (1.5, 0.5)
  179. conv_residual_weights: list[float] | tuple[float, ...] = (2.0, 1.0)
  180. batch_norm_momentum: float = 0.01
  181. rope_parameters: dict | None = None
  182. subsampling_factor = AttributeError()
  183. scale_input = AttributeError()
  184. @auto_docstring(checkpoint="google/medasr")
  185. @strict
  186. class LasrCTCConfig(ParakeetCTCConfig):
  187. r"""
  188. ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`):
  189. Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
  190. instance of [`LasrForCTC`].
  191. ctc_zero_infinity (`bool`, *optional*, defaults to `True`):
  192. Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
  193. occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
  194. of [`LasrForCTC`].
  195. Example:
  196. ```python
  197. >>> from transformers import LasrForCTC, LasrCTCConfig
  198. >>> # Initializing a Lasr configuration
  199. >>> configuration = LasrCTCConfig()
  200. >>> # Initializing a model from the configuration
  201. >>> model = LasrForCTC(configuration)
  202. >>> # Accessing the model configuration
  203. >>> configuration = model.config
  204. ```
  205. This configuration class is based on the Lasr CTC architecture from Google Health AI. You can find more details
  206. and pre-trained models at [TODO/TODO](https://huggingface.co/TODO/TODO).
  207. """
  208. vocab_size: int = 512
  209. pad_token_id: int = 0
  210. @property
  211. def inputs_to_logits_ratio(self):
  212. return self.encoder_config.subsampling_conv_stride**2
  213. class LasrEncoderSubsampling(nn.Module):
  214. def __init__(self, config: LasrEncoderConfig):
  215. super().__init__()
  216. self.dense_0 = nn.Linear(config.num_mel_bins, config.hidden_size)
  217. self.conv_0 = nn.Conv1d(
  218. config.hidden_size,
  219. config.hidden_size,
  220. kernel_size=config.subsampling_conv_kernel_size,
  221. stride=config.subsampling_conv_stride,
  222. )
  223. self.conv_1 = nn.Conv1d(
  224. config.hidden_size,
  225. config.subsampling_conv_channels,
  226. kernel_size=config.subsampling_conv_kernel_size,
  227. stride=config.subsampling_conv_stride,
  228. )
  229. self.dense_1 = nn.Linear(config.subsampling_conv_channels, config.hidden_size)
  230. self.act_fn = nn.ReLU()
  231. def forward(self, input_features: torch.Tensor) -> torch.Tensor:
  232. hidden_states = self.act_fn(self.dense_0(input_features))
  233. hidden_states = hidden_states.transpose(1, 2)
  234. hidden_states = self.act_fn(self.conv_0(hidden_states))
  235. hidden_states = self.act_fn(self.conv_1(hidden_states))
  236. hidden_states = hidden_states.transpose(1, 2)
  237. return self.dense_1(hidden_states)
  238. class LasrEncoderRotaryEmbedding(LlamaRotaryEmbedding): ...
  239. class LasrEncoderAttention(LlamaAttention):
  240. def __init__(self, config: LasrEncoderConfig, layer_idx: int):
  241. super().__init__(config, layer_idx)
  242. self.is_causal = False
  243. def forward(
  244. self,
  245. hidden_states: torch.Tensor,
  246. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  247. attention_mask: torch.Tensor | None = None,
  248. **kwargs: Unpack[TransformersKwargs],
  249. ) -> tuple[torch.Tensor, torch.Tensor]:
  250. input_shape = hidden_states.shape[:-1]
  251. hidden_shape = (*input_shape, -1, self.head_dim)
  252. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  253. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  254. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  255. cos, sin = position_embeddings
  256. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  257. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  258. self.config._attn_implementation, eager_attention_forward
  259. )
  260. attn_output, attn_weights = attention_interface(
  261. self,
  262. query_states,
  263. key_states,
  264. value_states,
  265. attention_mask,
  266. dropout=0.0 if not self.training else self.attention_dropout,
  267. scaling=self.scaling,
  268. **kwargs,
  269. )
  270. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  271. attn_output = self.o_proj(attn_output)
  272. return attn_output, attn_weights
  273. class LasrEncoderConvolutionModule(ParakeetEncoderConvolutionModule):
  274. def __init__(self, config: LasrEncoderConfig, module_config=None):
  275. super().__init__(config, module_config)
  276. self.padding = "same"
  277. self.norm = nn.BatchNorm1d(config.hidden_size, momentum=config.batch_norm_momentum)
  278. class LasrEncoderBlock(ParakeetEncoderBlock):
  279. def __init__(self, config: LasrEncoderConfig, layer_idx: int):
  280. super().__init__(config, layer_idx)
  281. self.feed_forward_residual_weights = config.feed_forward_residual_weights
  282. self.conv_residual_weights = config.conv_residual_weights
  283. self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
  284. self.norm_self_att = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
  285. self.norm_conv = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
  286. self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
  287. self.norm_out = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
  288. def forward(
  289. self,
  290. hidden_states: torch.Tensor,
  291. attention_mask: torch.Tensor | None = None,
  292. position_embeddings: torch.Tensor | None = None,
  293. **kwargs: Unpack[TransformersKwargs],
  294. ) -> torch.Tensor:
  295. residual = hidden_states
  296. hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
  297. hidden_states = (
  298. self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
  299. )
  300. normalized_hidden_states = self.norm_self_att(hidden_states)
  301. attn_output, _ = self.self_attn(
  302. hidden_states=normalized_hidden_states,
  303. attention_mask=attention_mask,
  304. position_embeddings=position_embeddings,
  305. **kwargs,
  306. )
  307. hidden_states = hidden_states + attn_output
  308. conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
  309. hidden_states = self.conv_residual_weights[0] * hidden_states + self.conv_residual_weights[1] * conv_output
  310. residual = hidden_states
  311. hidden_states = self.feed_forward2(self.norm_feed_forward2(hidden_states))
  312. hidden_states = (
  313. self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
  314. )
  315. hidden_states = self.norm_out(hidden_states)
  316. return hidden_states
  317. class LasrPreTrainedModel(ParakeetPreTrainedModel):
  318. # padding is incompatible with flex attention as the resulting mask cannot be used to apply padding
  319. _supports_flex_attn = False
  320. def _init_weights(self, module):
  321. PreTrainedModel._init_weights(module)
  322. def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
  323. encoder_config = self.config.encoder_config if isinstance(self.config, LasrCTCConfig) else self.config
  324. kernel_size = encoder_config.subsampling_conv_kernel_size
  325. stride = encoder_config.subsampling_conv_stride
  326. num_layers = 2
  327. for _ in range(num_layers):
  328. input_lengths = (input_lengths - kernel_size) // stride + 1
  329. return input_lengths
  330. @auto_docstring(
  331. custom_intro="""
  332. The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100).
  333. """
  334. )
  335. class LasrEncoder(LasrPreTrainedModel):
  336. config: LasrEncoderConfig
  337. base_model_prefix = "encoder"
  338. def __init__(self, config: LasrEncoderConfig):
  339. super().__init__(config)
  340. self.gradient_checkpointing = False
  341. self.dropout = config.dropout
  342. self.dropout_positions = config.dropout_positions
  343. self.layerdrop = config.layerdrop
  344. self.subsampler = LasrEncoderSubsampling(config)
  345. self.rotary_emb = LasrEncoderRotaryEmbedding(config)
  346. self.layers = nn.ModuleList(
  347. [LasrEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  348. )
  349. self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)
  350. self.post_init()
  351. @auto_docstring
  352. @merge_with_config_defaults
  353. @capture_outputs
  354. @can_return_tuple
  355. def forward(
  356. self,
  357. input_features: torch.Tensor,
  358. attention_mask: torch.Tensor | None = None,
  359. **kwargs: Unpack[TransformersKwargs],
  360. ) -> BaseModelOutput:
  361. r"""
  362. Example:
  363. ```python
  364. >>> from transformers import AutoProcessor, LasrEncoder
  365. >>> from datasets import load_dataset, Audio
  366. >>> model_id = TODO
  367. >>> processor = AutoProcessor.from_pretrained(model_id)
  368. >>> encoder = ParakeetEncoder.from_pretrained(model_id)
  369. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  370. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  371. >>> inputs = processor(ds[0]["audio"]["array"])
  372. >>> encoder_outputs = encoder(**inputs)
  373. >>> print(encoder_outputs.last_hidden_state.shape)
  374. ```
  375. """
  376. hidden_states = self.subsampler(input_features)
  377. cos, sin = self.rotary_emb(
  378. hidden_states, torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
  379. )
  380. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  381. cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training)
  382. sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training)
  383. if attention_mask is not None:
  384. attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
  385. attention_mask = create_bidirectional_mask(
  386. config=self.config,
  387. inputs_embeds=hidden_states,
  388. attention_mask=attention_mask,
  389. )
  390. for encoder_layer in self.layers:
  391. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  392. to_drop = False
  393. if self.training:
  394. dropout_probability = torch.rand([])
  395. if dropout_probability < self.layerdrop: # skip the layer
  396. to_drop = True
  397. if not to_drop:
  398. hidden_states = encoder_layer(
  399. hidden_states,
  400. attention_mask=attention_mask,
  401. position_embeddings=(cos, sin),
  402. **kwargs,
  403. )
  404. hidden_states = self.out_norm(hidden_states)
  405. return BaseModelOutput(last_hidden_state=hidden_states)
  406. class LasrForCTC(ParakeetForCTC):
  407. def generate(**super_kwargs):
  408. r"""
  409. Example:
  410. ```python
  411. >>> from transformers import AutoProcessor, LasrForCTC
  412. >>> from datasets import load_dataset, Audio
  413. >>> model_id = TODO
  414. >>> processor = AutoProcessor.from_pretrained(model_id)
  415. >>> model = LasrForCTC.from_pretrained(model_id)
  416. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  417. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  418. >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
  419. >>> predicted_ids = model.generate(**inputs)
  420. >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
  421. >>> print(transcription)
  422. ```
  423. """
  424. return super().generate(**super_kwargs)
  425. __all__ = [
  426. "LasrForCTC",
  427. "LasrEncoder",
  428. "LasrPreTrainedModel",
  429. "LasrProcessor",
  430. "LasrEncoderConfig",
  431. "LasrCTCConfig",
  432. "LasrTokenizer",
  433. ]