modeling_vaultgemma.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/vaultgemma/modular_vaultgemma.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_vaultgemma.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 the HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. import torch.nn as nn
  24. from ... import initialization as init
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_func_from_hub, use_kernelized_func
  29. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  30. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  33. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  34. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  35. from ...processing_utils import Unpack
  36. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  37. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  38. from ...utils.output_capturing import capture_outputs
  39. from .configuration_vaultgemma import VaultGemmaConfig
  40. class VaultGemmaRMSNorm(nn.Module):
  41. def __init__(self, dim: int, eps: float = 1e-6):
  42. super().__init__()
  43. self.eps = eps
  44. self.weight = nn.Parameter(torch.zeros(dim))
  45. def _norm(self, x):
  46. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  47. def forward(self, x):
  48. output = self._norm(x.float())
  49. # Llama does x.to(float16) * w whilst VaultGemma is (x * w).to(float16)
  50. # See https://github.com/huggingface/transformers/pull/29402
  51. output = output * (1.0 + self.weight.float())
  52. return output.type_as(x)
  53. def extra_repr(self):
  54. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  55. class VaultGemmaMLP(nn.Module):
  56. def __init__(self, config):
  57. super().__init__()
  58. self.config = config
  59. self.hidden_size = config.hidden_size
  60. self.intermediate_size = config.intermediate_size
  61. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  62. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  63. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  64. self.act_fn = ACT2FN[config.hidden_activation]
  65. def forward(self, x):
  66. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  67. return down_proj
  68. def rotate_half(x):
  69. """Rotates half the hidden dims of the input."""
  70. x1 = x[..., : x.shape[-1] // 2]
  71. x2 = x[..., x.shape[-1] // 2 :]
  72. return torch.cat((-x2, x1), dim=-1)
  73. @use_kernel_func_from_hub("rotary_pos_emb")
  74. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  75. """Applies Rotary Position Embedding to the query and key tensors.
  76. Args:
  77. q (`torch.Tensor`): The query tensor.
  78. k (`torch.Tensor`): The key tensor.
  79. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  80. sin (`torch.Tensor`): The sine part of the rotary embedding.
  81. unsqueeze_dim (`int`, *optional*, defaults to 1):
  82. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  83. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  84. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  85. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  86. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  87. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  88. Returns:
  89. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  90. """
  91. cos = cos.unsqueeze(unsqueeze_dim)
  92. sin = sin.unsqueeze(unsqueeze_dim)
  93. q_embed = (q * cos) + (rotate_half(q) * sin)
  94. k_embed = (k * cos) + (rotate_half(k) * sin)
  95. return q_embed, k_embed
  96. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  97. """
  98. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  99. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  100. """
  101. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  102. if n_rep == 1:
  103. return hidden_states
  104. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  105. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  106. def eager_attention_forward(
  107. module: nn.Module,
  108. query: torch.Tensor,
  109. key: torch.Tensor,
  110. value: torch.Tensor,
  111. attention_mask: torch.Tensor | None,
  112. dropout: float | int = 0.0,
  113. scaling: float | None = None,
  114. softcap: float | None = None,
  115. **kwargs,
  116. ) -> tuple[torch.Tensor, torch.Tensor]:
  117. if scaling is None:
  118. scaling = module.head_dim**-0.5
  119. key_states = repeat_kv(key, module.num_key_value_groups)
  120. value_states = repeat_kv(value, module.num_key_value_groups)
  121. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  122. if softcap is not None:
  123. attn_weights = attn_weights / softcap
  124. attn_weights = torch.tanh(attn_weights)
  125. attn_weights = attn_weights * softcap
  126. if attention_mask is not None:
  127. attn_weights = attn_weights + attention_mask
  128. # upcast attention to fp32
  129. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  130. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  131. attn_output = torch.matmul(attn_weights, value_states)
  132. attn_output = attn_output.transpose(1, 2).contiguous()
  133. return attn_output, attn_weights
  134. @use_kernelized_func(apply_rotary_pos_emb)
  135. class VaultGemmaAttention(nn.Module):
  136. """Multi-headed attention from 'Attention Is All You Need' paper"""
  137. def __init__(self, config: VaultGemmaConfig, layer_idx: int):
  138. super().__init__()
  139. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  140. self.config = config
  141. self.layer_idx = layer_idx
  142. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  143. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  144. self.scaling = config.query_pre_attn_scalar**-0.5
  145. self.attention_dropout = self.config.attention_dropout
  146. self.is_causal = True
  147. self.q_proj = nn.Linear(
  148. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  149. )
  150. self.k_proj = nn.Linear(
  151. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  152. )
  153. self.v_proj = nn.Linear(
  154. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  155. )
  156. self.o_proj = nn.Linear(
  157. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  158. )
  159. self.attn_logit_softcapping = self.config.attn_logit_softcapping
  160. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  161. def forward(
  162. self,
  163. hidden_states: torch.Tensor,
  164. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  165. attention_mask: torch.Tensor | None = None,
  166. past_key_values: Cache | None = None,
  167. **kwargs: Unpack[FlashAttentionKwargs],
  168. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  169. input_shape = hidden_states.shape[:-1]
  170. hidden_shape = (*input_shape, -1, self.head_dim)
  171. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  172. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  173. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  174. cos, sin = position_embeddings
  175. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  176. if past_key_values is not None:
  177. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  178. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  179. self.config._attn_implementation, eager_attention_forward
  180. )
  181. attn_output, attn_weights = attention_interface(
  182. self,
  183. query_states,
  184. key_states,
  185. value_states,
  186. attention_mask,
  187. dropout=self.attention_dropout if self.training else 0.0,
  188. scaling=self.scaling,
  189. sliding_window=self.sliding_window,
  190. softcap=self.attn_logit_softcapping,
  191. **kwargs,
  192. )
  193. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  194. attn_output = self.o_proj(attn_output)
  195. return attn_output, attn_weights
  196. class VaultGemmaDecoderLayer(GradientCheckpointingLayer):
  197. def __init__(self, config: VaultGemmaConfig, layer_idx: int):
  198. super().__init__()
  199. self.hidden_size = config.hidden_size
  200. self.config = config
  201. self.self_attn = VaultGemmaAttention(config=config, layer_idx=layer_idx)
  202. self.mlp = VaultGemmaMLP(config)
  203. self.input_layernorm = VaultGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  204. self.pre_feedforward_layernorm = VaultGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  205. def forward(
  206. self,
  207. hidden_states: torch.Tensor,
  208. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  209. attention_mask: torch.Tensor | None = None,
  210. position_ids: torch.LongTensor | None = None,
  211. past_key_values: Cache | None = None,
  212. **kwargs,
  213. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  214. residual = hidden_states
  215. hidden_states = self.input_layernorm(hidden_states)
  216. # Self Attention
  217. hidden_states, _ = self.self_attn(
  218. hidden_states=hidden_states,
  219. position_embeddings=position_embeddings,
  220. attention_mask=attention_mask,
  221. position_ids=position_ids,
  222. past_key_values=past_key_values,
  223. **kwargs,
  224. )
  225. hidden_states = residual + hidden_states
  226. residual = hidden_states
  227. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  228. hidden_states = self.mlp(hidden_states)
  229. hidden_states = residual + hidden_states
  230. return hidden_states
  231. class VaultGemmaRotaryEmbedding(nn.Module):
  232. inv_freq: torch.Tensor # fix linting for `register_buffer`
  233. def __init__(self, config: VaultGemmaConfig, device=None):
  234. super().__init__()
  235. self.max_seq_len_cached = config.max_position_embeddings
  236. self.original_max_seq_len = config.max_position_embeddings
  237. self.config = config
  238. self.rope_type = self.config.rope_parameters["rope_type"]
  239. rope_init_fn: Callable = self.compute_default_rope_parameters
  240. if self.rope_type != "default":
  241. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  242. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  243. self.register_buffer("inv_freq", inv_freq, persistent=False)
  244. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  245. @staticmethod
  246. def compute_default_rope_parameters(
  247. config: VaultGemmaConfig | None = None,
  248. device: Optional["torch.device"] = None,
  249. seq_len: int | None = None,
  250. ) -> tuple["torch.Tensor", float]:
  251. """
  252. Computes the inverse frequencies according to the original RoPE implementation
  253. Args:
  254. config ([`~transformers.PreTrainedConfig`]):
  255. The model configuration.
  256. device (`torch.device`):
  257. The device to use for initialization of the inverse frequencies.
  258. seq_len (`int`, *optional*):
  259. The current sequence length. Unused for this type of RoPE.
  260. Returns:
  261. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  262. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  263. """
  264. base = config.rope_parameters["rope_theta"]
  265. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  266. attention_factor = 1.0 # Unused in this type of RoPE
  267. # Compute the inverse frequencies
  268. inv_freq = 1.0 / (
  269. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  270. )
  271. return inv_freq, attention_factor
  272. @torch.no_grad()
  273. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  274. def forward(self, x, position_ids):
  275. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  276. position_ids_expanded = position_ids[:, None, :].float()
  277. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  278. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  279. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  280. emb = torch.cat((freqs, freqs), dim=-1)
  281. cos = emb.cos() * self.attention_scaling
  282. sin = emb.sin() * self.attention_scaling
  283. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  284. class VaultGemmaTextScaledWordEmbedding(nn.Embedding):
  285. """
  286. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  287. """
  288. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
  289. super().__init__(num_embeddings, embedding_dim, padding_idx)
  290. self.scalar_embed_scale = embed_scale
  291. self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
  292. def forward(self, input_ids: torch.Tensor):
  293. return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
  294. @auto_docstring
  295. class VaultGemmaPreTrainedModel(PreTrainedModel):
  296. config: VaultGemmaConfig
  297. base_model_prefix = "model"
  298. supports_gradient_checkpointing = True
  299. _no_split_modules = ["VaultGemmaDecoderLayer"]
  300. _skip_keys_device_placement = ["past_key_values"]
  301. _supports_flash_attn = True
  302. _supports_sdpa = True
  303. _supports_flex_attn = True
  304. _can_compile_fullgraph = True
  305. _supports_attention_backend = True
  306. _can_record_outputs = {
  307. "hidden_states": VaultGemmaDecoderLayer,
  308. "attentions": VaultGemmaAttention,
  309. }
  310. @torch.no_grad()
  311. def _init_weights(self, module):
  312. super()._init_weights(module)
  313. # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
  314. if "RMSNorm" in module.__class__.__name__:
  315. init.zeros_(module.weight)
  316. elif isinstance(module, VaultGemmaTextScaledWordEmbedding):
  317. init.constant_(module.embed_scale, module.scalar_embed_scale)
  318. @auto_docstring
  319. class VaultGemmaModel(VaultGemmaPreTrainedModel):
  320. def __init__(self, config: VaultGemmaConfig):
  321. super().__init__(config)
  322. self.padding_idx = config.pad_token_id
  323. self.vocab_size = config.vocab_size
  324. # VaultGemma3 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
  325. self.embed_tokens = VaultGemmaTextScaledWordEmbedding(
  326. config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
  327. )
  328. self.layers = nn.ModuleList(
  329. [VaultGemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  330. )
  331. self.norm = VaultGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  332. self.rotary_emb = VaultGemmaRotaryEmbedding(config)
  333. self.gradient_checkpointing = False
  334. # Initialize weights and apply final processing
  335. self.post_init()
  336. @merge_with_config_defaults
  337. @capture_outputs
  338. @auto_docstring
  339. def forward(
  340. self,
  341. input_ids: torch.LongTensor | None = None,
  342. attention_mask: torch.Tensor | None = None,
  343. position_ids: torch.LongTensor | None = None,
  344. past_key_values: Cache | None = None,
  345. inputs_embeds: torch.FloatTensor | None = None,
  346. use_cache: bool | None = None,
  347. **kwargs: Unpack[TransformersKwargs],
  348. ) -> BaseModelOutputWithPast:
  349. if (input_ids is None) ^ (inputs_embeds is not None):
  350. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  351. if inputs_embeds is None:
  352. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  353. if use_cache and past_key_values is None:
  354. past_key_values = DynamicCache(config=self.config)
  355. if position_ids is None:
  356. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  357. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  358. position_ids = position_ids.unsqueeze(0)
  359. # It may already have been prepared by e.g. `generate`
  360. if not isinstance(causal_mask_mapping := attention_mask, dict):
  361. # Prepare mask arguments
  362. mask_kwargs = {
  363. "config": self.config,
  364. "inputs_embeds": inputs_embeds,
  365. "attention_mask": attention_mask,
  366. "past_key_values": past_key_values,
  367. "position_ids": position_ids,
  368. }
  369. # Create the masks
  370. causal_mask_mapping = {
  371. "full_attention": create_causal_mask(**mask_kwargs),
  372. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  373. }
  374. # embed positions
  375. hidden_states = inputs_embeds
  376. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  377. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  378. hidden_states = decoder_layer(
  379. hidden_states,
  380. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  381. position_embeddings=position_embeddings,
  382. position_ids=position_ids,
  383. past_key_values=past_key_values,
  384. **kwargs,
  385. )
  386. hidden_states = self.norm(hidden_states)
  387. return BaseModelOutputWithPast(
  388. last_hidden_state=hidden_states,
  389. past_key_values=past_key_values,
  390. )
  391. @auto_docstring
  392. class VaultGemmaForCausalLM(VaultGemmaPreTrainedModel, GenerationMixin):
  393. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  394. _tp_plan = {"lm_head": "colwise_gather_output"}
  395. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  396. def __init__(self, config):
  397. super().__init__(config)
  398. self.model = VaultGemmaModel(config)
  399. self.vocab_size = config.vocab_size
  400. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  401. # Initialize weights and apply final processing
  402. self.post_init()
  403. @can_return_tuple
  404. @auto_docstring
  405. def forward(
  406. self,
  407. input_ids: torch.LongTensor | None = None,
  408. attention_mask: torch.Tensor | None = None,
  409. position_ids: torch.LongTensor | None = None,
  410. past_key_values: Cache | None = None,
  411. inputs_embeds: torch.FloatTensor | None = None,
  412. labels: torch.LongTensor | None = None,
  413. use_cache: bool | None = None,
  414. logits_to_keep: int | torch.Tensor = 0,
  415. **kwargs: Unpack[TransformersKwargs],
  416. ) -> CausalLMOutputWithPast:
  417. r"""
  418. Example:
  419. ```python
  420. >>> from transformers import AutoTokenizer, VaultGemmaForCausalLM
  421. >>> model = VaultGemmaForCausalLM.from_pretrained("google/gemma-2-9b")
  422. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
  423. >>> prompt = "What is your favorite condiment?"
  424. >>> inputs = tokenizer(prompt, return_tensors="pt")
  425. >>> # Generate
  426. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  427. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  428. "What is your favorite condiment?"
  429. ```"""
  430. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  431. outputs: BaseModelOutputWithPast = self.model(
  432. input_ids=input_ids,
  433. attention_mask=attention_mask,
  434. position_ids=position_ids,
  435. past_key_values=past_key_values,
  436. inputs_embeds=inputs_embeds,
  437. use_cache=use_cache,
  438. **kwargs,
  439. )
  440. hidden_states = outputs.last_hidden_state
  441. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  442. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  443. logits = self.lm_head(hidden_states[:, slice_indices, :])
  444. if self.config.final_logit_softcapping is not None:
  445. logits = logits / self.config.final_logit_softcapping
  446. logits = torch.tanh(logits)
  447. logits = logits * self.config.final_logit_softcapping
  448. loss = None
  449. if labels is not None:
  450. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  451. return CausalLMOutputWithPast(
  452. loss=loss,
  453. logits=logits,
  454. past_key_values=outputs.past_key_values,
  455. hidden_states=outputs.hidden_states,
  456. attentions=outputs.attentions,
  457. )
  458. __all__ = ["VaultGemmaForCausalLM", "VaultGemmaModel", "VaultGemmaPreTrainedModel"]