modeling_lasr.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/lasr/modular_lasr.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_lasr.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The HuggingFace Inc. team and Google LLC. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from dataclasses import dataclass
  22. from typing import Optional
  23. import torch
  24. from torch import nn
  25. from ...activations import ACT2FN
  26. from ...integrations import use_kernel_func_from_hub, use_kernelized_func
  27. from ...masking_utils import create_bidirectional_mask
  28. from ...modeling_layers import GradientCheckpointingLayer
  29. from ...modeling_outputs import BaseModelOutput, CausalLMOutput
  30. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  31. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  32. from ...processing_utils import Unpack
  33. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
  34. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  35. from ...utils.output_capturing import capture_outputs
  36. from .configuration_lasr import LasrCTCConfig, LasrEncoderConfig
  37. class LasrEncoderSubsampling(nn.Module):
  38. def __init__(self, config: LasrEncoderConfig):
  39. super().__init__()
  40. self.dense_0 = nn.Linear(config.num_mel_bins, config.hidden_size)
  41. self.conv_0 = nn.Conv1d(
  42. config.hidden_size,
  43. config.hidden_size,
  44. kernel_size=config.subsampling_conv_kernel_size,
  45. stride=config.subsampling_conv_stride,
  46. )
  47. self.conv_1 = nn.Conv1d(
  48. config.hidden_size,
  49. config.subsampling_conv_channels,
  50. kernel_size=config.subsampling_conv_kernel_size,
  51. stride=config.subsampling_conv_stride,
  52. )
  53. self.dense_1 = nn.Linear(config.subsampling_conv_channels, config.hidden_size)
  54. self.act_fn = nn.ReLU()
  55. def forward(self, input_features: torch.Tensor) -> torch.Tensor:
  56. hidden_states = self.act_fn(self.dense_0(input_features))
  57. hidden_states = hidden_states.transpose(1, 2)
  58. hidden_states = self.act_fn(self.conv_0(hidden_states))
  59. hidden_states = self.act_fn(self.conv_1(hidden_states))
  60. hidden_states = hidden_states.transpose(1, 2)
  61. return self.dense_1(hidden_states)
  62. class LasrEncoderRotaryEmbedding(nn.Module):
  63. inv_freq: torch.Tensor # fix linting for `register_buffer`
  64. def __init__(self, config: LasrEncoderConfig, device=None):
  65. super().__init__()
  66. self.max_seq_len_cached = config.max_position_embeddings
  67. self.original_max_seq_len = config.max_position_embeddings
  68. self.config = config
  69. self.rope_type = self.config.rope_parameters["rope_type"]
  70. rope_init_fn: Callable = self.compute_default_rope_parameters
  71. if self.rope_type != "default":
  72. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  73. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  74. self.register_buffer("inv_freq", inv_freq, persistent=False)
  75. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  76. @staticmethod
  77. def compute_default_rope_parameters(
  78. config: LasrEncoderConfig | None = None,
  79. device: Optional["torch.device"] = None,
  80. seq_len: int | None = None,
  81. ) -> tuple["torch.Tensor", float]:
  82. """
  83. Computes the inverse frequencies according to the original RoPE implementation
  84. Args:
  85. config ([`~transformers.PreTrainedConfig`]):
  86. The model configuration.
  87. device (`torch.device`):
  88. The device to use for initialization of the inverse frequencies.
  89. seq_len (`int`, *optional*):
  90. The current sequence length. Unused for this type of RoPE.
  91. Returns:
  92. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  93. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  94. """
  95. base = config.rope_parameters["rope_theta"]
  96. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  97. attention_factor = 1.0 # Unused in this type of RoPE
  98. # Compute the inverse frequencies
  99. inv_freq = 1.0 / (
  100. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  101. )
  102. return inv_freq, attention_factor
  103. @torch.no_grad()
  104. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  105. def forward(self, x, position_ids):
  106. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  107. position_ids_expanded = position_ids[:, None, :].float()
  108. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  109. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  110. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  111. emb = torch.cat((freqs, freqs), dim=-1)
  112. cos = emb.cos() * self.attention_scaling
  113. sin = emb.sin() * self.attention_scaling
  114. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  115. def rotate_half(x):
  116. """Rotates half the hidden dims of the input."""
  117. x1 = x[..., : x.shape[-1] // 2]
  118. x2 = x[..., x.shape[-1] // 2 :]
  119. return torch.cat((-x2, x1), dim=-1)
  120. @use_kernel_func_from_hub("rotary_pos_emb")
  121. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  122. """Applies Rotary Position Embedding to the query and key tensors.
  123. Args:
  124. q (`torch.Tensor`): The query tensor.
  125. k (`torch.Tensor`): The key tensor.
  126. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  127. sin (`torch.Tensor`): The sine part of the rotary embedding.
  128. unsqueeze_dim (`int`, *optional*, defaults to 1):
  129. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  130. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  131. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  132. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  133. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  134. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  135. Returns:
  136. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  137. """
  138. cos = cos.unsqueeze(unsqueeze_dim)
  139. sin = sin.unsqueeze(unsqueeze_dim)
  140. q_embed = (q * cos) + (rotate_half(q) * sin)
  141. k_embed = (k * cos) + (rotate_half(k) * sin)
  142. return q_embed, k_embed
  143. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  144. """
  145. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  146. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  147. """
  148. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  149. if n_rep == 1:
  150. return hidden_states
  151. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  152. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  153. def eager_attention_forward(
  154. module: nn.Module,
  155. query: torch.Tensor,
  156. key: torch.Tensor,
  157. value: torch.Tensor,
  158. attention_mask: torch.Tensor | None,
  159. scaling: float,
  160. dropout: float = 0.0,
  161. **kwargs: Unpack[TransformersKwargs],
  162. ):
  163. key_states = repeat_kv(key, module.num_key_value_groups)
  164. value_states = repeat_kv(value, module.num_key_value_groups)
  165. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  166. if attention_mask is not None:
  167. attn_weights = attn_weights + attention_mask
  168. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  169. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  170. attn_output = torch.matmul(attn_weights, value_states)
  171. attn_output = attn_output.transpose(1, 2).contiguous()
  172. return attn_output, attn_weights
  173. @use_kernelized_func(apply_rotary_pos_emb)
  174. class LasrEncoderAttention(nn.Module):
  175. """Multi-headed attention from 'Attention Is All You Need' paper"""
  176. def __init__(self, config: LasrEncoderConfig, layer_idx: int):
  177. super().__init__()
  178. self.config = config
  179. self.layer_idx = layer_idx
  180. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  181. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  182. self.scaling = self.head_dim**-0.5
  183. self.attention_dropout = config.attention_dropout
  184. self.is_causal = False
  185. self.q_proj = nn.Linear(
  186. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  187. )
  188. self.k_proj = nn.Linear(
  189. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  190. )
  191. self.v_proj = nn.Linear(
  192. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  193. )
  194. self.o_proj = nn.Linear(
  195. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  196. )
  197. def forward(
  198. self,
  199. hidden_states: torch.Tensor,
  200. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  201. attention_mask: torch.Tensor | None = None,
  202. **kwargs: Unpack[TransformersKwargs],
  203. ) -> tuple[torch.Tensor, torch.Tensor]:
  204. input_shape = hidden_states.shape[:-1]
  205. hidden_shape = (*input_shape, -1, self.head_dim)
  206. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  207. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  208. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  209. cos, sin = position_embeddings
  210. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  211. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  212. self.config._attn_implementation, eager_attention_forward
  213. )
  214. attn_output, attn_weights = attention_interface(
  215. self,
  216. query_states,
  217. key_states,
  218. value_states,
  219. attention_mask,
  220. dropout=0.0 if not self.training else self.attention_dropout,
  221. scaling=self.scaling,
  222. **kwargs,
  223. )
  224. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  225. attn_output = self.o_proj(attn_output)
  226. return attn_output, attn_weights
  227. class LasrEncoderConvolutionModule(nn.Module):
  228. def __init__(self, config: LasrEncoderConfig, module_config=None):
  229. """
  230. Args:
  231. config (LasrEncoderConfig): Configuration for the model.
  232. module_config (dict): Configuration for the module (e.g., encoder or decoder).
  233. """
  234. super().__init__()
  235. channels = config.hidden_size
  236. # kernel_size should be an odd number for 'SAME' padding
  237. if module_config is None:
  238. # e.g. using `LasrEncoderEncoderConfig` in src/transformers/models/lasr_encoder/configuration_lasr_encoder.py
  239. kernel_size = config.conv_kernel_size
  240. self.activation = ACT2FN[getattr(config, "hidden_act", "silu")]
  241. else:
  242. kernel_size = module_config["kernel_size"]
  243. self.activation = ACT2FN[module_config.get("activation", "silu")]
  244. self.padding = "same"
  245. self.pointwise_conv1 = nn.Conv1d(
  246. channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
  247. )
  248. self.depthwise_conv = nn.Conv1d(
  249. channels,
  250. channels,
  251. kernel_size,
  252. stride=1,
  253. padding=self.padding,
  254. groups=channels,
  255. bias=config.convolution_bias,
  256. )
  257. self.norm = nn.BatchNorm1d(config.hidden_size, momentum=config.batch_norm_momentum)
  258. self.pointwise_conv2 = nn.Conv1d(
  259. channels, channels, kernel_size=1, stride=1, padding=0, bias=config.convolution_bias
  260. )
  261. def forward(self, hidden_states, attention_mask=None):
  262. """
  263. Compute convolution module.
  264. Args:
  265. hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
  266. attention_mask (`torch.Tensor` of shape `(batch, 1, time, time)`): Attention mask.
  267. Returns:
  268. `torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
  269. """
  270. # exchange the temporal dimension and the feature dimension
  271. hidden_states = hidden_states.transpose(1, 2)
  272. # GLU mechanism, (batch_size, 2*channel, dim)
  273. hidden_states = self.pointwise_conv1(hidden_states)
  274. # (batch_size, channel, dim)
  275. hidden_states = nn.functional.glu(hidden_states, dim=1)
  276. # Apply padding mask before convolution
  277. if attention_mask is not None:
  278. if attention_mask.dtype == torch.bool:
  279. all_masked_rows = torch.all(~attention_mask, dim=2)
  280. else:
  281. all_masked_rows = torch.all(~(attention_mask == 0.0), dim=2)
  282. hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
  283. # 1D Depthwise Conv
  284. hidden_states = self.depthwise_conv(hidden_states)
  285. hidden_states = self.norm(hidden_states)
  286. hidden_states = self.activation(hidden_states)
  287. hidden_states = self.pointwise_conv2(hidden_states)
  288. return hidden_states.transpose(1, 2)
  289. class LasrEncoderFeedForward(nn.Module):
  290. def __init__(self, config: LasrEncoderConfig):
  291. super().__init__()
  292. self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.attention_bias)
  293. self.activation = ACT2FN[config.hidden_act]
  294. self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.attention_bias)
  295. self.activation_dropout = config.activation_dropout
  296. def forward(self, hidden_states):
  297. hidden_states = self.activation(self.linear1(hidden_states))
  298. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  299. hidden_states = self.linear2(hidden_states)
  300. return hidden_states
  301. class LasrEncoderBlock(GradientCheckpointingLayer):
  302. def __init__(self, config: LasrEncoderConfig, layer_idx: int):
  303. super().__init__()
  304. self.gradient_checkpointing = False
  305. self.feed_forward1 = LasrEncoderFeedForward(config)
  306. self.self_attn = LasrEncoderAttention(config, layer_idx)
  307. self.conv = LasrEncoderConvolutionModule(config)
  308. self.feed_forward2 = LasrEncoderFeedForward(config)
  309. self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
  310. self.norm_self_att = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
  311. self.norm_conv = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
  312. self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
  313. self.norm_out = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, bias=False)
  314. self.feed_forward_residual_weights = config.feed_forward_residual_weights
  315. self.conv_residual_weights = config.conv_residual_weights
  316. def forward(
  317. self,
  318. hidden_states: torch.Tensor,
  319. attention_mask: torch.Tensor | None = None,
  320. position_embeddings: torch.Tensor | None = None,
  321. **kwargs: Unpack[TransformersKwargs],
  322. ) -> torch.Tensor:
  323. residual = hidden_states
  324. hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
  325. hidden_states = (
  326. self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
  327. )
  328. normalized_hidden_states = self.norm_self_att(hidden_states)
  329. attn_output, _ = self.self_attn(
  330. hidden_states=normalized_hidden_states,
  331. attention_mask=attention_mask,
  332. position_embeddings=position_embeddings,
  333. **kwargs,
  334. )
  335. hidden_states = hidden_states + attn_output
  336. conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
  337. hidden_states = self.conv_residual_weights[0] * hidden_states + self.conv_residual_weights[1] * conv_output
  338. residual = hidden_states
  339. hidden_states = self.feed_forward2(self.norm_feed_forward2(hidden_states))
  340. hidden_states = (
  341. self.feed_forward_residual_weights[0] * residual + self.feed_forward_residual_weights[1] * hidden_states
  342. )
  343. hidden_states = self.norm_out(hidden_states)
  344. return hidden_states
  345. @auto_docstring
  346. class LasrPreTrainedModel(PreTrainedModel):
  347. config: LasrCTCConfig
  348. base_model_prefix = "model"
  349. main_input_name = "input_features"
  350. input_modalities = "audio"
  351. supports_gradient_checkpointing = True
  352. _no_split_modules = ["LasrEncoderBlock"]
  353. _supports_flat_attention_mask = True
  354. _supports_sdpa = True
  355. # padding is incompatible with flex attention as the resulting mask cannot be used to apply padding
  356. _supports_flex_attn = False
  357. # TODO: @eustlb, add support when flash attention supports custom attention bias
  358. _supports_flash_attn = False
  359. _can_compile_fullgraph = True
  360. _supports_attention_backend = True
  361. _can_record_outputs = {
  362. "hidden_states": LasrEncoderBlock,
  363. "attentions": LasrEncoderAttention,
  364. }
  365. @torch.no_grad()
  366. def _init_weights(self, module):
  367. super()._init_weights(module)
  368. def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
  369. encoder_config = self.config.encoder_config if isinstance(self.config, LasrCTCConfig) else self.config
  370. kernel_size = encoder_config.subsampling_conv_kernel_size
  371. stride = encoder_config.subsampling_conv_stride
  372. num_layers = 2
  373. for _ in range(num_layers):
  374. input_lengths = (input_lengths - kernel_size) // stride + 1
  375. return input_lengths
  376. def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: int | None = None):
  377. """
  378. Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
  379. when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
  380. """
  381. output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
  382. # Use target_length if provided, otherwise use max length in batch
  383. max_length = target_length if target_length is not None else output_lengths.max()
  384. attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None]
  385. return attention_mask
  386. @auto_docstring(
  387. custom_intro="""
  388. The LasrEncoder model, based on the Conformer architecture](https://arxiv.org/abs/2005.08100).
  389. """
  390. )
  391. class LasrEncoder(LasrPreTrainedModel):
  392. config: LasrEncoderConfig
  393. base_model_prefix = "encoder"
  394. def __init__(self, config: LasrEncoderConfig):
  395. super().__init__(config)
  396. self.gradient_checkpointing = False
  397. self.dropout = config.dropout
  398. self.dropout_positions = config.dropout_positions
  399. self.layerdrop = config.layerdrop
  400. self.subsampler = LasrEncoderSubsampling(config)
  401. self.rotary_emb = LasrEncoderRotaryEmbedding(config)
  402. self.layers = nn.ModuleList(
  403. [LasrEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  404. )
  405. self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)
  406. self.post_init()
  407. @auto_docstring
  408. @merge_with_config_defaults
  409. @capture_outputs
  410. @can_return_tuple
  411. def forward(
  412. self,
  413. input_features: torch.Tensor,
  414. attention_mask: torch.Tensor | None = None,
  415. **kwargs: Unpack[TransformersKwargs],
  416. ) -> BaseModelOutput:
  417. r"""
  418. Example:
  419. ```python
  420. >>> from transformers import AutoProcessor, LasrEncoder
  421. >>> from datasets import load_dataset, Audio
  422. >>> model_id = TODO
  423. >>> processor = AutoProcessor.from_pretrained(model_id)
  424. >>> encoder = ParakeetEncoder.from_pretrained(model_id)
  425. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  426. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  427. >>> inputs = processor(ds[0]["audio"]["array"])
  428. >>> encoder_outputs = encoder(**inputs)
  429. >>> print(encoder_outputs.last_hidden_state.shape)
  430. ```
  431. """
  432. hidden_states = self.subsampler(input_features)
  433. cos, sin = self.rotary_emb(
  434. hidden_states, torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
  435. )
  436. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  437. cos = nn.functional.dropout(cos, p=self.dropout_positions, training=self.training)
  438. sin = nn.functional.dropout(sin, p=self.dropout_positions, training=self.training)
  439. if attention_mask is not None:
  440. attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
  441. attention_mask = create_bidirectional_mask(
  442. config=self.config,
  443. inputs_embeds=hidden_states,
  444. attention_mask=attention_mask,
  445. )
  446. for encoder_layer in self.layers:
  447. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  448. to_drop = False
  449. if self.training:
  450. dropout_probability = torch.rand([])
  451. if dropout_probability < self.layerdrop: # skip the layer
  452. to_drop = True
  453. if not to_drop:
  454. hidden_states = encoder_layer(
  455. hidden_states,
  456. attention_mask=attention_mask,
  457. position_embeddings=(cos, sin),
  458. **kwargs,
  459. )
  460. hidden_states = self.out_norm(hidden_states)
  461. return BaseModelOutput(last_hidden_state=hidden_states)
  462. @dataclass
  463. class LasrGenerateOutput(ModelOutput):
  464. """
  465. Outputs of Lasr models.
  466. Args:
  467. sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  468. The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
  469. if all batches finished early due to the `eos_token_id`.
  470. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
  471. Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  472. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  473. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  474. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  475. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  476. `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
  477. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
  478. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  479. `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
  480. """
  481. sequences: torch.LongTensor
  482. logits: tuple[torch.FloatTensor] | None = None
  483. attentions: tuple[tuple[torch.FloatTensor]] | None = None
  484. hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
  485. @auto_docstring(
  486. custom_intro="""
  487. Lasr Encoder with a Connectionist Temporal Classification (CTC) head.
  488. """
  489. )
  490. class LasrForCTC(LasrPreTrainedModel):
  491. config: LasrCTCConfig
  492. def __init__(self, config: LasrCTCConfig):
  493. super().__init__(config)
  494. self.encoder = LasrEncoder(config.encoder_config)
  495. # Conv rather than linear to be consistent with NeMO decoding layer
  496. self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
  497. self.post_init()
  498. @auto_docstring
  499. @can_return_tuple
  500. def forward(
  501. self,
  502. input_features: torch.Tensor,
  503. attention_mask: torch.Tensor | None = None,
  504. labels: torch.Tensor | None = None,
  505. **kwargs: Unpack[TransformersKwargs],
  506. ) -> CausalLMOutput:
  507. r"""
  508. Example:
  509. ```python
  510. >>> from transformers import AutoProcessor, LasrForCTC
  511. >>> from datasets import load_dataset, Audio
  512. >>> model_id = "nvidia/lasr-ctc-1.1b"
  513. >>> processor = AutoProcessor.from_pretrained(model_id)
  514. >>> model = LasrForCTC.from_pretrained(model_id)
  515. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  516. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  517. >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
  518. >>> outputs = model(**inputs)
  519. >>> print(outputs.loss)
  520. ```"""
  521. encoder_outputs = self.encoder(
  522. input_features=input_features,
  523. attention_mask=attention_mask,
  524. **kwargs,
  525. )
  526. hidden_states = encoder_outputs.last_hidden_state
  527. logits = self.ctc_head(hidden_states.transpose(1, 2)).transpose(1, 2)
  528. loss = None
  529. if labels is not None:
  530. # retrieve loss input_lengths from attention_mask
  531. attention_mask = (
  532. attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
  533. )
  534. input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
  535. # assuming that padded tokens are filled with -100
  536. # when not being attended to
  537. labels_mask = labels != self.config.pad_token_id
  538. target_lengths = labels_mask.sum(-1)
  539. flattened_targets = labels.masked_select(labels_mask)
  540. # ctc_loss doesn't support fp16
  541. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  542. with torch.backends.cudnn.flags(enabled=False):
  543. loss = nn.functional.ctc_loss(
  544. log_probs,
  545. flattened_targets,
  546. input_lengths,
  547. target_lengths,
  548. blank=self.config.pad_token_id,
  549. reduction=self.config.ctc_loss_reduction,
  550. zero_infinity=self.config.ctc_zero_infinity,
  551. )
  552. return CausalLMOutput(
  553. loss=loss,
  554. logits=logits,
  555. hidden_states=encoder_outputs.hidden_states,
  556. attentions=encoder_outputs.attentions,
  557. )
  558. @torch.no_grad()
  559. def generate(
  560. self,
  561. input_features: torch.Tensor,
  562. attention_mask: torch.Tensor | None = None,
  563. return_dict_in_generate: bool = False,
  564. **kwargs: Unpack[TransformersKwargs],
  565. ) -> LasrGenerateOutput | torch.LongTensor:
  566. r"""
  567. Example:
  568. ```python
  569. >>> from transformers import AutoProcessor, LasrForCTC
  570. >>> from datasets import load_dataset, Audio
  571. >>> model_id = TODO
  572. >>> processor = AutoProcessor.from_pretrained(model_id)
  573. >>> model = LasrForCTC.from_pretrained(model_id)
  574. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  575. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  576. >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
  577. >>> predicted_ids = model.generate(**inputs)
  578. >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
  579. >>> print(transcription)
  580. ```
  581. """
  582. kwargs["return_dict"] = True
  583. outputs: CausalLMOutput = self.forward(
  584. input_features=input_features,
  585. attention_mask=attention_mask,
  586. **kwargs,
  587. )
  588. # greedy decoding
  589. sequences = outputs.logits.argmax(dim=-1)
  590. # mask out padded tokens
  591. if attention_mask is not None:
  592. attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequences.shape[1])
  593. sequences[~attention_mask] = self.config.pad_token_id
  594. if return_dict_in_generate:
  595. return LasrGenerateOutput(
  596. sequences=sequences,
  597. logits=outputs.logits,
  598. attentions=outputs.attentions,
  599. hidden_states=outputs.hidden_states,
  600. )
  601. return sequences
  602. __all__ = ["LasrForCTC", "LasrEncoder", "LasrPreTrainedModel"]