modular_moonshine.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Callable
  15. from dataclasses import dataclass
  16. import torch
  17. import torch.nn as nn
  18. from huggingface_hub.dataclasses import strict
  19. from ...activations import ACT2FN
  20. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  21. from ...configuration_utils import PreTrainedConfig
  22. from ...generation import GenerationMixin
  23. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  24. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutput,
  28. BaseModelOutputWithPast,
  29. BaseModelOutputWithPastAndCrossAttentions,
  30. Seq2SeqLMOutput,
  31. Seq2SeqModelOutput,
  32. )
  33. from ...modeling_rope_utils import RopeParameters
  34. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  35. from ...processing_utils import Unpack
  36. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  37. from ...utils.generic import merge_with_config_defaults
  38. from ...utils.output_capturing import OutputRecorder, capture_outputs
  39. from ..glm.modeling_glm import GlmAttention, GlmRotaryEmbedding, apply_rotary_pos_emb
  40. from ..llama.modeling_llama import LlamaDecoderLayer, LlamaModel, eager_attention_forward
  41. from ..whisper.modeling_whisper import WhisperModel, shift_tokens_right
  42. logger = logging.get_logger(__name__)
  43. @auto_docstring(checkpoint="UsefulSensors/moonshine-tiny")
  44. @strict
  45. class MoonshineConfig(PreTrainedConfig):
  46. r"""
  47. encoder_num_key_value_heads (`int`, *optional*):
  48. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  49. `encoder_num_key_value_heads=encoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if
  50. `encoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  51. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  52. by meanpooling all the original heads within that group. For more details, check out [this
  53. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  54. `num_attention_heads`.
  55. decoder_num_key_value_heads (`int`, *optional*):
  56. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  57. `decoder_num_key_value_heads=decoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if
  58. `decoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  59. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  60. by meanpooling all the original heads within that group. For more details, check out [this
  61. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  62. `decoder_num_attention_heads`.
  63. pad_head_dim_to_multiple_of (`int`, *optional*):
  64. Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain
  65. optimized attention implementations.
  66. encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
  67. The non-linear activation function (function or string) in the encoder.
  68. decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  69. The non-linear activation function (function or string) in the decoder.
  70. Example:
  71. ```python
  72. >>> from transformers import MoonshineModel, MoonshineConfig
  73. >>> # Initializing a Moonshine style configuration
  74. >>> configuration = MoonshineConfig().from_pretrained("UsefulSensors/moonshine-tiny")
  75. >>> # Initializing a model from the configuration
  76. >>> model = MoonshineModel(configuration)
  77. >>> # Accessing the model configuration
  78. >>> configuration = model.config
  79. ```"""
  80. model_type = "moonshine"
  81. keys_to_ignore_at_inference = ["past_key_values"]
  82. attribute_map = {
  83. "num_key_value_heads": "decoder_num_key_value_heads",
  84. "num_attention_heads": "decoder_num_attention_heads",
  85. "num_hidden_layers": "decoder_num_hidden_layers",
  86. "hidden_act": "decoder_hidden_act",
  87. }
  88. vocab_size: int = 32768
  89. hidden_size: int = 288
  90. intermediate_size: int = 1152
  91. encoder_num_hidden_layers: int = 6
  92. decoder_num_hidden_layers: int = 6
  93. encoder_num_attention_heads: int = 8
  94. decoder_num_attention_heads: int = 8
  95. encoder_num_key_value_heads: int | None = None
  96. decoder_num_key_value_heads: int | None = None
  97. pad_head_dim_to_multiple_of: int | None = None
  98. encoder_hidden_act: str = "gelu"
  99. decoder_hidden_act: str = "silu"
  100. max_position_embeddings: int = 512
  101. initializer_range: float = 0.02
  102. decoder_start_token_id: int = 1
  103. use_cache: bool = True
  104. rope_parameters: RopeParameters | dict | None = None
  105. is_encoder_decoder: bool = True
  106. attention_bias: bool = False
  107. attention_dropout: float | int = 0.0
  108. bos_token_id: int | None = 1
  109. eos_token_id: int | list[int] | None = 2
  110. pad_token_id: int | None = None
  111. tie_word_embeddings: bool = True
  112. def __post_init__(self, **kwargs):
  113. if self.encoder_num_key_value_heads is None:
  114. self.encoder_num_key_value_heads = self.encoder_num_attention_heads
  115. if self.decoder_num_key_value_heads is None:
  116. self.decoder_num_key_value_heads = self.decoder_num_attention_heads
  117. kwargs.setdefault("partial_rotary_factor", 0.9) # assign default for BC
  118. super().__post_init__(**kwargs)
  119. @dataclass
  120. @auto_docstring(
  121. custom_intro="""
  122. Extends [~modeling_outputs.BaseModelOutput] to include the output attention mask since sequence length is not preserved in the model's forward.
  123. """
  124. )
  125. class MoonshineEncoderModelOutput(BaseModelOutput):
  126. attention_mask: torch.Tensor | None = None
  127. class MoonshineEncoderMLP(nn.Module):
  128. def __init__(self, config, hidden_act):
  129. super().__init__()
  130. self.config = config
  131. self.activation_fn = ACT2FN[hidden_act]
  132. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  133. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  134. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  135. hidden_states = self.fc1(hidden_states)
  136. hidden_states = self.activation_fn(hidden_states)
  137. hidden_states = self.fc2(hidden_states)
  138. return hidden_states
  139. class MoonshineDecoderMLP(nn.Module):
  140. def __init__(self, config, hidden_act):
  141. super().__init__()
  142. self.config = config
  143. self.activation_fn = ACT2FN[hidden_act]
  144. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size * 2)
  145. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  146. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  147. hidden_states = self.fc1(hidden_states)
  148. hidden_states, gate = hidden_states.chunk(2, dim=-1)
  149. hidden_states = self.activation_fn(gate) * hidden_states
  150. hidden_states = self.fc2(hidden_states)
  151. return hidden_states
  152. class MoonshineRotaryEmbedding(GlmRotaryEmbedding):
  153. pass
  154. class MoonshineAttention(GlmAttention):
  155. def __init__(
  156. self,
  157. config: MoonshineConfig,
  158. layer_idx: int,
  159. is_causal: bool,
  160. num_attention_heads: int,
  161. num_key_value_heads: int,
  162. ):
  163. config.update({"num_attention_heads": num_attention_heads, "num_key_value_heads": num_key_value_heads})
  164. super().__init__(config, layer_idx)
  165. self.is_causal = is_causal
  166. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  167. # Pad head dimension to the next specified multiple.
  168. if self.config.pad_head_dim_to_multiple_of is not None:
  169. target_multiple = self.config.pad_head_dim_to_multiple_of
  170. target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple)
  171. self.head_dim_padding = target_head_dim - self.head_dim
  172. else:
  173. self.head_dim_padding = 0
  174. def forward(
  175. self,
  176. hidden_states: torch.Tensor,
  177. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  178. attention_mask: torch.Tensor | None = None,
  179. past_key_values: Cache | None = None,
  180. key_value_states: torch.Tensor | None = None,
  181. **kwargs: Unpack[FlashAttentionKwargs],
  182. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  183. bsz, q_len = hidden_states.shape[:-1]
  184. query_states = (
  185. self.q_proj(hidden_states).view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2)
  186. )
  187. is_cross_attention = key_value_states is not None
  188. if past_key_values is not None:
  189. is_updated = past_key_values.is_updated.get(self.layer_idx)
  190. if is_cross_attention:
  191. # after the first generated id, we can subsequently re-use all key/value_states from cache
  192. past_key_values.is_updated[self.layer_idx] = True
  193. past_key_values = past_key_values.cross_attention_cache
  194. else:
  195. past_key_values = past_key_values.self_attention_cache
  196. # use key_value_states if cross attention
  197. current_states = key_value_states if key_value_states is not None else hidden_states
  198. if is_cross_attention and past_key_values and is_updated:
  199. key_states = past_key_values.layers[self.layer_idx].keys
  200. value_states = past_key_values.layers[self.layer_idx].values
  201. else:
  202. key_states = (
  203. self.k_proj(current_states)
  204. .view(bsz, -1, self.config.num_key_value_heads, self.head_dim)
  205. .transpose(1, 2)
  206. )
  207. value_states = (
  208. self.v_proj(current_states)
  209. .view(bsz, -1, self.config.num_key_value_heads, self.head_dim)
  210. .transpose(1, 2)
  211. )
  212. if is_cross_attention and past_key_values is not None:
  213. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  214. if not is_cross_attention:
  215. cos, sin = position_embeddings
  216. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  217. if past_key_values is not None:
  218. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  219. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  220. self.config._attn_implementation, eager_attention_forward
  221. )
  222. is_causal = self.is_causal and attention_mask is None and q_len > 1
  223. if self.head_dim_padding > 0:
  224. query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding))
  225. key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding))
  226. value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding))
  227. attn_output, attn_weights = attention_interface(
  228. self,
  229. query_states,
  230. key_states,
  231. value_states,
  232. attention_mask,
  233. dropout=0.0 if not self.training else self.attention_dropout,
  234. scaling=self.scaling,
  235. is_causal=is_causal,
  236. **kwargs,
  237. )
  238. if self.head_dim_padding > 0:
  239. attn_output = attn_output[..., : -self.head_dim_padding]
  240. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  241. attn_output = self.o_proj(attn_output)
  242. return attn_output, attn_weights
  243. class MoonshineEncoderLayer(LlamaDecoderLayer):
  244. def __init__(self, config: MoonshineConfig, layer_idx: int):
  245. super().__init__(config, layer_idx)
  246. self.self_attn = MoonshineAttention(
  247. config=config,
  248. layer_idx=layer_idx,
  249. is_causal=False,
  250. num_attention_heads=config.encoder_num_attention_heads,
  251. num_key_value_heads=config.encoder_num_key_value_heads,
  252. )
  253. self.mlp = MoonshineEncoderMLP(config, config.encoder_hidden_act)
  254. self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  255. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  256. class MoonshineDecoderLayer(GradientCheckpointingLayer):
  257. def __init__(self, config: MoonshineConfig, layer_idx: int | None = None):
  258. super().__init__()
  259. self.hidden_size = config.hidden_size
  260. self.self_attn = MoonshineAttention(
  261. config=config,
  262. layer_idx=layer_idx,
  263. is_causal=True,
  264. num_attention_heads=config.num_attention_heads,
  265. num_key_value_heads=config.num_key_value_heads,
  266. )
  267. self.encoder_attn = MoonshineAttention(
  268. config=config,
  269. layer_idx=layer_idx,
  270. is_causal=False,
  271. num_attention_heads=config.num_attention_heads,
  272. num_key_value_heads=config.num_key_value_heads,
  273. )
  274. self.mlp = MoonshineDecoderMLP(config, config.hidden_act)
  275. self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  276. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  277. self.final_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  278. def forward(
  279. self,
  280. hidden_states: torch.Tensor,
  281. attention_mask: torch.Tensor | None = None,
  282. encoder_hidden_states: torch.Tensor | None = None,
  283. encoder_attention_mask: torch.Tensor | None = None,
  284. position_ids: torch.LongTensor | None = None,
  285. encoder_position_ids: torch.LongTensor | None = None,
  286. past_key_values: Cache | None = None,
  287. use_cache: bool | None = False,
  288. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  289. encoder_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  290. **kwargs: Unpack[TransformersKwargs],
  291. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  292. residual = hidden_states
  293. hidden_states = self.input_layernorm(hidden_states)
  294. hidden_states, _ = self.self_attn(
  295. hidden_states=hidden_states,
  296. attention_mask=attention_mask,
  297. position_ids=position_ids,
  298. past_key_values=past_key_values,
  299. use_cache=use_cache,
  300. position_embeddings=position_embeddings,
  301. **kwargs,
  302. )
  303. hidden_states = residual + hidden_states
  304. if encoder_hidden_states is not None:
  305. residual = hidden_states
  306. hidden_states = self.post_attention_layernorm(hidden_states)
  307. hidden_states, _ = self.encoder_attn(
  308. hidden_states=hidden_states,
  309. key_value_states=encoder_hidden_states,
  310. attention_mask=encoder_attention_mask,
  311. past_key_values=past_key_values,
  312. use_cache=use_cache,
  313. )
  314. hidden_states = residual + hidden_states
  315. residual = hidden_states
  316. hidden_states = self.final_layernorm(hidden_states)
  317. hidden_states = self.mlp(hidden_states)
  318. hidden_states = residual + hidden_states
  319. return hidden_states
  320. @auto_docstring
  321. class MoonshinePreTrainedModel(PreTrainedModel):
  322. config: MoonshineConfig
  323. base_model_prefix = "model"
  324. main_input_name = "input_values"
  325. input_modalities = "audio"
  326. supports_gradient_checkpointing = True
  327. _no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"]
  328. _supports_flash_attn = True
  329. _supports_sdpa = True
  330. _can_compile_fullgraph = True
  331. # TODO arthur, how do we separate when it cross / self coming from different layer?
  332. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
  333. """
  334. Computes the output length of the convolutional layers
  335. """
  336. output_conv1_length = int((input_lengths - 127) / 64 + 1)
  337. output_conv2_length = int((output_conv1_length - 7) / 3 + 1)
  338. output_conv3_length = int((output_conv2_length - 3) / 2 + 1)
  339. return output_conv3_length
  340. class MoonshineEncoder(MoonshinePreTrainedModel):
  341. """
  342. Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoonshineEncoderLayer`]
  343. Args:
  344. config: MoonshineConfig
  345. """
  346. main_input_name = "input_values"
  347. _can_record_outputs = {
  348. "attentions": MoonshineAttention,
  349. "hidden_states": MoonshineEncoderLayer,
  350. }
  351. def __init__(self, config: MoonshineConfig):
  352. super().__init__(config)
  353. self.config = config
  354. embed_dim = config.hidden_size
  355. self.conv1 = nn.Conv1d(1, embed_dim, kernel_size=127, stride=64, bias=False)
  356. self.conv2 = nn.Conv1d(embed_dim, 2 * embed_dim, kernel_size=7, stride=3)
  357. self.conv3 = nn.Conv1d(2 * embed_dim, embed_dim, kernel_size=3, stride=2)
  358. self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=1e-5)
  359. self.layers = nn.ModuleList(
  360. [MoonshineEncoderLayer(config, idx) for idx in range(config.encoder_num_hidden_layers)]
  361. )
  362. self.layer_norm = nn.LayerNorm(embed_dim, bias=False)
  363. self.rotary_emb = MoonshineRotaryEmbedding(config=config)
  364. self.gradient_checkpointing = False
  365. self.post_init()
  366. def get_input_embeddings(self) -> nn.Module:
  367. return self.conv1
  368. def set_input_embeddings(self, value: nn.Module):
  369. self.conv1 = value
  370. @merge_with_config_defaults
  371. @capture_outputs
  372. def forward(
  373. self,
  374. input_values: torch.FloatTensor,
  375. attention_mask: torch.Tensor | None = None,
  376. **kwargs: Unpack[TransformersKwargs],
  377. ) -> tuple | BaseModelOutputWithPast:
  378. r"""
  379. Args:
  380. input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
  381. Float values of the raw speech waveform. Raw speech waveform can be
  382. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
  383. `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
  384. the soundfile library (`pip install soundfile`). To prepare the array into
  385. `input_values`, the [`AutoFeatureExtractor`] should be used for padding
  386. and conversion into a tensor of type `torch.FloatTensor`.
  387. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  388. Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
  389. - 1 for tokens that are **not masked**,
  390. - 0 for tokens that are **masked**.
  391. [What are attention masks?](../glossary#attention-mask)
  392. """
  393. input_values = input_values.unsqueeze(1)
  394. hidden_states = nn.functional.tanh(self.conv1(input_values))
  395. hidden_states = self.groupnorm(hidden_states)
  396. hidden_states = nn.functional.gelu(self.conv2(hidden_states))
  397. hidden_states = nn.functional.gelu(self.conv3(hidden_states))
  398. hidden_states = hidden_states.permute(0, 2, 1)
  399. # attention mask downsampling
  400. output_attention_mask = None
  401. if attention_mask is not None:
  402. mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1])
  403. downsample_stride = 64 * 3 * 2 # conv strides
  404. attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len]
  405. output_attention_mask = attention_mask
  406. attention_mask = create_bidirectional_mask(
  407. config=self.config,
  408. inputs_embeds=hidden_states,
  409. attention_mask=attention_mask,
  410. encoder_hidden_states=hidden_states,
  411. )
  412. position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
  413. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  414. for encoder_layer in self.layers:
  415. hidden_states = encoder_layer(
  416. hidden_states,
  417. attention_mask=attention_mask,
  418. position_ids=position_ids,
  419. position_embeddings=position_embeddings,
  420. **kwargs,
  421. )
  422. hidden_states = self.layer_norm(hidden_states)
  423. return MoonshineEncoderModelOutput(
  424. last_hidden_state=hidden_states,
  425. attention_mask=output_attention_mask.int() if output_attention_mask is not None else None,
  426. )
  427. class MoonshineDecoder(LlamaModel):
  428. main_input_name = "input_ids"
  429. _can_record_outputs = {
  430. "attentions": OutputRecorder(MoonshineAttention, index=1, layer_name="self_attn"),
  431. "hidden_states": MoonshineDecoderLayer,
  432. "cross_attentions": OutputRecorder(MoonshineAttention, index=1, layer_name="encoder_attn"),
  433. }
  434. def __init__(self, config: MoonshineConfig):
  435. super().__init__(config)
  436. self.norm = nn.LayerNorm(config.hidden_size, bias=False)
  437. self.layers = nn.ModuleList([MoonshineDecoderLayer(config, idx) for idx in range(config.num_hidden_layers)])
  438. @merge_with_config_defaults
  439. @capture_outputs
  440. def forward(
  441. self,
  442. input_ids: torch.LongTensor | None = None,
  443. attention_mask: torch.Tensor | None = None,
  444. position_ids: torch.LongTensor | None = None,
  445. past_key_values: Cache | None = None,
  446. inputs_embeds: torch.FloatTensor | None = None,
  447. use_cache: bool | None = None,
  448. encoder_hidden_states: torch.FloatTensor | None = None,
  449. encoder_attention_mask: torch.Tensor | None = None,
  450. **kwargs: Unpack[TransformersKwargs],
  451. ) -> tuple | BaseModelOutputWithPast:
  452. r"""
  453. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  454. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  455. of the decoder.
  456. encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  457. Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
  458. - 1 for tokens that are **not masked**,
  459. - 0 for tokens that are **masked**.
  460. [What are attention masks?](../glossary#attention-mask)
  461. """
  462. if (input_ids is None) ^ (inputs_embeds is not None):
  463. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  464. if inputs_embeds is None:
  465. inputs_embeds = self.embed_tokens(input_ids)
  466. if use_cache and past_key_values is None:
  467. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  468. if position_ids is None:
  469. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  470. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  471. position_ids = position_ids.unsqueeze(0)
  472. causal_mask = create_causal_mask(
  473. config=self.config,
  474. inputs_embeds=inputs_embeds,
  475. attention_mask=attention_mask,
  476. past_key_values=past_key_values,
  477. position_ids=position_ids,
  478. )
  479. encoder_attention_mask = create_bidirectional_mask(
  480. config=self.config,
  481. inputs_embeds=inputs_embeds,
  482. attention_mask=encoder_attention_mask,
  483. encoder_hidden_states=encoder_hidden_states,
  484. )
  485. hidden_states = inputs_embeds
  486. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  487. for decoder_layer in self.layers:
  488. hidden_states = decoder_layer(
  489. hidden_states,
  490. causal_mask,
  491. encoder_hidden_states, # as a positional argument for gradient checkpointing
  492. encoder_attention_mask=encoder_attention_mask,
  493. position_ids=position_ids,
  494. past_key_values=past_key_values,
  495. use_cache=use_cache,
  496. position_embeddings=position_embeddings,
  497. **kwargs,
  498. )
  499. hidden_states = self.norm(hidden_states)
  500. return BaseModelOutputWithPastAndCrossAttentions(
  501. last_hidden_state=hidden_states,
  502. past_key_values=past_key_values if use_cache else None,
  503. )
  504. class MoonshineModel(WhisperModel):
  505. def _mask_input_features(self):
  506. raise AttributeError("Not needed for Moonshine")
  507. @can_return_tuple
  508. @auto_docstring
  509. def forward(
  510. self,
  511. input_values: torch.FloatTensor | None = None,
  512. attention_mask: torch.LongTensor | None = None,
  513. decoder_input_ids: torch.LongTensor | None = None,
  514. decoder_attention_mask: torch.LongTensor | None = None,
  515. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  516. past_key_values: EncoderDecoderCache | None = None,
  517. decoder_inputs_embeds: tuple[torch.FloatTensor] | None = None,
  518. decoder_position_ids: tuple[torch.LongTensor] | None = None,
  519. use_cache: bool | None = None,
  520. **kwargs: Unpack[TransformersKwargs],
  521. ) -> Seq2SeqModelOutput:
  522. r"""
  523. input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
  524. Float values of the raw speech waveform. Raw speech waveform can be
  525. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
  526. `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
  527. the soundfile library (`pip install soundfile`). To prepare the array into
  528. `input_values`, the [`AutoFeatureExtractor`] should be used for padding
  529. and conversion into a tensor of type `torch.FloatTensor`.
  530. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
  531. Indices of positions of each input sequence tokens in the position embeddings.
  532. Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`
  533. Example:
  534. ```python
  535. >>> import torch
  536. >>> from transformers import AutoFeatureExtractor, MoonshineModel
  537. >>> from datasets import load_dataset
  538. >>> model = MoonshineModel.from_pretrained("UsefulSensors/moonshine-tiny")
  539. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("UsefulSensors/moonshine-tiny")
  540. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  541. >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
  542. >>> input_values = inputs.input_values
  543. >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
  544. >>> last_hidden_state = model(input_values, decoder_input_ids=decoder_input_ids).last_hidden_state
  545. >>> list(last_hidden_state.shape)
  546. [1, 2, 288]
  547. ```
  548. """
  549. if encoder_outputs is None:
  550. encoder_outputs: BaseModelOutput = self.encoder(input_values, attention_mask=attention_mask, **kwargs)
  551. decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder(
  552. input_ids=decoder_input_ids,
  553. attention_mask=decoder_attention_mask,
  554. encoder_hidden_states=encoder_outputs.last_hidden_state,
  555. encoder_attention_mask=encoder_outputs.attention_mask,
  556. past_key_values=past_key_values,
  557. inputs_embeds=decoder_inputs_embeds,
  558. position_ids=decoder_position_ids,
  559. use_cache=use_cache,
  560. **kwargs,
  561. )
  562. return Seq2SeqModelOutput(
  563. last_hidden_state=decoder_outputs.last_hidden_state,
  564. past_key_values=decoder_outputs.past_key_values,
  565. decoder_hidden_states=decoder_outputs.hidden_states,
  566. decoder_attentions=decoder_outputs.attentions,
  567. cross_attentions=decoder_outputs.cross_attentions,
  568. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  569. encoder_hidden_states=encoder_outputs.hidden_states,
  570. encoder_attentions=encoder_outputs.attentions,
  571. )
  572. @auto_docstring(
  573. custom_intro="""
  574. The Moonshine Model with a language modeling head. Can be used for automatic speech recognition.
  575. """
  576. )
  577. class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixin):
  578. _tied_weights_keys = {"proj_out.weight": "model.decoder.embed_tokens.weight"}
  579. def __init__(self, config: MoonshineConfig):
  580. super().__init__(config)
  581. self.model = MoonshineModel(config)
  582. self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  583. # Initialize weights and apply final processing
  584. self.post_init()
  585. def get_output_embeddings(self):
  586. return self.proj_out
  587. def set_output_embeddings(self, new_embeddings):
  588. self.proj_out = new_embeddings
  589. def get_input_embeddings(self) -> nn.Module:
  590. return self.model.get_input_embeddings()
  591. @can_return_tuple
  592. @auto_docstring
  593. def forward(
  594. self,
  595. input_values: torch.FloatTensor | None = None,
  596. attention_mask: torch.LongTensor | None = None,
  597. decoder_input_ids: torch.LongTensor | None = None,
  598. decoder_attention_mask: torch.LongTensor | None = None,
  599. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  600. past_key_values: EncoderDecoderCache | None = None,
  601. decoder_inputs_embeds: tuple[torch.FloatTensor] | None = None,
  602. decoder_position_ids: tuple[torch.LongTensor] | None = None,
  603. use_cache: bool | None = None,
  604. labels: torch.LongTensor | None = None,
  605. **kwargs: Unpack[TransformersKwargs],
  606. ) -> Seq2SeqLMOutput:
  607. r"""
  608. input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
  609. Float values of the raw speech waveform. Raw speech waveform can be
  610. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
  611. `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
  612. the soundfile library (`pip install soundfile`). To prepare the array into
  613. `input_values`, the [`AutoFeatureExtractor`] should be used for padding
  614. and conversion into a tensor of type `torch.FloatTensor`.
  615. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
  616. Indices of positions of each input sequence tokens in the position embeddings.
  617. Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`
  618. Example:
  619. ```python
  620. >>> import torch
  621. >>> from transformers import AutoProcessor, MoonshineForConditionalGeneration
  622. >>> from datasets import load_dataset
  623. >>> processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-tiny")
  624. >>> model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny")
  625. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  626. >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
  627. >>> input_values = inputs.input_values
  628. >>> generated_ids = model.generate(input_values, max_new_tokens=100)
  629. >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  630. >>> transcription
  631. 'Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
  632. ```"""
  633. if labels is not None:
  634. if decoder_input_ids is None and decoder_inputs_embeds is None:
  635. decoder_input_ids = shift_tokens_right(
  636. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  637. )
  638. outputs: Seq2SeqModelOutput = self.model(
  639. input_values,
  640. attention_mask=attention_mask,
  641. decoder_input_ids=decoder_input_ids,
  642. encoder_outputs=encoder_outputs,
  643. decoder_attention_mask=decoder_attention_mask,
  644. past_key_values=past_key_values,
  645. decoder_inputs_embeds=decoder_inputs_embeds,
  646. decoder_position_ids=decoder_position_ids,
  647. use_cache=use_cache,
  648. **kwargs,
  649. )
  650. logits = self.proj_out(outputs.last_hidden_state)
  651. loss = None
  652. if labels is not None:
  653. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
  654. return Seq2SeqLMOutput(
  655. loss=loss,
  656. logits=logits,
  657. past_key_values=outputs.past_key_values,
  658. decoder_hidden_states=outputs.decoder_hidden_states,
  659. decoder_attentions=outputs.decoder_attentions,
  660. cross_attentions=outputs.cross_attentions,
  661. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  662. encoder_hidden_states=outputs.encoder_hidden_states,
  663. encoder_attentions=outputs.encoder_attentions,
  664. )
  665. __all__ = [
  666. "MoonshineConfig",
  667. "MoonshineModel",
  668. "MoonshinePreTrainedModel",
  669. "MoonshineForConditionalGeneration",
  670. ]