modular_parakeet.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Parakeet model."""
  15. import math
  16. from collections.abc import Callable
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput, CausalLMOutput
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
  27. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  28. from ...utils.output_capturing import capture_outputs
  29. from ..fastspeech2_conformer.modeling_fastspeech2_conformer import FastSpeech2ConformerConvolutionModule
  30. from ..llama.modeling_llama import LlamaAttention, eager_attention_forward
  31. from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
  32. @dataclass
  33. @auto_docstring(
  34. custom_intro="""
  35. Extends [~modeling_outputs.BaseModelOutput] to include the output attention mask since sequence length is not preserved in the model's forward.
  36. """
  37. )
  38. class ParakeetEncoderModelOutput(BaseModelOutput):
  39. attention_mask: torch.Tensor | None = None
  40. class ParakeetEncoderRelPositionalEncoding(nn.Module):
  41. """Relative positional encoding for Parakeet."""
  42. inv_freq: torch.Tensor # fix linting for `register_buffer`
  43. def __init__(self, config: ParakeetEncoderConfig, device=None):
  44. super().__init__()
  45. self.max_position_embeddings = config.max_position_embeddings
  46. base = 10000.0
  47. inv_freq = 1.0 / (
  48. base
  49. ** (
  50. torch.arange(0, config.hidden_size, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
  51. / config.hidden_size
  52. )
  53. )
  54. self.register_buffer("inv_freq", inv_freq, persistent=False)
  55. @torch.no_grad()
  56. def forward(self, hidden_states: torch.Tensor):
  57. seq_length = hidden_states.shape[1]
  58. if seq_length > self.max_position_embeddings:
  59. raise ValueError(
  60. f"Sequence Length: {seq_length} has to be less or equal than "
  61. f"config.max_position_embeddings {self.max_position_embeddings}."
  62. )
  63. position_ids = torch.arange(seq_length - 1, -seq_length, -1, device=hidden_states.device)
  64. inv_freq_expanded = (
  65. self.inv_freq[None, :, None].float().expand(hidden_states.shape[0], -1, 1).to(hidden_states.device)
  66. )
  67. position_ids_expanded = position_ids[None, None, :].float()
  68. device_type = (
  69. hidden_states.device.type
  70. if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
  71. else "cpu"
  72. )
  73. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  74. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  75. sin = freqs.sin()
  76. cos = freqs.cos()
  77. # interleave sin and cos
  78. pos_embed = torch.stack([sin, cos], dim=-1)
  79. pos_embed = pos_embed.reshape(*pos_embed.shape[:-2], -1)
  80. return pos_embed.to(dtype=hidden_states.dtype)
  81. class ParakeetEncoderFeedForward(nn.Module):
  82. def __init__(self, config: ParakeetEncoderConfig):
  83. super().__init__()
  84. self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.attention_bias)
  85. self.activation = ACT2FN[config.hidden_act]
  86. self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.attention_bias)
  87. self.activation_dropout = config.activation_dropout
  88. def forward(self, hidden_states):
  89. hidden_states = self.activation(self.linear1(hidden_states))
  90. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  91. hidden_states = self.linear2(hidden_states)
  92. return hidden_states
  93. class ParakeetEncoderConvolutionModule(FastSpeech2ConformerConvolutionModule):
  94. def __init__(self, config: ParakeetEncoderConfig, module_config=None):
  95. super().__init__(config, module_config)
  96. class ParakeetEncoderAttention(LlamaAttention):
  97. """Multi-head attention with relative positional encoding. See section 3.3 of https://huggingface.co/papers/1901.02860."""
  98. def __init__(self, config: ParakeetEncoderConfig, layer_idx: int):
  99. super().__init__(config, layer_idx=layer_idx)
  100. self.is_causal = False
  101. # W_{k,R} projection
  102. self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  103. # global content bias
  104. self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
  105. # global positional bias
  106. self.bias_v = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
  107. def forward(
  108. self,
  109. hidden_states: torch.Tensor,
  110. position_embeddings: torch.Tensor | None,
  111. attention_mask: torch.Tensor | None = None,
  112. **kwargs: Unpack[TransformersKwargs],
  113. ) -> tuple[torch.Tensor, torch.Tensor]:
  114. input_shape = hidden_states.shape[:-1]
  115. batch_size, seq_length = input_shape
  116. hidden_shape = (batch_size, seq_length, -1, self.head_dim)
  117. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  118. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  119. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  120. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  121. self.config._attn_implementation, eager_attention_forward
  122. )
  123. query_states_with_bias_u = query_states + self.bias_u.view(
  124. 1, self.config.num_attention_heads, 1, self.head_dim
  125. )
  126. query_states_with_bias_v = query_states + self.bias_v.view(
  127. 1, self.config.num_attention_heads, 1, self.head_dim
  128. )
  129. relative_key_states = self.relative_k_proj(position_embeddings)
  130. relative_key_states = relative_key_states.view(batch_size, -1, self.config.num_attention_heads, self.head_dim)
  131. # terms (b) and (d)
  132. matrix_bd = query_states_with_bias_v @ relative_key_states.permute(0, 2, 3, 1)
  133. matrix_bd = self._rel_shift(matrix_bd)
  134. matrix_bd = matrix_bd[..., :seq_length]
  135. matrix_bd = matrix_bd * self.scaling
  136. if attention_mask is not None:
  137. # here the original codebase uses -10000.0 rather than float("-inf") and then manual masked fill with 0.0s
  138. # see: https://github.com/NVIDIA-NeMo/NeMo/blob/8cfedd7203462cb251a914e700e5605444277561/nemo/collections/asr/parts/submodules/multi_head_attention.py#L320-L340
  139. # we rather went for a straight-forward approach with float("-inf")
  140. matrix_bd = matrix_bd.masked_fill_(attention_mask.logical_not(), float("-inf"))
  141. # will compute matrix_ac - terms (a) and (c) - and add matrix_bd
  142. attn_output, attn_weights = attention_interface(
  143. self,
  144. query=query_states_with_bias_u,
  145. key=key_states,
  146. value=value_states,
  147. attention_mask=matrix_bd,
  148. dropout=0.0 if not self.training else self.attention_dropout,
  149. scaling=self.scaling,
  150. **kwargs,
  151. )
  152. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  153. attn_output = self.o_proj(attn_output)
  154. return attn_output, attn_weights
  155. def _rel_shift(self, attention_scores):
  156. """Relative position shift for Shaw et al. style attention. See appendix B of https://huggingface.co/papers/1901.02860."""
  157. batch_size, num_heads, query_length, position_length = attention_scores.shape
  158. attention_scores = nn.functional.pad(attention_scores, pad=(1, 0))
  159. attention_scores = attention_scores.view(batch_size, num_heads, -1, query_length)
  160. attention_scores = attention_scores[:, :, 1:].view(batch_size, num_heads, query_length, position_length)
  161. return attention_scores
  162. class ParakeetEncoderSubsamplingConv2D(nn.Module):
  163. def __init__(self, config: ParakeetEncoderConfig):
  164. super().__init__()
  165. self.kernel_size = config.subsampling_conv_kernel_size
  166. self.stride = config.subsampling_conv_stride
  167. self.channels = config.subsampling_conv_channels
  168. self.padding = (self.kernel_size - 1) // 2
  169. self.num_layers = int(math.log2(config.subsampling_factor))
  170. # define layers
  171. self.layers = nn.ModuleList()
  172. self.layers.append(
  173. nn.Conv2d(1, self.channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
  174. )
  175. self.layers.append(nn.ReLU())
  176. for i in range(self.num_layers - 1):
  177. # depthwise conv
  178. self.layers.append(
  179. nn.Conv2d(
  180. self.channels,
  181. self.channels,
  182. kernel_size=self.kernel_size,
  183. stride=self.stride,
  184. padding=self.padding,
  185. groups=self.channels,
  186. )
  187. )
  188. # pointwise conv
  189. self.layers.append(nn.Conv2d(self.channels, self.channels, kernel_size=1))
  190. # activation
  191. self.layers.append(nn.ReLU())
  192. out_length = config.num_mel_bins // (self.stride**self.num_layers)
  193. self.linear = nn.Linear(config.subsampling_conv_channels * out_length, config.hidden_size, bias=True)
  194. def _get_output_length(self, input_lengths: torch.Tensor, conv_layer: nn.Conv2d):
  195. if hasattr(conv_layer, "stride") and conv_layer.stride != (1, 1):
  196. padding = conv_layer.padding
  197. kernel_size = conv_layer.kernel_size[0]
  198. stride = conv_layer.stride[0]
  199. output_lengths = (input_lengths + padding[0] + padding[1] - kernel_size) // stride + 1
  200. return output_lengths
  201. return input_lengths
  202. def forward(self, input_features: torch.Tensor, attention_mask: torch.Tensor = None):
  203. hidden_states = input_features.unsqueeze(1)
  204. current_lengths = attention_mask.sum(-1) if attention_mask is not None else None
  205. for layer in self.layers:
  206. hidden_states = layer(hidden_states)
  207. # mask the hidden states
  208. if isinstance(layer, nn.Conv2d) and attention_mask is not None:
  209. current_lengths = self._get_output_length(current_lengths, layer)
  210. current_seq_length = hidden_states.shape[2]
  211. channel_mask = (
  212. torch.arange(current_seq_length, device=attention_mask.device) < current_lengths[:, None]
  213. )
  214. hidden_states *= channel_mask[:, None, :, None]
  215. hidden_states = hidden_states.transpose(1, 2).reshape(hidden_states.shape[0], hidden_states.shape[2], -1)
  216. hidden_states = self.linear(hidden_states)
  217. return hidden_states
  218. class ParakeetEncoderBlock(GradientCheckpointingLayer):
  219. def __init__(self, config: ParakeetEncoderConfig, layer_idx: int | None = None):
  220. super().__init__()
  221. self.gradient_checkpointing = False
  222. self.feed_forward1 = ParakeetEncoderFeedForward(config)
  223. self.self_attn = ParakeetEncoderAttention(config, layer_idx)
  224. self.conv = ParakeetEncoderConvolutionModule(config)
  225. self.feed_forward2 = ParakeetEncoderFeedForward(config)
  226. self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size)
  227. self.norm_self_att = nn.LayerNorm(config.hidden_size)
  228. self.norm_conv = nn.LayerNorm(config.hidden_size)
  229. self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size)
  230. self.norm_out = nn.LayerNorm(config.hidden_size)
  231. def forward(
  232. self,
  233. hidden_states: torch.Tensor,
  234. attention_mask: torch.Tensor | None = None,
  235. position_embeddings: torch.Tensor | None = None,
  236. **kwargs: Unpack[TransformersKwargs],
  237. ) -> torch.Tensor:
  238. residual = hidden_states
  239. hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
  240. hidden_states = residual + 0.5 * hidden_states # the conformer architecture uses a factor of 0.5
  241. normalized_hidden_states = self.norm_self_att(hidden_states)
  242. attn_output, _ = self.self_attn(
  243. hidden_states=normalized_hidden_states,
  244. attention_mask=attention_mask,
  245. position_embeddings=position_embeddings,
  246. **kwargs,
  247. )
  248. hidden_states = hidden_states + attn_output
  249. conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
  250. hidden_states = hidden_states + conv_output
  251. ff2_output = self.feed_forward2(self.norm_feed_forward2(hidden_states))
  252. hidden_states = hidden_states + 0.5 * ff2_output # the conformer architecture uses a factor of 0.5
  253. hidden_states = self.norm_out(hidden_states)
  254. return hidden_states
  255. @auto_docstring
  256. class ParakeetPreTrainedModel(PreTrainedModel):
  257. config: ParakeetCTCConfig
  258. base_model_prefix = "model"
  259. main_input_name = "input_features"
  260. input_modalities = "audio"
  261. supports_gradient_checkpointing = True
  262. _no_split_modules = ["ParakeetEncoderBlock"]
  263. _supports_flat_attention_mask = True
  264. _supports_sdpa = True
  265. _supports_flex_attn = True
  266. # TODO: @eustlb, add support when flash attention supports custom attention bias
  267. _supports_flash_attn = False
  268. _can_compile_fullgraph = True
  269. _supports_attention_backend = True
  270. _can_record_outputs = {
  271. "hidden_states": ParakeetEncoderBlock,
  272. "attentions": ParakeetEncoderAttention,
  273. }
  274. @torch.no_grad()
  275. def _init_weights(self, module):
  276. super()._init_weights(module)
  277. if hasattr(self.config, "initializer_range"):
  278. std = self.config.initializer_range
  279. else:
  280. # 0.02 is the standard default value across the library
  281. std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
  282. if isinstance(module, ParakeetEncoderAttention):
  283. # Initialize positional bias parameters
  284. init.normal_(module.bias_u, mean=0.0, std=std)
  285. init.normal_(module.bias_v, mean=0.0, std=std)
  286. elif isinstance(module, ParakeetEncoderRelPositionalEncoding):
  287. inv_freq = 1.0 / (
  288. 10000.0 ** (torch.arange(0, self.config.hidden_size, 2, dtype=torch.int64) / self.config.hidden_size)
  289. )
  290. init.copy_(module.inv_freq, inv_freq)
  291. def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
  292. encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config
  293. kernel_size = encoder_config.subsampling_conv_kernel_size
  294. stride = encoder_config.subsampling_conv_stride
  295. num_layers = int(math.log2(encoder_config.subsampling_factor))
  296. all_paddings = (kernel_size - 1) // 2 * 2
  297. add_pad = all_paddings - kernel_size
  298. lengths = input_lengths
  299. for _ in range(num_layers):
  300. lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + 1.0
  301. lengths = torch.floor(lengths)
  302. return lengths.to(dtype=torch.int)
  303. def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: int | None = None):
  304. """
  305. Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
  306. when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
  307. """
  308. output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
  309. # Use target_length if provided, otherwise use max length in batch
  310. max_length = target_length if target_length is not None else output_lengths.max()
  311. attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None]
  312. return attention_mask
  313. @auto_docstring(
  314. custom_intro="""
  315. The Parakeet Encoder model, based on the [Fast Conformer architecture](https://huggingface.co/papers/2305.05084).
  316. """
  317. )
  318. class ParakeetEncoder(ParakeetPreTrainedModel):
  319. config: ParakeetEncoderConfig
  320. base_model_prefix = "encoder"
  321. def __init__(self, config: ParakeetEncoderConfig):
  322. super().__init__(config)
  323. self.config = config
  324. self.gradient_checkpointing = False
  325. self.dropout = config.dropout
  326. self.dropout_positions = config.dropout_positions
  327. self.layerdrop = config.layerdrop
  328. self.input_scale = math.sqrt(config.hidden_size) if config.scale_input else 1.0
  329. self.subsampling = ParakeetEncoderSubsamplingConv2D(config)
  330. self.encode_positions = ParakeetEncoderRelPositionalEncoding(config)
  331. self.layers = nn.ModuleList(
  332. [ParakeetEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  333. )
  334. self.post_init()
  335. @auto_docstring
  336. @merge_with_config_defaults
  337. @capture_outputs
  338. @can_return_tuple
  339. def forward(
  340. self,
  341. input_features: torch.Tensor,
  342. attention_mask: torch.Tensor | None = None,
  343. output_attention_mask: bool = True,
  344. **kwargs: Unpack[TransformersKwargs],
  345. ) -> BaseModelOutput:
  346. r"""
  347. output_attention_mask (`bool`, *optional*, defaults to `True`):
  348. Whether to return the output attention mask. Only effective when `attention_mask` is provided.
  349. Example:
  350. ```python
  351. >>> from transformers import AutoProcessor, ParakeetEncoder
  352. >>> from datasets import load_dataset, Audio
  353. >>> model_id = "nvidia/parakeet-ctc-1.1b"
  354. >>> processor = AutoProcessor.from_pretrained(model_id)
  355. >>> encoder = ParakeetEncoder.from_pretrained(model_id)
  356. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  357. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  358. >>> inputs = processor(ds[0]["audio"]["array"])
  359. >>> encoder_outputs = encoder(**inputs)
  360. >>> print(encoder_outputs.last_hidden_state.shape)
  361. ```
  362. """
  363. hidden_states = self.subsampling(input_features, attention_mask)
  364. hidden_states = hidden_states * self.input_scale
  365. position_embeddings = self.encode_positions(hidden_states)
  366. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  367. position_embeddings = nn.functional.dropout(
  368. position_embeddings, p=self.dropout_positions, training=self.training
  369. )
  370. if attention_mask is not None:
  371. output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
  372. attention_mask = output_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1)
  373. attention_mask = attention_mask & attention_mask.transpose(1, 2)
  374. attention_mask = attention_mask.unsqueeze(1)
  375. for encoder_layer in self.layers:
  376. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  377. to_drop = False
  378. if self.training:
  379. dropout_probability = torch.rand([])
  380. if dropout_probability < self.layerdrop: # skip the layer
  381. to_drop = True
  382. if not to_drop:
  383. hidden_states = encoder_layer(
  384. hidden_states,
  385. attention_mask=attention_mask,
  386. position_embeddings=position_embeddings,
  387. **kwargs,
  388. )
  389. return ParakeetEncoderModelOutput(
  390. last_hidden_state=hidden_states,
  391. attention_mask=output_mask.int() if attention_mask is not None and output_attention_mask else None,
  392. )
  393. @dataclass
  394. class ParakeetGenerateOutput(ModelOutput):
  395. """
  396. Outputs of Parakeet models.
  397. Args:
  398. sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  399. The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
  400. if all batches finished early due to the `eos_token_id`.
  401. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
  402. Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  403. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  404. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  405. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  406. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  407. `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
  408. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
  409. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  410. `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
  411. """
  412. sequences: torch.LongTensor
  413. logits: tuple[torch.FloatTensor] | None = None
  414. attentions: tuple[tuple[torch.FloatTensor]] | None = None
  415. hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
  416. @auto_docstring(
  417. custom_intro="""
  418. Parakeet Encoder with a Connectionist Temporal Classification (CTC) head.
  419. """
  420. )
  421. class ParakeetForCTC(ParakeetPreTrainedModel):
  422. config: ParakeetCTCConfig
  423. def __init__(self, config: ParakeetCTCConfig):
  424. super().__init__(config)
  425. self.encoder = ParakeetEncoder(config.encoder_config)
  426. # Conv rather than linear to be consistent with NeMO decoding layer
  427. self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
  428. self.post_init()
  429. @auto_docstring
  430. @can_return_tuple
  431. def forward(
  432. self,
  433. input_features: torch.Tensor,
  434. attention_mask: torch.Tensor | None = None,
  435. labels: torch.Tensor | None = None,
  436. **kwargs: Unpack[TransformersKwargs],
  437. ) -> CausalLMOutput:
  438. r"""
  439. Example:
  440. ```python
  441. >>> from transformers import AutoProcessor, ParakeetForCTC
  442. >>> from datasets import load_dataset, Audio
  443. >>> model_id = "nvidia/parakeet-ctc-1.1b"
  444. >>> processor = AutoProcessor.from_pretrained(model_id)
  445. >>> model = ParakeetForCTC.from_pretrained(model_id)
  446. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  447. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  448. >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
  449. >>> outputs = model(**inputs)
  450. >>> print(outputs.loss)
  451. ```"""
  452. encoder_outputs = self.encoder(
  453. input_features=input_features,
  454. attention_mask=attention_mask,
  455. **kwargs,
  456. )
  457. hidden_states = encoder_outputs.last_hidden_state
  458. logits = self.ctc_head(hidden_states.transpose(1, 2)).transpose(1, 2)
  459. loss = None
  460. if labels is not None:
  461. # retrieve loss input_lengths from attention_mask
  462. attention_mask = (
  463. attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
  464. )
  465. input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
  466. # assuming that padded tokens are filled with -100
  467. # when not being attended to
  468. labels_mask = labels != self.config.pad_token_id
  469. target_lengths = labels_mask.sum(-1)
  470. flattened_targets = labels.masked_select(labels_mask)
  471. # ctc_loss doesn't support fp16
  472. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  473. with torch.backends.cudnn.flags(enabled=False):
  474. loss = nn.functional.ctc_loss(
  475. log_probs,
  476. flattened_targets,
  477. input_lengths,
  478. target_lengths,
  479. blank=self.config.pad_token_id,
  480. reduction=self.config.ctc_loss_reduction,
  481. zero_infinity=self.config.ctc_zero_infinity,
  482. )
  483. return CausalLMOutput(
  484. loss=loss,
  485. logits=logits,
  486. hidden_states=encoder_outputs.hidden_states,
  487. attentions=encoder_outputs.attentions,
  488. )
  489. @torch.no_grad()
  490. def generate(
  491. self,
  492. input_features: torch.Tensor,
  493. attention_mask: torch.Tensor | None = None,
  494. return_dict_in_generate: bool = False,
  495. **kwargs: Unpack[TransformersKwargs],
  496. ) -> ParakeetGenerateOutput | torch.LongTensor:
  497. r"""
  498. Example:
  499. ```python
  500. >>> from transformers import AutoProcessor, ParakeetForCTC
  501. >>> from datasets import load_dataset, Audio
  502. >>> model_id = "nvidia/parakeet-ctc-1.1b"
  503. >>> processor = AutoProcessor.from_pretrained(model_id)
  504. >>> model = ParakeetForCTC.from_pretrained(model_id)
  505. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  506. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  507. >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
  508. >>> predicted_ids = model.generate(**inputs)
  509. >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
  510. >>> print(transcription)
  511. ```
  512. """
  513. kwargs["return_dict"] = True
  514. outputs: CausalLMOutput = self.forward(
  515. input_features=input_features,
  516. attention_mask=attention_mask,
  517. **kwargs,
  518. )
  519. # greedy decoding
  520. sequences = outputs.logits.argmax(dim=-1)
  521. # mask out padded tokens
  522. if attention_mask is not None:
  523. attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequences.shape[1])
  524. sequences[~attention_mask] = self.config.pad_token_id
  525. if return_dict_in_generate:
  526. return ParakeetGenerateOutput(
  527. sequences=sequences,
  528. logits=outputs.logits,
  529. attentions=outputs.attentions,
  530. hidden_states=outputs.hidden_states,
  531. )
  532. return sequences
  533. __all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"]