modeling_parakeet.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/parakeet/modular_parakeet.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_parakeet.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. from dataclasses import dataclass
  23. import torch
  24. from torch import nn
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...integrations import use_kernel_func_from_hub, use_kernelized_func
  28. from ...modeling_layers import GradientCheckpointingLayer
  29. from ...modeling_outputs import BaseModelOutput, CausalLMOutput
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
  33. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  34. from ...utils.output_capturing import capture_outputs
  35. from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
  36. @dataclass
  37. @auto_docstring(
  38. custom_intro="""
  39. Extends [~modeling_outputs.BaseModelOutput] to include the output attention mask since sequence length is not preserved in the model's forward.
  40. """
  41. )
  42. class ParakeetEncoderModelOutput(BaseModelOutput):
  43. attention_mask: torch.Tensor | None = None
  44. class ParakeetEncoderRelPositionalEncoding(nn.Module):
  45. """Relative positional encoding for Parakeet."""
  46. inv_freq: torch.Tensor # fix linting for `register_buffer`
  47. def __init__(self, config: ParakeetEncoderConfig, device=None):
  48. super().__init__()
  49. self.max_position_embeddings = config.max_position_embeddings
  50. base = 10000.0
  51. inv_freq = 1.0 / (
  52. base
  53. ** (
  54. torch.arange(0, config.hidden_size, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
  55. / config.hidden_size
  56. )
  57. )
  58. self.register_buffer("inv_freq", inv_freq, persistent=False)
  59. @torch.no_grad()
  60. def forward(self, hidden_states: torch.Tensor):
  61. seq_length = hidden_states.shape[1]
  62. if seq_length > self.max_position_embeddings:
  63. raise ValueError(
  64. f"Sequence Length: {seq_length} has to be less or equal than "
  65. f"config.max_position_embeddings {self.max_position_embeddings}."
  66. )
  67. position_ids = torch.arange(seq_length - 1, -seq_length, -1, device=hidden_states.device)
  68. inv_freq_expanded = (
  69. self.inv_freq[None, :, None].float().expand(hidden_states.shape[0], -1, 1).to(hidden_states.device)
  70. )
  71. position_ids_expanded = position_ids[None, None, :].float()
  72. device_type = (
  73. hidden_states.device.type
  74. if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
  75. else "cpu"
  76. )
  77. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  78. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  79. sin = freqs.sin()
  80. cos = freqs.cos()
  81. # interleave sin and cos
  82. pos_embed = torch.stack([sin, cos], dim=-1)
  83. pos_embed = pos_embed.reshape(*pos_embed.shape[:-2], -1)
  84. return pos_embed.to(dtype=hidden_states.dtype)
  85. class ParakeetEncoderFeedForward(nn.Module):
  86. def __init__(self, config: ParakeetEncoderConfig):
  87. super().__init__()
  88. self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.attention_bias)
  89. self.activation = ACT2FN[config.hidden_act]
  90. self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.attention_bias)
  91. self.activation_dropout = config.activation_dropout
  92. def forward(self, hidden_states):
  93. hidden_states = self.activation(self.linear1(hidden_states))
  94. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  95. hidden_states = self.linear2(hidden_states)
  96. return hidden_states
  97. class ParakeetEncoderConvolutionModule(nn.Module):
  98. def __init__(self, config: ParakeetEncoderConfig, module_config=None):
  99. """
  100. Args:
  101. config (ParakeetEncoderConfig): Configuration for the model.
  102. module_config (dict): Configuration for the module (e.g., encoder or decoder).
  103. """
  104. super().__init__()
  105. channels = config.hidden_size
  106. # kernel_size should be an odd number for 'SAME' padding
  107. if module_config is None:
  108. # e.g. using `ParakeetEncoderEncoderConfig` in src/transformers/models/parakeet_encoder/configuration_parakeet_encoder.py
  109. kernel_size = config.conv_kernel_size
  110. self.activation = ACT2FN[getattr(config, "hidden_act", "silu")]
  111. else:
  112. kernel_size = module_config["kernel_size"]
  113. self.activation = ACT2FN[module_config.get("activation", "silu")]
  114. self.padding = (kernel_size - 1) // 2
  115. self.pointwise_conv1 = nn.Conv1d(
  116. channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
  117. )
  118. self.depthwise_conv = nn.Conv1d(
  119. channels,
  120. channels,
  121. kernel_size,
  122. stride=1,
  123. padding=self.padding,
  124. groups=channels,
  125. bias=config.convolution_bias,
  126. )
  127. self.norm = nn.BatchNorm1d(channels)
  128. self.pointwise_conv2 = nn.Conv1d(
  129. channels, channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
  130. )
  131. def forward(self, hidden_states, attention_mask=None):
  132. """
  133. Compute convolution module.
  134. Args:
  135. hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
  136. attention_mask (`torch.Tensor` of shape `(batch, 1, time, time)`): Attention mask.
  137. Returns:
  138. `torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
  139. """
  140. # exchange the temporal dimension and the feature dimension
  141. hidden_states = hidden_states.transpose(1, 2)
  142. # GLU mechanism, (batch_size, 2*channel, dim)
  143. hidden_states = self.pointwise_conv1(hidden_states)
  144. # (batch_size, channel, dim)
  145. hidden_states = nn.functional.glu(hidden_states, dim=1)
  146. # Apply padding mask before convolution
  147. if attention_mask is not None:
  148. if attention_mask.dtype == torch.bool:
  149. all_masked_rows = torch.all(~attention_mask, dim=2)
  150. else:
  151. all_masked_rows = torch.all(~(attention_mask == 0.0), dim=2)
  152. hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
  153. # 1D Depthwise Conv
  154. hidden_states = self.depthwise_conv(hidden_states)
  155. hidden_states = self.norm(hidden_states)
  156. hidden_states = self.activation(hidden_states)
  157. hidden_states = self.pointwise_conv2(hidden_states)
  158. return hidden_states.transpose(1, 2)
  159. def rotate_half(x):
  160. """Rotates half the hidden dims of the input."""
  161. x1 = x[..., : x.shape[-1] // 2]
  162. x2 = x[..., x.shape[-1] // 2 :]
  163. return torch.cat((-x2, x1), dim=-1)
  164. @use_kernel_func_from_hub("rotary_pos_emb")
  165. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  166. """Applies Rotary Position Embedding to the query and key tensors.
  167. Args:
  168. q (`torch.Tensor`): The query tensor.
  169. k (`torch.Tensor`): The key tensor.
  170. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  171. sin (`torch.Tensor`): The sine part of the rotary embedding.
  172. unsqueeze_dim (`int`, *optional*, defaults to 1):
  173. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  174. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  175. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  176. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  177. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  178. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  179. Returns:
  180. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  181. """
  182. cos = cos.unsqueeze(unsqueeze_dim)
  183. sin = sin.unsqueeze(unsqueeze_dim)
  184. q_embed = (q * cos) + (rotate_half(q) * sin)
  185. k_embed = (k * cos) + (rotate_half(k) * sin)
  186. return q_embed, k_embed
  187. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  188. """
  189. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  190. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  191. """
  192. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  193. if n_rep == 1:
  194. return hidden_states
  195. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  196. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  197. def eager_attention_forward(
  198. module: nn.Module,
  199. query: torch.Tensor,
  200. key: torch.Tensor,
  201. value: torch.Tensor,
  202. attention_mask: torch.Tensor | None,
  203. scaling: float,
  204. dropout: float = 0.0,
  205. **kwargs: Unpack[TransformersKwargs],
  206. ):
  207. key_states = repeat_kv(key, module.num_key_value_groups)
  208. value_states = repeat_kv(value, module.num_key_value_groups)
  209. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  210. if attention_mask is not None:
  211. attn_weights = attn_weights + attention_mask
  212. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  213. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  214. attn_output = torch.matmul(attn_weights, value_states)
  215. attn_output = attn_output.transpose(1, 2).contiguous()
  216. return attn_output, attn_weights
  217. @use_kernelized_func(apply_rotary_pos_emb)
  218. class ParakeetEncoderAttention(nn.Module):
  219. """Multi-head attention with relative positional encoding. See section 3.3 of https://huggingface.co/papers/1901.02860."""
  220. def __init__(self, config: ParakeetEncoderConfig, layer_idx: int):
  221. super().__init__()
  222. self.config = config
  223. self.layer_idx = layer_idx
  224. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  225. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  226. self.scaling = self.head_dim**-0.5
  227. self.attention_dropout = config.attention_dropout
  228. self.is_causal = False
  229. self.q_proj = nn.Linear(
  230. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  231. )
  232. self.k_proj = nn.Linear(
  233. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  234. )
  235. self.v_proj = nn.Linear(
  236. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  237. )
  238. self.o_proj = nn.Linear(
  239. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  240. )
  241. # W_{k,R} projection
  242. self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  243. # global content bias
  244. self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
  245. # global positional bias
  246. self.bias_v = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
  247. def forward(
  248. self,
  249. hidden_states: torch.Tensor,
  250. position_embeddings: torch.Tensor | None,
  251. attention_mask: torch.Tensor | None = None,
  252. **kwargs: Unpack[TransformersKwargs],
  253. ) -> tuple[torch.Tensor, torch.Tensor]:
  254. input_shape = hidden_states.shape[:-1]
  255. batch_size, seq_length = input_shape
  256. hidden_shape = (batch_size, seq_length, -1, self.head_dim)
  257. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  258. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  259. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  260. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  261. self.config._attn_implementation, eager_attention_forward
  262. )
  263. query_states_with_bias_u = query_states + self.bias_u.view(
  264. 1, self.config.num_attention_heads, 1, self.head_dim
  265. )
  266. query_states_with_bias_v = query_states + self.bias_v.view(
  267. 1, self.config.num_attention_heads, 1, self.head_dim
  268. )
  269. relative_key_states = self.relative_k_proj(position_embeddings)
  270. relative_key_states = relative_key_states.view(batch_size, -1, self.config.num_attention_heads, self.head_dim)
  271. # terms (b) and (d)
  272. matrix_bd = query_states_with_bias_v @ relative_key_states.permute(0, 2, 3, 1)
  273. matrix_bd = self._rel_shift(matrix_bd)
  274. matrix_bd = matrix_bd[..., :seq_length]
  275. matrix_bd = matrix_bd * self.scaling
  276. if attention_mask is not None:
  277. # here the original codebase uses -10000.0 rather than float("-inf") and then manual masked fill with 0.0s
  278. # see: https://github.com/NVIDIA-NeMo/NeMo/blob/8cfedd7203462cb251a914e700e5605444277561/nemo/collections/asr/parts/submodules/multi_head_attention.py#L320-L340
  279. # we rather went for a straight-forward approach with float("-inf")
  280. matrix_bd = matrix_bd.masked_fill_(attention_mask.logical_not(), float("-inf"))
  281. # will compute matrix_ac - terms (a) and (c) - and add matrix_bd
  282. attn_output, attn_weights = attention_interface(
  283. self,
  284. query=query_states_with_bias_u,
  285. key=key_states,
  286. value=value_states,
  287. attention_mask=matrix_bd,
  288. dropout=0.0 if not self.training else self.attention_dropout,
  289. scaling=self.scaling,
  290. **kwargs,
  291. )
  292. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  293. attn_output = self.o_proj(attn_output)
  294. return attn_output, attn_weights
  295. def _rel_shift(self, attention_scores):
  296. """Relative position shift for Shaw et al. style attention. See appendix B of https://huggingface.co/papers/1901.02860."""
  297. batch_size, num_heads, query_length, position_length = attention_scores.shape
  298. attention_scores = nn.functional.pad(attention_scores, pad=(1, 0))
  299. attention_scores = attention_scores.view(batch_size, num_heads, -1, query_length)
  300. attention_scores = attention_scores[:, :, 1:].view(batch_size, num_heads, query_length, position_length)
  301. return attention_scores
  302. class ParakeetEncoderSubsamplingConv2D(nn.Module):
  303. def __init__(self, config: ParakeetEncoderConfig):
  304. super().__init__()
  305. self.kernel_size = config.subsampling_conv_kernel_size
  306. self.stride = config.subsampling_conv_stride
  307. self.channels = config.subsampling_conv_channels
  308. self.padding = (self.kernel_size - 1) // 2
  309. self.num_layers = int(math.log2(config.subsampling_factor))
  310. # define layers
  311. self.layers = nn.ModuleList()
  312. self.layers.append(
  313. nn.Conv2d(1, self.channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
  314. )
  315. self.layers.append(nn.ReLU())
  316. for i in range(self.num_layers - 1):
  317. # depthwise conv
  318. self.layers.append(
  319. nn.Conv2d(
  320. self.channels,
  321. self.channels,
  322. kernel_size=self.kernel_size,
  323. stride=self.stride,
  324. padding=self.padding,
  325. groups=self.channels,
  326. )
  327. )
  328. # pointwise conv
  329. self.layers.append(nn.Conv2d(self.channels, self.channels, kernel_size=1))
  330. # activation
  331. self.layers.append(nn.ReLU())
  332. out_length = config.num_mel_bins // (self.stride**self.num_layers)
  333. self.linear = nn.Linear(config.subsampling_conv_channels * out_length, config.hidden_size, bias=True)
  334. def _get_output_length(self, input_lengths: torch.Tensor, conv_layer: nn.Conv2d):
  335. if hasattr(conv_layer, "stride") and conv_layer.stride != (1, 1):
  336. padding = conv_layer.padding
  337. kernel_size = conv_layer.kernel_size[0]
  338. stride = conv_layer.stride[0]
  339. output_lengths = (input_lengths + padding[0] + padding[1] - kernel_size) // stride + 1
  340. return output_lengths
  341. return input_lengths
  342. def forward(self, input_features: torch.Tensor, attention_mask: torch.Tensor = None):
  343. hidden_states = input_features.unsqueeze(1)
  344. current_lengths = attention_mask.sum(-1) if attention_mask is not None else None
  345. for layer in self.layers:
  346. hidden_states = layer(hidden_states)
  347. # mask the hidden states
  348. if isinstance(layer, nn.Conv2d) and attention_mask is not None:
  349. current_lengths = self._get_output_length(current_lengths, layer)
  350. current_seq_length = hidden_states.shape[2]
  351. channel_mask = (
  352. torch.arange(current_seq_length, device=attention_mask.device) < current_lengths[:, None]
  353. )
  354. hidden_states *= channel_mask[:, None, :, None]
  355. hidden_states = hidden_states.transpose(1, 2).reshape(hidden_states.shape[0], hidden_states.shape[2], -1)
  356. hidden_states = self.linear(hidden_states)
  357. return hidden_states
  358. class ParakeetEncoderBlock(GradientCheckpointingLayer):
  359. def __init__(self, config: ParakeetEncoderConfig, layer_idx: int | None = None):
  360. super().__init__()
  361. self.gradient_checkpointing = False
  362. self.feed_forward1 = ParakeetEncoderFeedForward(config)
  363. self.self_attn = ParakeetEncoderAttention(config, layer_idx)
  364. self.conv = ParakeetEncoderConvolutionModule(config)
  365. self.feed_forward2 = ParakeetEncoderFeedForward(config)
  366. self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size)
  367. self.norm_self_att = nn.LayerNorm(config.hidden_size)
  368. self.norm_conv = nn.LayerNorm(config.hidden_size)
  369. self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size)
  370. self.norm_out = nn.LayerNorm(config.hidden_size)
  371. def forward(
  372. self,
  373. hidden_states: torch.Tensor,
  374. attention_mask: torch.Tensor | None = None,
  375. position_embeddings: torch.Tensor | None = None,
  376. **kwargs: Unpack[TransformersKwargs],
  377. ) -> torch.Tensor:
  378. residual = hidden_states
  379. hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
  380. hidden_states = residual + 0.5 * hidden_states # the conformer architecture uses a factor of 0.5
  381. normalized_hidden_states = self.norm_self_att(hidden_states)
  382. attn_output, _ = self.self_attn(
  383. hidden_states=normalized_hidden_states,
  384. attention_mask=attention_mask,
  385. position_embeddings=position_embeddings,
  386. **kwargs,
  387. )
  388. hidden_states = hidden_states + attn_output
  389. conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
  390. hidden_states = hidden_states + conv_output
  391. ff2_output = self.feed_forward2(self.norm_feed_forward2(hidden_states))
  392. hidden_states = hidden_states + 0.5 * ff2_output # the conformer architecture uses a factor of 0.5
  393. hidden_states = self.norm_out(hidden_states)
  394. return hidden_states
  395. @auto_docstring
  396. class ParakeetPreTrainedModel(PreTrainedModel):
  397. config: ParakeetCTCConfig
  398. base_model_prefix = "model"
  399. main_input_name = "input_features"
  400. input_modalities = "audio"
  401. supports_gradient_checkpointing = True
  402. _no_split_modules = ["ParakeetEncoderBlock"]
  403. _supports_flat_attention_mask = True
  404. _supports_sdpa = True
  405. _supports_flex_attn = True
  406. # TODO: @eustlb, add support when flash attention supports custom attention bias
  407. _supports_flash_attn = False
  408. _can_compile_fullgraph = True
  409. _supports_attention_backend = True
  410. _can_record_outputs = {
  411. "hidden_states": ParakeetEncoderBlock,
  412. "attentions": ParakeetEncoderAttention,
  413. }
  414. @torch.no_grad()
  415. def _init_weights(self, module):
  416. super()._init_weights(module)
  417. if hasattr(self.config, "initializer_range"):
  418. std = self.config.initializer_range
  419. else:
  420. # 0.02 is the standard default value across the library
  421. std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
  422. if isinstance(module, ParakeetEncoderAttention):
  423. # Initialize positional bias parameters
  424. init.normal_(module.bias_u, mean=0.0, std=std)
  425. init.normal_(module.bias_v, mean=0.0, std=std)
  426. elif isinstance(module, ParakeetEncoderRelPositionalEncoding):
  427. inv_freq = 1.0 / (
  428. 10000.0 ** (torch.arange(0, self.config.hidden_size, 2, dtype=torch.int64) / self.config.hidden_size)
  429. )
  430. init.copy_(module.inv_freq, inv_freq)
  431. def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
  432. encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config
  433. kernel_size = encoder_config.subsampling_conv_kernel_size
  434. stride = encoder_config.subsampling_conv_stride
  435. num_layers = int(math.log2(encoder_config.subsampling_factor))
  436. all_paddings = (kernel_size - 1) // 2 * 2
  437. add_pad = all_paddings - kernel_size
  438. lengths = input_lengths
  439. for _ in range(num_layers):
  440. lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + 1.0
  441. lengths = torch.floor(lengths)
  442. return lengths.to(dtype=torch.int)
  443. def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: int | None = None):
  444. """
  445. Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
  446. when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
  447. """
  448. output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
  449. # Use target_length if provided, otherwise use max length in batch
  450. max_length = target_length if target_length is not None else output_lengths.max()
  451. attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None]
  452. return attention_mask
  453. @auto_docstring(
  454. custom_intro="""
  455. The Parakeet Encoder model, based on the [Fast Conformer architecture](https://huggingface.co/papers/2305.05084).
  456. """
  457. )
  458. class ParakeetEncoder(ParakeetPreTrainedModel):
  459. config: ParakeetEncoderConfig
  460. base_model_prefix = "encoder"
  461. def __init__(self, config: ParakeetEncoderConfig):
  462. super().__init__(config)
  463. self.config = config
  464. self.gradient_checkpointing = False
  465. self.dropout = config.dropout
  466. self.dropout_positions = config.dropout_positions
  467. self.layerdrop = config.layerdrop
  468. self.input_scale = math.sqrt(config.hidden_size) if config.scale_input else 1.0
  469. self.subsampling = ParakeetEncoderSubsamplingConv2D(config)
  470. self.encode_positions = ParakeetEncoderRelPositionalEncoding(config)
  471. self.layers = nn.ModuleList(
  472. [ParakeetEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  473. )
  474. self.post_init()
  475. @auto_docstring
  476. @merge_with_config_defaults
  477. @capture_outputs
  478. @can_return_tuple
  479. def forward(
  480. self,
  481. input_features: torch.Tensor,
  482. attention_mask: torch.Tensor | None = None,
  483. output_attention_mask: bool = True,
  484. **kwargs: Unpack[TransformersKwargs],
  485. ) -> BaseModelOutput:
  486. r"""
  487. output_attention_mask (`bool`, *optional*, defaults to `True`):
  488. Whether to return the output attention mask. Only effective when `attention_mask` is provided.
  489. Example:
  490. ```python
  491. >>> from transformers import AutoProcessor, ParakeetEncoder
  492. >>> from datasets import load_dataset, Audio
  493. >>> model_id = "nvidia/parakeet-ctc-1.1b"
  494. >>> processor = AutoProcessor.from_pretrained(model_id)
  495. >>> encoder = ParakeetEncoder.from_pretrained(model_id)
  496. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  497. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  498. >>> inputs = processor(ds[0]["audio"]["array"])
  499. >>> encoder_outputs = encoder(**inputs)
  500. >>> print(encoder_outputs.last_hidden_state.shape)
  501. ```
  502. """
  503. hidden_states = self.subsampling(input_features, attention_mask)
  504. hidden_states = hidden_states * self.input_scale
  505. position_embeddings = self.encode_positions(hidden_states)
  506. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  507. position_embeddings = nn.functional.dropout(
  508. position_embeddings, p=self.dropout_positions, training=self.training
  509. )
  510. if attention_mask is not None:
  511. output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
  512. attention_mask = output_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1)
  513. attention_mask = attention_mask & attention_mask.transpose(1, 2)
  514. attention_mask = attention_mask.unsqueeze(1)
  515. for encoder_layer in self.layers:
  516. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  517. to_drop = False
  518. if self.training:
  519. dropout_probability = torch.rand([])
  520. if dropout_probability < self.layerdrop: # skip the layer
  521. to_drop = True
  522. if not to_drop:
  523. hidden_states = encoder_layer(
  524. hidden_states,
  525. attention_mask=attention_mask,
  526. position_embeddings=position_embeddings,
  527. **kwargs,
  528. )
  529. return ParakeetEncoderModelOutput(
  530. last_hidden_state=hidden_states,
  531. attention_mask=output_mask.int() if attention_mask is not None and output_attention_mask else None,
  532. )
  533. @dataclass
  534. class ParakeetGenerateOutput(ModelOutput):
  535. """
  536. Outputs of Parakeet models.
  537. Args:
  538. sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  539. The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
  540. if all batches finished early due to the `eos_token_id`.
  541. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
  542. Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  543. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  544. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  545. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  546. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  547. `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
  548. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
  549. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  550. `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
  551. """
  552. sequences: torch.LongTensor
  553. logits: tuple[torch.FloatTensor] | None = None
  554. attentions: tuple[tuple[torch.FloatTensor]] | None = None
  555. hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
  556. @auto_docstring(
  557. custom_intro="""
  558. Parakeet Encoder with a Connectionist Temporal Classification (CTC) head.
  559. """
  560. )
  561. class ParakeetForCTC(ParakeetPreTrainedModel):
  562. config: ParakeetCTCConfig
  563. def __init__(self, config: ParakeetCTCConfig):
  564. super().__init__(config)
  565. self.encoder = ParakeetEncoder(config.encoder_config)
  566. # Conv rather than linear to be consistent with NeMO decoding layer
  567. self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
  568. self.post_init()
  569. @auto_docstring
  570. @can_return_tuple
  571. def forward(
  572. self,
  573. input_features: torch.Tensor,
  574. attention_mask: torch.Tensor | None = None,
  575. labels: torch.Tensor | None = None,
  576. **kwargs: Unpack[TransformersKwargs],
  577. ) -> CausalLMOutput:
  578. r"""
  579. Example:
  580. ```python
  581. >>> from transformers import AutoProcessor, ParakeetForCTC
  582. >>> from datasets import load_dataset, Audio
  583. >>> model_id = "nvidia/parakeet-ctc-1.1b"
  584. >>> processor = AutoProcessor.from_pretrained(model_id)
  585. >>> model = ParakeetForCTC.from_pretrained(model_id)
  586. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  587. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  588. >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
  589. >>> outputs = model(**inputs)
  590. >>> print(outputs.loss)
  591. ```"""
  592. encoder_outputs = self.encoder(
  593. input_features=input_features,
  594. attention_mask=attention_mask,
  595. **kwargs,
  596. )
  597. hidden_states = encoder_outputs.last_hidden_state
  598. logits = self.ctc_head(hidden_states.transpose(1, 2)).transpose(1, 2)
  599. loss = None
  600. if labels is not None:
  601. # retrieve loss input_lengths from attention_mask
  602. attention_mask = (
  603. attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
  604. )
  605. input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
  606. # assuming that padded tokens are filled with -100
  607. # when not being attended to
  608. labels_mask = labels != self.config.pad_token_id
  609. target_lengths = labels_mask.sum(-1)
  610. flattened_targets = labels.masked_select(labels_mask)
  611. # ctc_loss doesn't support fp16
  612. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  613. with torch.backends.cudnn.flags(enabled=False):
  614. loss = nn.functional.ctc_loss(
  615. log_probs,
  616. flattened_targets,
  617. input_lengths,
  618. target_lengths,
  619. blank=self.config.pad_token_id,
  620. reduction=self.config.ctc_loss_reduction,
  621. zero_infinity=self.config.ctc_zero_infinity,
  622. )
  623. return CausalLMOutput(
  624. loss=loss,
  625. logits=logits,
  626. hidden_states=encoder_outputs.hidden_states,
  627. attentions=encoder_outputs.attentions,
  628. )
  629. @torch.no_grad()
  630. def generate(
  631. self,
  632. input_features: torch.Tensor,
  633. attention_mask: torch.Tensor | None = None,
  634. return_dict_in_generate: bool = False,
  635. **kwargs: Unpack[TransformersKwargs],
  636. ) -> ParakeetGenerateOutput | torch.LongTensor:
  637. r"""
  638. Example:
  639. ```python
  640. >>> from transformers import AutoProcessor, ParakeetForCTC
  641. >>> from datasets import load_dataset, Audio
  642. >>> model_id = "nvidia/parakeet-ctc-1.1b"
  643. >>> processor = AutoProcessor.from_pretrained(model_id)
  644. >>> model = ParakeetForCTC.from_pretrained(model_id)
  645. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  646. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  647. >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
  648. >>> predicted_ids = model.generate(**inputs)
  649. >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
  650. >>> print(transcription)
  651. ```
  652. """
  653. kwargs["return_dict"] = True
  654. outputs: CausalLMOutput = self.forward(
  655. input_features=input_features,
  656. attention_mask=attention_mask,
  657. **kwargs,
  658. )
  659. # greedy decoding
  660. sequences = outputs.logits.argmax(dim=-1)
  661. # mask out padded tokens
  662. if attention_mask is not None:
  663. attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequences.shape[1])
  664. sequences[~attention_mask] = self.config.pad_token_id
  665. if return_dict_in_generate:
  666. return ParakeetGenerateOutput(
  667. sequences=sequences,
  668. logits=outputs.logits,
  669. attentions=outputs.attentions,
  670. hidden_states=outputs.hidden_states,
  671. )
  672. return sequences
  673. __all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"]