modeling_moonshine.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/moonshine/modular_moonshine.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_moonshine.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from dataclasses import dataclass
  22. from typing import Optional
  23. import torch
  24. import torch.nn as nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernelized_func
  29. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  30. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import (
  33. BaseModelOutput,
  34. BaseModelOutputWithPast,
  35. BaseModelOutputWithPastAndCrossAttentions,
  36. Seq2SeqLMOutput,
  37. Seq2SeqModelOutput,
  38. )
  39. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  40. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  41. from ...processing_utils import Unpack
  42. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  43. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  44. from ...utils.output_capturing import OutputRecorder, capture_outputs
  45. from .configuration_moonshine import MoonshineConfig
  46. @dataclass
  47. @auto_docstring(
  48. custom_intro="""
  49. Extends [~modeling_outputs.BaseModelOutput] to include the output attention mask since sequence length is not preserved in the model's forward.
  50. """
  51. )
  52. class MoonshineEncoderModelOutput(BaseModelOutput):
  53. attention_mask: torch.Tensor | None = None
  54. class MoonshineEncoderMLP(nn.Module):
  55. def __init__(self, config, hidden_act):
  56. super().__init__()
  57. self.config = config
  58. self.activation_fn = ACT2FN[hidden_act]
  59. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  60. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  61. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  62. hidden_states = self.fc1(hidden_states)
  63. hidden_states = self.activation_fn(hidden_states)
  64. hidden_states = self.fc2(hidden_states)
  65. return hidden_states
  66. class MoonshineDecoderMLP(nn.Module):
  67. def __init__(self, config, hidden_act):
  68. super().__init__()
  69. self.config = config
  70. self.activation_fn = ACT2FN[hidden_act]
  71. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size * 2)
  72. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  73. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  74. hidden_states = self.fc1(hidden_states)
  75. hidden_states, gate = hidden_states.chunk(2, dim=-1)
  76. hidden_states = self.activation_fn(gate) * hidden_states
  77. hidden_states = self.fc2(hidden_states)
  78. return hidden_states
  79. class MoonshineRotaryEmbedding(nn.Module):
  80. inv_freq: torch.Tensor # fix linting for `register_buffer`
  81. def __init__(self, config: MoonshineConfig, device=None):
  82. super().__init__()
  83. self.max_seq_len_cached = config.max_position_embeddings
  84. self.original_max_seq_len = config.max_position_embeddings
  85. self.config = config
  86. self.rope_type = self.config.rope_parameters["rope_type"]
  87. rope_init_fn: Callable = self.compute_default_rope_parameters
  88. if self.rope_type != "default":
  89. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  90. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  91. self.register_buffer("inv_freq", inv_freq, persistent=False)
  92. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  93. @staticmethod
  94. def compute_default_rope_parameters(
  95. config: MoonshineConfig | None = None,
  96. device: Optional["torch.device"] = None,
  97. seq_len: int | None = None,
  98. ) -> tuple["torch.Tensor", float]:
  99. """
  100. Computes the inverse frequencies according to the original RoPE implementation
  101. Args:
  102. config ([`~transformers.PreTrainedConfig`]):
  103. The model configuration.
  104. device (`torch.device`):
  105. The device to use for initialization of the inverse frequencies.
  106. seq_len (`int`, *optional*):
  107. The current sequence length. Unused for this type of RoPE.
  108. Returns:
  109. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  110. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  111. """
  112. base = config.rope_parameters["rope_theta"]
  113. partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
  114. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  115. dim = int(head_dim * partial_rotary_factor)
  116. attention_factor = 1.0 # Unused in this type of RoPE
  117. # Compute the inverse frequencies
  118. inv_freq = 1.0 / (
  119. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  120. )
  121. return inv_freq, attention_factor
  122. @torch.no_grad()
  123. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  124. def forward(self, x, position_ids):
  125. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  126. position_ids_expanded = position_ids[:, None, :].float()
  127. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  128. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  129. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  130. emb = torch.cat((freqs, freqs), dim=-1)
  131. cos = emb.cos() * self.attention_scaling
  132. sin = emb.sin() * self.attention_scaling
  133. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  134. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  135. """
  136. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  137. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  138. """
  139. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  140. if n_rep == 1:
  141. return hidden_states
  142. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  143. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  144. def eager_attention_forward(
  145. module: nn.Module,
  146. query: torch.Tensor,
  147. key: torch.Tensor,
  148. value: torch.Tensor,
  149. attention_mask: torch.Tensor | None,
  150. scaling: float,
  151. dropout: float = 0.0,
  152. **kwargs: Unpack[TransformersKwargs],
  153. ):
  154. key_states = repeat_kv(key, module.num_key_value_groups)
  155. value_states = repeat_kv(value, module.num_key_value_groups)
  156. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  157. if attention_mask is not None:
  158. attn_weights = attn_weights + attention_mask
  159. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  160. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  161. attn_output = torch.matmul(attn_weights, value_states)
  162. attn_output = attn_output.transpose(1, 2).contiguous()
  163. return attn_output, attn_weights
  164. def rotate_half(x):
  165. """Rotates half the hidden dims of the input."""
  166. x1 = x[..., 0::2]
  167. x2 = x[..., 1::2]
  168. return torch.stack((-x2, x1), dim=-1).flatten(-2)
  169. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  170. """Applies Rotary Position Embedding to the query and key tensors.
  171. Args:
  172. q (`torch.Tensor`): The query tensor.
  173. k (`torch.Tensor`): The key tensor.
  174. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  175. sin (`torch.Tensor`): The sine part of the rotary embedding.
  176. unsqueeze_dim (`int`, *optional*, defaults to 1):
  177. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  178. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  179. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  180. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  181. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  182. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  183. Returns:
  184. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  185. """
  186. cos = cos.unsqueeze(unsqueeze_dim)
  187. sin = sin.unsqueeze(unsqueeze_dim)
  188. # Interleave them instead of usual shape
  189. cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
  190. sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
  191. # Keep half or full tensor for later concatenation
  192. rotary_dim = cos.shape[-1]
  193. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  194. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  195. # Apply rotary embeddings on the first half or full tensor
  196. q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
  197. k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
  198. # Concatenate back to full shape
  199. q_embed = torch.cat([q_embed, q_pass], dim=-1)
  200. k_embed = torch.cat([k_embed, k_pass], dim=-1)
  201. return q_embed, k_embed
  202. @use_kernelized_func(apply_rotary_pos_emb)
  203. class MoonshineAttention(nn.Module):
  204. """Multi-headed attention from 'Attention Is All You Need' paper"""
  205. def __init__(
  206. self,
  207. config: MoonshineConfig,
  208. layer_idx: int,
  209. is_causal: bool,
  210. num_attention_heads: int,
  211. num_key_value_heads: int,
  212. ):
  213. super().__init__()
  214. config.update({"num_attention_heads": num_attention_heads, "num_key_value_heads": num_key_value_heads})
  215. self.config = config
  216. self.layer_idx = layer_idx
  217. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  218. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  219. self.scaling = self.head_dim**-0.5
  220. self.attention_dropout = config.attention_dropout
  221. self.is_causal = is_causal
  222. self.q_proj = nn.Linear(
  223. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  224. )
  225. self.k_proj = nn.Linear(
  226. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  227. )
  228. self.v_proj = nn.Linear(
  229. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  230. )
  231. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  232. # Pad head dimension to the next specified multiple.
  233. if self.config.pad_head_dim_to_multiple_of is not None:
  234. target_multiple = self.config.pad_head_dim_to_multiple_of
  235. target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple)
  236. self.head_dim_padding = target_head_dim - self.head_dim
  237. else:
  238. self.head_dim_padding = 0
  239. def forward(
  240. self,
  241. hidden_states: torch.Tensor,
  242. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  243. attention_mask: torch.Tensor | None = None,
  244. past_key_values: Cache | None = None,
  245. key_value_states: torch.Tensor | None = None,
  246. **kwargs: Unpack[FlashAttentionKwargs],
  247. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  248. bsz, q_len = hidden_states.shape[:-1]
  249. query_states = (
  250. self.q_proj(hidden_states).view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2)
  251. )
  252. is_cross_attention = key_value_states is not None
  253. if past_key_values is not None:
  254. is_updated = past_key_values.is_updated.get(self.layer_idx)
  255. if is_cross_attention:
  256. # after the first generated id, we can subsequently re-use all key/value_states from cache
  257. past_key_values.is_updated[self.layer_idx] = True
  258. past_key_values = past_key_values.cross_attention_cache
  259. else:
  260. past_key_values = past_key_values.self_attention_cache
  261. # use key_value_states if cross attention
  262. current_states = key_value_states if key_value_states is not None else hidden_states
  263. if is_cross_attention and past_key_values and is_updated:
  264. key_states = past_key_values.layers[self.layer_idx].keys
  265. value_states = past_key_values.layers[self.layer_idx].values
  266. else:
  267. key_states = (
  268. self.k_proj(current_states)
  269. .view(bsz, -1, self.config.num_key_value_heads, self.head_dim)
  270. .transpose(1, 2)
  271. )
  272. value_states = (
  273. self.v_proj(current_states)
  274. .view(bsz, -1, self.config.num_key_value_heads, self.head_dim)
  275. .transpose(1, 2)
  276. )
  277. if is_cross_attention and past_key_values is not None:
  278. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  279. if not is_cross_attention:
  280. cos, sin = position_embeddings
  281. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  282. if past_key_values is not None:
  283. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  284. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  285. self.config._attn_implementation, eager_attention_forward
  286. )
  287. is_causal = self.is_causal and attention_mask is None and q_len > 1
  288. if self.head_dim_padding > 0:
  289. query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding))
  290. key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding))
  291. value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding))
  292. attn_output, attn_weights = attention_interface(
  293. self,
  294. query_states,
  295. key_states,
  296. value_states,
  297. attention_mask,
  298. dropout=0.0 if not self.training else self.attention_dropout,
  299. scaling=self.scaling,
  300. is_causal=is_causal,
  301. **kwargs,
  302. )
  303. if self.head_dim_padding > 0:
  304. attn_output = attn_output[..., : -self.head_dim_padding]
  305. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  306. attn_output = self.o_proj(attn_output)
  307. return attn_output, attn_weights
  308. class MoonshineEncoderLayer(GradientCheckpointingLayer):
  309. def __init__(self, config: MoonshineConfig, layer_idx: int):
  310. super().__init__()
  311. self.hidden_size = config.hidden_size
  312. self.self_attn = MoonshineAttention(
  313. config=config,
  314. layer_idx=layer_idx,
  315. is_causal=False,
  316. num_attention_heads=config.encoder_num_attention_heads,
  317. num_key_value_heads=config.encoder_num_key_value_heads,
  318. )
  319. self.mlp = MoonshineEncoderMLP(config, config.encoder_hidden_act)
  320. self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  321. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  322. def forward(
  323. self,
  324. hidden_states: torch.Tensor,
  325. attention_mask: torch.Tensor | None = None,
  326. position_ids: torch.LongTensor | None = None,
  327. past_key_values: Cache | None = None,
  328. use_cache: bool | None = False,
  329. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  330. **kwargs: Unpack[TransformersKwargs],
  331. ) -> torch.Tensor:
  332. residual = hidden_states
  333. hidden_states = self.input_layernorm(hidden_states)
  334. # Self Attention
  335. hidden_states, _ = self.self_attn(
  336. hidden_states=hidden_states,
  337. attention_mask=attention_mask,
  338. position_ids=position_ids,
  339. past_key_values=past_key_values,
  340. use_cache=use_cache,
  341. position_embeddings=position_embeddings,
  342. **kwargs,
  343. )
  344. hidden_states = residual + hidden_states
  345. # Fully Connected
  346. residual = hidden_states
  347. hidden_states = self.post_attention_layernorm(hidden_states)
  348. hidden_states = self.mlp(hidden_states)
  349. hidden_states = residual + hidden_states
  350. return hidden_states
  351. class MoonshineDecoderLayer(GradientCheckpointingLayer):
  352. def __init__(self, config: MoonshineConfig, layer_idx: int | None = None):
  353. super().__init__()
  354. self.hidden_size = config.hidden_size
  355. self.self_attn = MoonshineAttention(
  356. config=config,
  357. layer_idx=layer_idx,
  358. is_causal=True,
  359. num_attention_heads=config.num_attention_heads,
  360. num_key_value_heads=config.num_key_value_heads,
  361. )
  362. self.encoder_attn = MoonshineAttention(
  363. config=config,
  364. layer_idx=layer_idx,
  365. is_causal=False,
  366. num_attention_heads=config.num_attention_heads,
  367. num_key_value_heads=config.num_key_value_heads,
  368. )
  369. self.mlp = MoonshineDecoderMLP(config, config.hidden_act)
  370. self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  371. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  372. self.final_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  373. def forward(
  374. self,
  375. hidden_states: torch.Tensor,
  376. attention_mask: torch.Tensor | None = None,
  377. encoder_hidden_states: torch.Tensor | None = None,
  378. encoder_attention_mask: torch.Tensor | None = None,
  379. position_ids: torch.LongTensor | None = None,
  380. encoder_position_ids: torch.LongTensor | None = None,
  381. past_key_values: Cache | None = None,
  382. use_cache: bool | None = False,
  383. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  384. encoder_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  385. **kwargs: Unpack[TransformersKwargs],
  386. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  387. residual = hidden_states
  388. hidden_states = self.input_layernorm(hidden_states)
  389. hidden_states, _ = self.self_attn(
  390. hidden_states=hidden_states,
  391. attention_mask=attention_mask,
  392. position_ids=position_ids,
  393. past_key_values=past_key_values,
  394. use_cache=use_cache,
  395. position_embeddings=position_embeddings,
  396. **kwargs,
  397. )
  398. hidden_states = residual + hidden_states
  399. if encoder_hidden_states is not None:
  400. residual = hidden_states
  401. hidden_states = self.post_attention_layernorm(hidden_states)
  402. hidden_states, _ = self.encoder_attn(
  403. hidden_states=hidden_states,
  404. key_value_states=encoder_hidden_states,
  405. attention_mask=encoder_attention_mask,
  406. past_key_values=past_key_values,
  407. use_cache=use_cache,
  408. )
  409. hidden_states = residual + hidden_states
  410. residual = hidden_states
  411. hidden_states = self.final_layernorm(hidden_states)
  412. hidden_states = self.mlp(hidden_states)
  413. hidden_states = residual + hidden_states
  414. return hidden_states
  415. @auto_docstring
  416. class MoonshinePreTrainedModel(PreTrainedModel):
  417. config: MoonshineConfig
  418. base_model_prefix = "model"
  419. main_input_name = "input_values"
  420. input_modalities = "audio"
  421. supports_gradient_checkpointing = True
  422. _no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"]
  423. _supports_flash_attn = True
  424. _supports_sdpa = True
  425. _can_compile_fullgraph = True
  426. # TODO arthur, how do we separate when it cross / self coming from different layer?
  427. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
  428. """
  429. Computes the output length of the convolutional layers
  430. """
  431. output_conv1_length = int((input_lengths - 127) / 64 + 1)
  432. output_conv2_length = int((output_conv1_length - 7) / 3 + 1)
  433. output_conv3_length = int((output_conv2_length - 3) / 2 + 1)
  434. return output_conv3_length
  435. class MoonshineEncoder(MoonshinePreTrainedModel):
  436. """
  437. Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoonshineEncoderLayer`]
  438. Args:
  439. config: MoonshineConfig
  440. """
  441. main_input_name = "input_values"
  442. _can_record_outputs = {
  443. "attentions": MoonshineAttention,
  444. "hidden_states": MoonshineEncoderLayer,
  445. }
  446. def __init__(self, config: MoonshineConfig):
  447. super().__init__(config)
  448. self.config = config
  449. embed_dim = config.hidden_size
  450. self.conv1 = nn.Conv1d(1, embed_dim, kernel_size=127, stride=64, bias=False)
  451. self.conv2 = nn.Conv1d(embed_dim, 2 * embed_dim, kernel_size=7, stride=3)
  452. self.conv3 = nn.Conv1d(2 * embed_dim, embed_dim, kernel_size=3, stride=2)
  453. self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=1e-5)
  454. self.layers = nn.ModuleList(
  455. [MoonshineEncoderLayer(config, idx) for idx in range(config.encoder_num_hidden_layers)]
  456. )
  457. self.layer_norm = nn.LayerNorm(embed_dim, bias=False)
  458. self.rotary_emb = MoonshineRotaryEmbedding(config=config)
  459. self.gradient_checkpointing = False
  460. self.post_init()
  461. def get_input_embeddings(self) -> nn.Module:
  462. return self.conv1
  463. def set_input_embeddings(self, value: nn.Module):
  464. self.conv1 = value
  465. @merge_with_config_defaults
  466. @capture_outputs
  467. def forward(
  468. self,
  469. input_values: torch.FloatTensor,
  470. attention_mask: torch.Tensor | None = None,
  471. **kwargs: Unpack[TransformersKwargs],
  472. ) -> tuple | BaseModelOutputWithPast:
  473. r"""
  474. Args:
  475. input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
  476. Float values of the raw speech waveform. Raw speech waveform can be
  477. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
  478. `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
  479. the soundfile library (`pip install soundfile`). To prepare the array into
  480. `input_values`, the [`AutoFeatureExtractor`] should be used for padding
  481. and conversion into a tensor of type `torch.FloatTensor`.
  482. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  483. Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
  484. - 1 for tokens that are **not masked**,
  485. - 0 for tokens that are **masked**.
  486. [What are attention masks?](../glossary#attention-mask)
  487. """
  488. input_values = input_values.unsqueeze(1)
  489. hidden_states = nn.functional.tanh(self.conv1(input_values))
  490. hidden_states = self.groupnorm(hidden_states)
  491. hidden_states = nn.functional.gelu(self.conv2(hidden_states))
  492. hidden_states = nn.functional.gelu(self.conv3(hidden_states))
  493. hidden_states = hidden_states.permute(0, 2, 1)
  494. # attention mask downsampling
  495. output_attention_mask = None
  496. if attention_mask is not None:
  497. mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1])
  498. downsample_stride = 64 * 3 * 2 # conv strides
  499. attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len]
  500. output_attention_mask = attention_mask
  501. attention_mask = create_bidirectional_mask(
  502. config=self.config,
  503. inputs_embeds=hidden_states,
  504. attention_mask=attention_mask,
  505. encoder_hidden_states=hidden_states,
  506. )
  507. position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
  508. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  509. for encoder_layer in self.layers:
  510. hidden_states = encoder_layer(
  511. hidden_states,
  512. attention_mask=attention_mask,
  513. position_ids=position_ids,
  514. position_embeddings=position_embeddings,
  515. **kwargs,
  516. )
  517. hidden_states = self.layer_norm(hidden_states)
  518. return MoonshineEncoderModelOutput(
  519. last_hidden_state=hidden_states,
  520. attention_mask=output_attention_mask.int() if output_attention_mask is not None else None,
  521. )
  522. @auto_docstring
  523. class MoonshineDecoder(MoonshinePreTrainedModel):
  524. main_input_name = "input_ids"
  525. _can_record_outputs = {
  526. "attentions": OutputRecorder(MoonshineAttention, index=1, layer_name="self_attn"),
  527. "hidden_states": MoonshineDecoderLayer,
  528. "cross_attentions": OutputRecorder(MoonshineAttention, index=1, layer_name="encoder_attn"),
  529. }
  530. def __init__(self, config: MoonshineConfig):
  531. super().__init__(config)
  532. self.padding_idx = config.pad_token_id
  533. self.vocab_size = config.vocab_size
  534. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  535. self.layers = nn.ModuleList([MoonshineDecoderLayer(config, idx) for idx in range(config.num_hidden_layers)])
  536. self.norm = nn.LayerNorm(config.hidden_size, bias=False)
  537. self.rotary_emb = MoonshineRotaryEmbedding(config=config)
  538. self.gradient_checkpointing = False
  539. # Initialize weights and apply final processing
  540. self.post_init()
  541. @merge_with_config_defaults
  542. @capture_outputs
  543. def forward(
  544. self,
  545. input_ids: torch.LongTensor | None = None,
  546. attention_mask: torch.Tensor | None = None,
  547. position_ids: torch.LongTensor | None = None,
  548. past_key_values: Cache | None = None,
  549. inputs_embeds: torch.FloatTensor | None = None,
  550. use_cache: bool | None = None,
  551. encoder_hidden_states: torch.FloatTensor | None = None,
  552. encoder_attention_mask: torch.Tensor | None = None,
  553. **kwargs: Unpack[TransformersKwargs],
  554. ) -> tuple | BaseModelOutputWithPast:
  555. r"""
  556. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  557. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  558. of the decoder.
  559. encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  560. Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
  561. - 1 for tokens that are **not masked**,
  562. - 0 for tokens that are **masked**.
  563. [What are attention masks?](../glossary#attention-mask)
  564. """
  565. if (input_ids is None) ^ (inputs_embeds is not None):
  566. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  567. if inputs_embeds is None:
  568. inputs_embeds = self.embed_tokens(input_ids)
  569. if use_cache and past_key_values is None:
  570. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  571. if position_ids is None:
  572. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  573. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  574. position_ids = position_ids.unsqueeze(0)
  575. causal_mask = create_causal_mask(
  576. config=self.config,
  577. inputs_embeds=inputs_embeds,
  578. attention_mask=attention_mask,
  579. past_key_values=past_key_values,
  580. position_ids=position_ids,
  581. )
  582. encoder_attention_mask = create_bidirectional_mask(
  583. config=self.config,
  584. inputs_embeds=inputs_embeds,
  585. attention_mask=encoder_attention_mask,
  586. encoder_hidden_states=encoder_hidden_states,
  587. )
  588. hidden_states = inputs_embeds
  589. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  590. for decoder_layer in self.layers:
  591. hidden_states = decoder_layer(
  592. hidden_states,
  593. causal_mask,
  594. encoder_hidden_states, # as a positional argument for gradient checkpointing
  595. encoder_attention_mask=encoder_attention_mask,
  596. position_ids=position_ids,
  597. past_key_values=past_key_values,
  598. use_cache=use_cache,
  599. position_embeddings=position_embeddings,
  600. **kwargs,
  601. )
  602. hidden_states = self.norm(hidden_states)
  603. return BaseModelOutputWithPastAndCrossAttentions(
  604. last_hidden_state=hidden_states,
  605. past_key_values=past_key_values if use_cache else None,
  606. )
  607. @auto_docstring
  608. class MoonshineModel(MoonshinePreTrainedModel):
  609. def __init__(self, config: MoonshineConfig):
  610. super().__init__(config)
  611. self.encoder = MoonshineEncoder(config)
  612. self.decoder = MoonshineDecoder(config)
  613. # Initialize weights and apply final processing
  614. self.post_init()
  615. def get_input_embeddings(self):
  616. return self.decoder.embed_tokens
  617. def set_input_embeddings(self, value):
  618. self.decoder.embed_tokens = value
  619. def freeze_encoder(self):
  620. """
  621. Calling this function will disable the gradient computation for the Moonshine encoder so that its parameters will
  622. not be updated during training.
  623. """
  624. self.encoder._freeze_parameters()
  625. def _mask_input_features(self):
  626. """
  627. Masks extracted features along time axis and/or along feature axis according to
  628. [SpecAugment](https://huggingface.co/papers/1904.08779).
  629. """
  630. raise AttributeError("Not needed for Moonshine")
  631. @can_return_tuple
  632. @auto_docstring
  633. def forward(
  634. self,
  635. input_values: torch.FloatTensor | None = None,
  636. attention_mask: torch.LongTensor | None = None,
  637. decoder_input_ids: torch.LongTensor | None = None,
  638. decoder_attention_mask: torch.LongTensor | None = None,
  639. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  640. past_key_values: EncoderDecoderCache | None = None,
  641. decoder_inputs_embeds: tuple[torch.FloatTensor] | None = None,
  642. decoder_position_ids: tuple[torch.LongTensor] | None = None,
  643. use_cache: bool | None = None,
  644. **kwargs: Unpack[TransformersKwargs],
  645. ) -> Seq2SeqModelOutput:
  646. r"""
  647. input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
  648. Float values of the raw speech waveform. Raw speech waveform can be
  649. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
  650. `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
  651. the soundfile library (`pip install soundfile`). To prepare the array into
  652. `input_values`, the [`AutoFeatureExtractor`] should be used for padding
  653. and conversion into a tensor of type `torch.FloatTensor`.
  654. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
  655. Indices of positions of each input sequence tokens in the position embeddings.
  656. Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`
  657. Example:
  658. ```python
  659. >>> import torch
  660. >>> from transformers import AutoFeatureExtractor, MoonshineModel
  661. >>> from datasets import load_dataset
  662. >>> model = MoonshineModel.from_pretrained("UsefulSensors/moonshine-tiny")
  663. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("UsefulSensors/moonshine-tiny")
  664. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  665. >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
  666. >>> input_values = inputs.input_values
  667. >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
  668. >>> last_hidden_state = model(input_values, decoder_input_ids=decoder_input_ids).last_hidden_state
  669. >>> list(last_hidden_state.shape)
  670. [1, 2, 288]
  671. ```
  672. """
  673. if encoder_outputs is None:
  674. encoder_outputs: BaseModelOutput = self.encoder(input_values, attention_mask=attention_mask, **kwargs)
  675. decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder(
  676. input_ids=decoder_input_ids,
  677. attention_mask=decoder_attention_mask,
  678. encoder_hidden_states=encoder_outputs.last_hidden_state,
  679. encoder_attention_mask=encoder_outputs.attention_mask,
  680. past_key_values=past_key_values,
  681. inputs_embeds=decoder_inputs_embeds,
  682. position_ids=decoder_position_ids,
  683. use_cache=use_cache,
  684. **kwargs,
  685. )
  686. return Seq2SeqModelOutput(
  687. last_hidden_state=decoder_outputs.last_hidden_state,
  688. past_key_values=decoder_outputs.past_key_values,
  689. decoder_hidden_states=decoder_outputs.hidden_states,
  690. decoder_attentions=decoder_outputs.attentions,
  691. cross_attentions=decoder_outputs.cross_attentions,
  692. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  693. encoder_hidden_states=encoder_outputs.hidden_states,
  694. encoder_attentions=encoder_outputs.attentions,
  695. )
  696. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  697. """
  698. Shift input ids one token to the right.
  699. """
  700. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  701. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  702. shifted_input_ids[:, 0] = decoder_start_token_id
  703. if pad_token_id is None:
  704. raise ValueError("self.model.config.pad_token_id has to be defined.")
  705. # replace possible -100 values in labels by `pad_token_id`
  706. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  707. return shifted_input_ids
  708. @auto_docstring(
  709. custom_intro="""
  710. The Moonshine Model with a language modeling head. Can be used for automatic speech recognition.
  711. """
  712. )
  713. class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixin):
  714. _tied_weights_keys = {"proj_out.weight": "model.decoder.embed_tokens.weight"}
  715. def __init__(self, config: MoonshineConfig):
  716. super().__init__(config)
  717. self.model = MoonshineModel(config)
  718. self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  719. # Initialize weights and apply final processing
  720. self.post_init()
  721. def get_output_embeddings(self):
  722. return self.proj_out
  723. def set_output_embeddings(self, new_embeddings):
  724. self.proj_out = new_embeddings
  725. def get_input_embeddings(self) -> nn.Module:
  726. return self.model.get_input_embeddings()
  727. @can_return_tuple
  728. @auto_docstring
  729. def forward(
  730. self,
  731. input_values: torch.FloatTensor | None = None,
  732. attention_mask: torch.LongTensor | None = None,
  733. decoder_input_ids: torch.LongTensor | None = None,
  734. decoder_attention_mask: torch.LongTensor | None = None,
  735. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  736. past_key_values: EncoderDecoderCache | None = None,
  737. decoder_inputs_embeds: tuple[torch.FloatTensor] | None = None,
  738. decoder_position_ids: tuple[torch.LongTensor] | None = None,
  739. use_cache: bool | None = None,
  740. labels: torch.LongTensor | None = None,
  741. **kwargs: Unpack[TransformersKwargs],
  742. ) -> Seq2SeqLMOutput:
  743. r"""
  744. input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
  745. Float values of the raw speech waveform. Raw speech waveform can be
  746. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
  747. `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
  748. the soundfile library (`pip install soundfile`). To prepare the array into
  749. `input_values`, the [`AutoFeatureExtractor`] should be used for padding
  750. and conversion into a tensor of type `torch.FloatTensor`.
  751. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
  752. Indices of positions of each input sequence tokens in the position embeddings.
  753. Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`
  754. Example:
  755. ```python
  756. >>> import torch
  757. >>> from transformers import AutoProcessor, MoonshineForConditionalGeneration
  758. >>> from datasets import load_dataset
  759. >>> processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-tiny")
  760. >>> model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny")
  761. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  762. >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
  763. >>> input_values = inputs.input_values
  764. >>> generated_ids = model.generate(input_values, max_new_tokens=100)
  765. >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  766. >>> transcription
  767. 'Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
  768. ```"""
  769. if labels is not None:
  770. if decoder_input_ids is None and decoder_inputs_embeds is None:
  771. decoder_input_ids = shift_tokens_right(
  772. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  773. )
  774. outputs: Seq2SeqModelOutput = self.model(
  775. input_values,
  776. attention_mask=attention_mask,
  777. decoder_input_ids=decoder_input_ids,
  778. encoder_outputs=encoder_outputs,
  779. decoder_attention_mask=decoder_attention_mask,
  780. past_key_values=past_key_values,
  781. decoder_inputs_embeds=decoder_inputs_embeds,
  782. decoder_position_ids=decoder_position_ids,
  783. use_cache=use_cache,
  784. **kwargs,
  785. )
  786. logits = self.proj_out(outputs.last_hidden_state)
  787. loss = None
  788. if labels is not None:
  789. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
  790. return Seq2SeqLMOutput(
  791. loss=loss,
  792. logits=logits,
  793. past_key_values=outputs.past_key_values,
  794. decoder_hidden_states=outputs.decoder_hidden_states,
  795. decoder_attentions=outputs.decoder_attentions,
  796. cross_attentions=outputs.cross_attentions,
  797. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  798. encoder_hidden_states=outputs.encoder_hidden_states,
  799. encoder_attentions=outputs.encoder_attentions,
  800. )
  801. __all__ = ["MoonshineModel", "MoonshinePreTrainedModel", "MoonshineForConditionalGeneration"]