modeling_glmasr.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/glmasr/modular_glmasr.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_glmasr.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 the HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache
  24. from ...generation import GenerationMixin
  25. from ...integrations import use_kernelized_func
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
  28. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  29. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...utils import TransformersKwargs, auto_docstring, is_torch_available
  32. from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults
  33. from ...utils.output_capturing import capture_outputs
  34. from ..auto import AutoModel, AutoModelForCausalLM
  35. from .configuration_glmasr import GlmAsrConfig, GlmAsrEncoderConfig
  36. if is_torch_available():
  37. import torch
  38. from torch import nn
  39. class GlmAsrRotaryEmbedding(nn.Module):
  40. inv_freq: torch.Tensor # fix linting for `register_buffer`
  41. def __init__(self, config: GlmAsrConfig, device=None):
  42. super().__init__()
  43. self.max_seq_len_cached = config.max_position_embeddings
  44. self.original_max_seq_len = config.max_position_embeddings
  45. self.config = config
  46. self.rope_type = self.config.rope_parameters["rope_type"]
  47. rope_init_fn: Callable = self.compute_default_rope_parameters
  48. if self.rope_type != "default":
  49. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  50. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  51. self.register_buffer("inv_freq", inv_freq, persistent=False)
  52. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  53. @staticmethod
  54. def compute_default_rope_parameters(
  55. config: GlmAsrConfig | None = None,
  56. device: Optional["torch.device"] = None,
  57. seq_len: int | None = None,
  58. ) -> tuple["torch.Tensor", float]:
  59. """
  60. Computes the inverse frequencies according to the original RoPE implementation
  61. Args:
  62. config ([`~transformers.PreTrainedConfig`]):
  63. The model configuration.
  64. device (`torch.device`):
  65. The device to use for initialization of the inverse frequencies.
  66. seq_len (`int`, *optional*):
  67. The current sequence length. Unused for this type of RoPE.
  68. Returns:
  69. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  70. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  71. """
  72. base = config.rope_parameters["rope_theta"]
  73. partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
  74. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  75. dim = int(head_dim * partial_rotary_factor)
  76. attention_factor = 1.0 # Unused in this type of RoPE
  77. # Compute the inverse frequencies
  78. inv_freq = 1.0 / (
  79. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  80. )
  81. return inv_freq, attention_factor
  82. @torch.no_grad()
  83. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  84. def forward(self, x, position_ids):
  85. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  86. position_ids_expanded = position_ids[:, None, :].float()
  87. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  88. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  89. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  90. emb = torch.cat((freqs, freqs), dim=-1)
  91. cos = emb.cos() * self.attention_scaling
  92. sin = emb.sin() * self.attention_scaling
  93. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  94. def rotate_half(x):
  95. """Rotates half the hidden dims of the input."""
  96. x1 = x[..., : x.shape[-1] // 2]
  97. x2 = x[..., x.shape[-1] // 2 :]
  98. return torch.cat((-x2, x1), dim=-1)
  99. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  100. """
  101. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  102. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  103. """
  104. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  105. if n_rep == 1:
  106. return hidden_states
  107. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  108. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  109. def eager_attention_forward(
  110. module: nn.Module,
  111. query: torch.Tensor,
  112. key: torch.Tensor,
  113. value: torch.Tensor,
  114. attention_mask: torch.Tensor | None,
  115. scaling: float,
  116. dropout: float = 0.0,
  117. **kwargs: Unpack[TransformersKwargs],
  118. ):
  119. key_states = repeat_kv(key, module.num_key_value_groups)
  120. value_states = repeat_kv(value, module.num_key_value_groups)
  121. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  122. if attention_mask is not None:
  123. attn_weights = attn_weights + attention_mask
  124. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  125. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  126. attn_output = torch.matmul(attn_weights, value_states)
  127. attn_output = attn_output.transpose(1, 2).contiguous()
  128. return attn_output, attn_weights
  129. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  130. cos = cos.unsqueeze(unsqueeze_dim)
  131. sin = sin.unsqueeze(unsqueeze_dim)
  132. rotary_dim = cos.shape[-1]
  133. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  134. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  135. # Apply rotary embeddings on the first half or full tensor
  136. q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
  137. k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
  138. # Concatenate back to full shape
  139. q_embed = torch.cat([q_embed, q_pass], dim=-1)
  140. k_embed = torch.cat([k_embed, k_pass], dim=-1)
  141. return q_embed, k_embed
  142. @use_kernelized_func(apply_rotary_pos_emb)
  143. class GlmAsrAttention(nn.Module):
  144. """Multi-headed attention from 'Attention Is All You Need' paper"""
  145. def __init__(self, config: GlmAsrConfig, layer_idx: int):
  146. super().__init__()
  147. self.config = config
  148. self.layer_idx = layer_idx
  149. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  150. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  151. self.scaling = self.head_dim**-0.5
  152. self.attention_dropout = config.attention_dropout
  153. self.is_causal = False
  154. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
  155. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  156. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  157. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True)
  158. def forward(
  159. self,
  160. hidden_states: torch.Tensor,
  161. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  162. **kwargs: Unpack[TransformersKwargs],
  163. ) -> tuple[torch.Tensor, torch.Tensor]:
  164. input_shape = hidden_states.shape[:-1]
  165. hidden_shape = (*input_shape, -1, self.head_dim)
  166. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  167. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  168. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  169. cos, sin = position_embeddings
  170. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  171. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  172. self.config._attn_implementation, eager_attention_forward
  173. )
  174. attn_output, attn_weights = attention_interface(
  175. self,
  176. query_states,
  177. key_states,
  178. value_states,
  179. attention_mask=None,
  180. dropout=0.0 if not self.training else self.attention_dropout,
  181. scaling=self.scaling,
  182. **kwargs,
  183. )
  184. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  185. attn_output = self.o_proj(attn_output)
  186. return attn_output, attn_weights
  187. class GlmAsrMLP(nn.Module):
  188. def __init__(self, config):
  189. super().__init__()
  190. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  191. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  192. self.act_fn = ACT2FN[config.hidden_act]
  193. def forward(self, hidden_states: torch.Tensor):
  194. hidden_states = self.fc1(hidden_states)
  195. hidden_states = self.act_fn(hidden_states)
  196. hidden_states = self.fc2(hidden_states)
  197. return hidden_states
  198. class GlmAsrEncoderLayer(GradientCheckpointingLayer):
  199. def __init__(self, config: GlmAsrConfig, layer_idx: int):
  200. super().__init__()
  201. self.hidden_size = config.hidden_size
  202. self.self_attn = GlmAsrAttention(config=config, layer_idx=layer_idx)
  203. self.mlp = GlmAsrMLP(config)
  204. self.input_layernorm = nn.LayerNorm(config.hidden_size)
  205. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
  206. def forward(
  207. self,
  208. hidden_states: torch.Tensor,
  209. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  210. **kwargs: Unpack[TransformersKwargs],
  211. ) -> torch.Tensor:
  212. residual = hidden_states
  213. hidden_states = self.input_layernorm(hidden_states)
  214. # Self Attention
  215. hidden_states, _ = self.self_attn(
  216. hidden_states=hidden_states,
  217. position_embeddings=position_embeddings,
  218. **kwargs,
  219. )
  220. hidden_states = residual + hidden_states
  221. # Fully Connected
  222. residual = hidden_states
  223. hidden_states = self.post_attention_layernorm(hidden_states)
  224. hidden_states = self.mlp(hidden_states)
  225. hidden_states = residual + hidden_states
  226. return hidden_states
  227. @auto_docstring
  228. class GlmAsrPreTrainedModel(PreTrainedModel):
  229. config: GlmAsrConfig
  230. base_model_prefix = "model"
  231. input_modalities = ("audio", "text")
  232. supports_gradient_checkpointing = True
  233. _no_split_modules = ["GlmAsrAttention"]
  234. _skip_keys_device_placement = "past_key_values"
  235. _supports_flash_attn = True
  236. _supports_sdpa = True
  237. # TODO: @eustlb, this is what WhisperEncoder should look like
  238. class GlmAsrEncoder(GlmAsrPreTrainedModel):
  239. config: GlmAsrEncoderConfig
  240. main_input_name = "input_features"
  241. input_modalities = "audio"
  242. _no_split_modules = ["GlmAsrEncoderLayer"]
  243. _can_record_outputs = {
  244. "hidden_states": GlmAsrEncoderLayer,
  245. "attentions": GlmAsrAttention,
  246. }
  247. def __init__(self, config: GlmAsrEncoderConfig):
  248. super().__init__(config)
  249. self.conv1 = nn.Conv1d(config.num_mel_bins, config.hidden_size, kernel_size=3, padding=1)
  250. self.conv2 = nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2, padding=1)
  251. self.layers = nn.ModuleList(
  252. [GlmAsrEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  253. )
  254. self.norm = nn.LayerNorm(config.hidden_size)
  255. self.rotary_emb = GlmAsrRotaryEmbedding(config=config)
  256. self.gradient_checkpointing = False
  257. self.post_init()
  258. @merge_with_config_defaults
  259. @capture_outputs
  260. @auto_docstring
  261. def forward(self, input_features, **kwargs: Unpack[TransformersKwargs]):
  262. inputs_embeds = nn.functional.gelu(self.conv1(input_features))
  263. inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
  264. inputs_embeds = inputs_embeds.transpose(1, 2)
  265. hidden_states = inputs_embeds
  266. position_embeddings = self.rotary_emb(
  267. hidden_states, position_ids=torch.arange(hidden_states.shape[1], device=hidden_states.device)[None, :]
  268. )
  269. for encoder_layer in self.layers:
  270. hidden_states = encoder_layer(hidden_states, position_embeddings=position_embeddings, **kwargs)
  271. hidden_states = self.norm(hidden_states)
  272. return BaseModelOutputWithPooling(last_hidden_state=hidden_states)
  273. class GlmAsrMultiModalProjector(nn.Module):
  274. """
  275. Audio adaptor (small MLP) that projects GlmAsrEncoder features
  276. to the LLM embedding space so they can replace `<sound>` tokens.
  277. """
  278. def __init__(self, config: GlmAsrConfig):
  279. super().__init__()
  280. self.linear_1 = nn.Linear(config.audio_config.intermediate_size, config.text_config.hidden_size * 2)
  281. self.act = ACT2FN[config.projector_hidden_act]
  282. self.linear_2 = nn.Linear(config.text_config.hidden_size * 2, config.text_config.hidden_size)
  283. def forward(self, audio_features):
  284. hidden_states = self.linear_1(audio_features)
  285. hidden_states = self.act(hidden_states)
  286. hidden_states = self.linear_2(hidden_states)
  287. return hidden_states
  288. @auto_docstring(
  289. custom_intro="""
  290. The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model.
  291. """
  292. )
  293. class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin):
  294. _keep_in_fp32_modules_strict = None
  295. _tp_plan = None
  296. _pp_plan = None
  297. def __init__(self, config):
  298. super().__init__(config)
  299. self.vocab_size = config.text_config.vocab_size
  300. self.audio_tower = AutoModel.from_config(config.audio_config)
  301. self.language_model = AutoModelForCausalLM.from_config(config.text_config)
  302. self.multi_modal_projector = GlmAsrMultiModalProjector(config)
  303. # Initialize weights and apply final processing
  304. self.post_init()
  305. def get_input_embeddings(self):
  306. return self.language_model.get_input_embeddings()
  307. def set_input_embeddings(self, value):
  308. self.language_model.set_input_embeddings(value)
  309. def get_output_embeddings(self):
  310. return self.language_model.get_output_embeddings()
  311. def set_output_embeddings(self, new_embeddings):
  312. self.language_model.set_output_embeddings(new_embeddings)
  313. def set_decoder(self, decoder):
  314. self.language_model.set_decoder(decoder)
  315. def get_decoder(self):
  316. return self.language_model.get_decoder()
  317. @can_return_tuple
  318. @auto_docstring(
  319. custom_intro="Compute audio embeddings from log-mel input features using the audio encoder and multi-modal projector."
  320. )
  321. def get_audio_features(
  322. self,
  323. input_features: torch.FloatTensor,
  324. input_features_mask: torch.Tensor,
  325. **kwargs: Unpack[TransformersKwargs],
  326. ) -> tuple | BaseModelOutputWithPooling:
  327. r"""
  328. input_features (`torch.FloatTensor`):
  329. Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
  330. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
  331. `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
  332. `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
  333. and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
  334. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
  335. Mask to avoid performing attention on padded feature indices.
  336. """
  337. audio_outputs = self.audio_tower(input_features, return_dict=True, **kwargs)
  338. audio_hidden_states = audio_outputs.last_hidden_state
  339. audio_hidden_states = audio_hidden_states.reshape(
  340. input_features.shape[0], -1, self.config.audio_config.intermediate_size
  341. )
  342. audio_embeds = self.multi_modal_projector(audio_hidden_states)
  343. audio_lengths = input_features_mask.sum(-1)
  344. for padding, kernel_size, stride in [(1, 3, 1), (1, 3, 2)]:
  345. audio_lengths = (audio_lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
  346. merge_factor = 4
  347. post_lengths = (audio_lengths - merge_factor) // merge_factor + 1
  348. valid_mask = torch.arange(audio_embeds.shape[1], device=post_lengths.device)[None, :] < post_lengths[:, None]
  349. audio_outputs.pooler_output = audio_embeds[valid_mask.to(audio_embeds.device)]
  350. return audio_outputs
  351. @can_return_tuple
  352. @auto_docstring
  353. def forward(
  354. self,
  355. input_ids: torch.LongTensor | None = None,
  356. input_features: torch.FloatTensor | None = None,
  357. input_features_mask: torch.Tensor | None = None,
  358. attention_mask: torch.Tensor | None = None,
  359. position_ids: torch.LongTensor | None = None,
  360. past_key_values: Cache | None = None,
  361. inputs_embeds: torch.FloatTensor | None = None,
  362. labels: torch.LongTensor | None = None,
  363. use_cache: bool | None = None,
  364. logits_to_keep: int | torch.Tensor = 0,
  365. **kwargs: Unpack[TransformersKwargs],
  366. ) -> CausalLMOutputWithPast:
  367. r"""
  368. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
  369. Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
  370. - 1 for tokens that are **not masked**,
  371. - 0 for tokens that are **masked**.
  372. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  373. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  374. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  375. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  376. Example:
  377. ```python
  378. >>> from transformers import GlmAsrForConditionalGeneration, AutoProcessor
  379. >>> model_id = "zai-org/GLM-ASR-Nano-2512"
  380. >>> processor = AutoProcessor.from_pretrained(model_id)
  381. >>> model = GlmAsrForConditionalGeneration.from_pretrained(model_id, dtype="auto", device_map="auto")
  382. >>> inputs = processor.apply_transcription_request("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3")
  383. >>> inputs = inputs.to(model.device, dtype=model.dtype)
  384. >>> outputs = model.generate(**inputs, do_sample=False, max_new_tokens=500)
  385. >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
  386. >>> print(decoded_outputs)
  387. ```"""
  388. if inputs_embeds is None:
  389. inputs_embeds = self.get_input_embeddings()(input_ids)
  390. if input_features is not None and input_ids is not None:
  391. audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output
  392. # replace text-audio token placeholders with audio embeddings
  393. audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
  394. inputs_embeds = inputs_embeds.masked_scatter(
  395. audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
  396. )
  397. outputs: CausalLMOutputWithPast = self.language_model(
  398. inputs_embeds=inputs_embeds,
  399. attention_mask=attention_mask,
  400. position_ids=position_ids,
  401. past_key_values=past_key_values,
  402. labels=labels,
  403. use_cache=use_cache,
  404. logits_to_keep=logits_to_keep,
  405. **kwargs,
  406. )
  407. return outputs
  408. def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs):
  409. input_features = kwargs.pop("input_features", None)
  410. input_features_mask = kwargs.pop("input_features_mask", None)
  411. model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
  412. if is_first_iteration or not model_inputs.get("use_cache", False):
  413. if input_features is not None:
  414. model_inputs["input_features"] = input_features
  415. if input_features_mask is not None:
  416. model_inputs["input_features_mask"] = input_features_mask
  417. return model_inputs
  418. __all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrPreTrainedModel"]