modeling_gemma2.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/gemma2/modular_gemma2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_gemma2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. from typing import Optional
  23. import torch
  24. import torch.nn as nn
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernel_func_from_hub, use_kernelized_func
  30. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  31. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  32. from ...modeling_layers import (
  33. GenericForSequenceClassification,
  34. GenericForTokenClassification,
  35. GradientCheckpointingLayer,
  36. )
  37. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  38. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  39. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  40. from ...processing_utils import Unpack
  41. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  42. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  43. from ...utils.output_capturing import capture_outputs
  44. from .configuration_gemma2 import Gemma2Config
  45. class Gemma2RMSNorm(nn.Module):
  46. def __init__(self, dim: int, eps: float = 1e-6):
  47. super().__init__()
  48. self.eps = eps
  49. self.weight = nn.Parameter(torch.zeros(dim))
  50. def _norm(self, x):
  51. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  52. def forward(self, x):
  53. output = self._norm(x.float())
  54. # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
  55. # See https://github.com/huggingface/transformers/pull/29402
  56. output = output * (1.0 + self.weight.float())
  57. return output.type_as(x)
  58. def extra_repr(self):
  59. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  60. class Gemma2MLP(nn.Module):
  61. def __init__(self, config):
  62. super().__init__()
  63. self.config = config
  64. self.hidden_size = config.hidden_size
  65. self.intermediate_size = config.intermediate_size
  66. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  67. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  68. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  69. self.act_fn = ACT2FN[config.hidden_activation]
  70. def forward(self, x):
  71. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  72. return down_proj
  73. class Gemma2RotaryEmbedding(nn.Module):
  74. inv_freq: torch.Tensor # fix linting for `register_buffer`
  75. def __init__(self, config: Gemma2Config, device=None):
  76. super().__init__()
  77. self.max_seq_len_cached = config.max_position_embeddings
  78. self.original_max_seq_len = config.max_position_embeddings
  79. self.config = config
  80. self.rope_type = self.config.rope_parameters["rope_type"]
  81. rope_init_fn: Callable = self.compute_default_rope_parameters
  82. if self.rope_type != "default":
  83. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  84. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  85. self.register_buffer("inv_freq", inv_freq, persistent=False)
  86. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  87. @staticmethod
  88. def compute_default_rope_parameters(
  89. config: Gemma2Config | None = None,
  90. device: Optional["torch.device"] = None,
  91. seq_len: int | None = None,
  92. ) -> tuple["torch.Tensor", float]:
  93. """
  94. Computes the inverse frequencies according to the original RoPE implementation
  95. Args:
  96. config ([`~transformers.PreTrainedConfig`]):
  97. The model configuration.
  98. device (`torch.device`):
  99. The device to use for initialization of the inverse frequencies.
  100. seq_len (`int`, *optional*):
  101. The current sequence length. Unused for this type of RoPE.
  102. Returns:
  103. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  104. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  105. """
  106. base = config.rope_parameters["rope_theta"]
  107. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  108. attention_factor = 1.0 # Unused in this type of RoPE
  109. # Compute the inverse frequencies
  110. inv_freq = 1.0 / (
  111. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  112. )
  113. return inv_freq, attention_factor
  114. @torch.no_grad()
  115. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  116. def forward(self, x, position_ids):
  117. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  118. position_ids_expanded = position_ids[:, None, :].float()
  119. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  120. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  121. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  122. emb = torch.cat((freqs, freqs), dim=-1)
  123. cos = emb.cos() * self.attention_scaling
  124. sin = emb.sin() * self.attention_scaling
  125. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  126. def rotate_half(x):
  127. """Rotates half the hidden dims of the input."""
  128. x1 = x[..., : x.shape[-1] // 2]
  129. x2 = x[..., x.shape[-1] // 2 :]
  130. return torch.cat((-x2, x1), dim=-1)
  131. @use_kernel_func_from_hub("rotary_pos_emb")
  132. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  133. """Applies Rotary Position Embedding to the query and key tensors.
  134. Args:
  135. q (`torch.Tensor`): The query tensor.
  136. k (`torch.Tensor`): The key tensor.
  137. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  138. sin (`torch.Tensor`): The sine part of the rotary embedding.
  139. unsqueeze_dim (`int`, *optional*, defaults to 1):
  140. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  141. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  142. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  143. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  144. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  145. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  146. Returns:
  147. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  148. """
  149. cos = cos.unsqueeze(unsqueeze_dim)
  150. sin = sin.unsqueeze(unsqueeze_dim)
  151. q_embed = (q * cos) + (rotate_half(q) * sin)
  152. k_embed = (k * cos) + (rotate_half(k) * sin)
  153. return q_embed, k_embed
  154. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  155. """
  156. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  157. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  158. """
  159. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  160. if n_rep == 1:
  161. return hidden_states
  162. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  163. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  164. def eager_attention_forward(
  165. module: nn.Module,
  166. query: torch.Tensor,
  167. key: torch.Tensor,
  168. value: torch.Tensor,
  169. attention_mask: torch.Tensor | None,
  170. dropout: float | int = 0.0,
  171. scaling: float | None = None,
  172. softcap: float | None = None,
  173. **kwargs,
  174. ) -> tuple[torch.Tensor, torch.Tensor]:
  175. if scaling is None:
  176. scaling = module.head_dim**-0.5
  177. key_states = repeat_kv(key, module.num_key_value_groups)
  178. value_states = repeat_kv(value, module.num_key_value_groups)
  179. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  180. if softcap is not None:
  181. attn_weights = attn_weights / softcap
  182. attn_weights = torch.tanh(attn_weights)
  183. attn_weights = attn_weights * softcap
  184. if attention_mask is not None:
  185. attn_weights = attn_weights + attention_mask
  186. # upcast attention to fp32
  187. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  188. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  189. attn_output = torch.matmul(attn_weights, value_states)
  190. attn_output = attn_output.transpose(1, 2).contiguous()
  191. return attn_output, attn_weights
  192. @use_kernelized_func(apply_rotary_pos_emb)
  193. class Gemma2Attention(nn.Module):
  194. """Multi-headed attention from 'Attention Is All You Need' paper"""
  195. def __init__(self, config: Gemma2Config, layer_idx: int):
  196. super().__init__()
  197. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  198. self.config = config
  199. self.layer_idx = layer_idx
  200. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  201. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  202. self.scaling = config.query_pre_attn_scalar**-0.5
  203. self.attention_dropout = self.config.attention_dropout
  204. self.is_causal = not getattr(config, "use_bidirectional_attention", False)
  205. self.q_proj = nn.Linear(
  206. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  207. )
  208. self.k_proj = nn.Linear(
  209. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  210. )
  211. self.v_proj = nn.Linear(
  212. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  213. )
  214. self.o_proj = nn.Linear(
  215. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  216. )
  217. self.attn_logit_softcapping = self.config.attn_logit_softcapping
  218. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  219. def forward(
  220. self,
  221. hidden_states: torch.Tensor,
  222. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  223. attention_mask: torch.Tensor | None = None,
  224. past_key_values: Cache | None = None,
  225. **kwargs: Unpack[FlashAttentionKwargs],
  226. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  227. input_shape = hidden_states.shape[:-1]
  228. hidden_shape = (*input_shape, -1, self.head_dim)
  229. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  230. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  231. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  232. cos, sin = position_embeddings
  233. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  234. if past_key_values is not None:
  235. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  236. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  237. self.config._attn_implementation, eager_attention_forward
  238. )
  239. attn_output, attn_weights = attention_interface(
  240. self,
  241. query_states,
  242. key_states,
  243. value_states,
  244. attention_mask,
  245. dropout=self.attention_dropout if self.training else 0.0,
  246. scaling=self.scaling,
  247. sliding_window=self.sliding_window,
  248. softcap=self.attn_logit_softcapping,
  249. **kwargs,
  250. )
  251. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  252. attn_output = self.o_proj(attn_output)
  253. return attn_output, attn_weights
  254. class Gemma2DecoderLayer(GradientCheckpointingLayer):
  255. def __init__(self, config: Gemma2Config, layer_idx: int):
  256. super().__init__()
  257. self.hidden_size = config.hidden_size
  258. self.config = config
  259. self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
  260. self.mlp = Gemma2MLP(config)
  261. self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  262. self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  263. self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  264. self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  265. def forward(
  266. self,
  267. hidden_states: torch.Tensor,
  268. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  269. attention_mask: torch.Tensor | None = None,
  270. position_ids: torch.LongTensor | None = None,
  271. past_key_values: Cache | None = None,
  272. **kwargs,
  273. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  274. residual = hidden_states
  275. hidden_states = self.input_layernorm(hidden_states)
  276. # Self Attention
  277. hidden_states, _ = self.self_attn(
  278. hidden_states=hidden_states,
  279. position_embeddings=position_embeddings,
  280. attention_mask=attention_mask,
  281. position_ids=position_ids,
  282. past_key_values=past_key_values,
  283. **kwargs,
  284. )
  285. hidden_states = self.post_attention_layernorm(hidden_states)
  286. hidden_states = residual + hidden_states
  287. residual = hidden_states
  288. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  289. hidden_states = self.mlp(hidden_states)
  290. hidden_states = self.post_feedforward_layernorm(hidden_states)
  291. hidden_states = residual + hidden_states
  292. return hidden_states
  293. class Gemma2TextScaledWordEmbedding(nn.Embedding):
  294. """
  295. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  296. """
  297. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
  298. super().__init__(num_embeddings, embedding_dim, padding_idx)
  299. self.scalar_embed_scale = embed_scale
  300. self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
  301. def forward(self, input_ids: torch.Tensor):
  302. return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
  303. @auto_docstring
  304. class Gemma2PreTrainedModel(PreTrainedModel):
  305. config: Gemma2Config
  306. base_model_prefix = "model"
  307. supports_gradient_checkpointing = True
  308. _no_split_modules = ["Gemma2DecoderLayer"]
  309. _skip_keys_device_placement = ["past_key_values"]
  310. _supports_flash_attn = True
  311. _supports_sdpa = True
  312. _supports_flex_attn = True
  313. _can_compile_fullgraph = True
  314. _supports_attention_backend = True
  315. _can_record_outputs = {
  316. "hidden_states": Gemma2DecoderLayer,
  317. "attentions": Gemma2Attention,
  318. }
  319. @torch.no_grad()
  320. def _init_weights(self, module):
  321. super()._init_weights(module)
  322. # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
  323. if "RMSNorm" in module.__class__.__name__:
  324. init.zeros_(module.weight)
  325. elif isinstance(module, Gemma2TextScaledWordEmbedding):
  326. init.constant_(module.embed_scale, module.scalar_embed_scale)
  327. @auto_docstring
  328. class Gemma2Model(Gemma2PreTrainedModel):
  329. def __init__(self, config: Gemma2Config):
  330. super().__init__(config)
  331. self.padding_idx = config.pad_token_id
  332. self.vocab_size = config.vocab_size
  333. # Gemma23 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
  334. self.embed_tokens = Gemma2TextScaledWordEmbedding(
  335. config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
  336. )
  337. self.layers = nn.ModuleList(
  338. [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  339. )
  340. self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  341. self.rotary_emb = Gemma2RotaryEmbedding(config)
  342. self.gradient_checkpointing = False
  343. # Initialize weights and apply final processing
  344. self.post_init()
  345. @merge_with_config_defaults
  346. @capture_outputs
  347. @auto_docstring
  348. def forward(
  349. self,
  350. input_ids: torch.LongTensor | None = None,
  351. attention_mask: torch.Tensor | None = None,
  352. position_ids: torch.LongTensor | None = None,
  353. past_key_values: Cache | None = None,
  354. inputs_embeds: torch.FloatTensor | None = None,
  355. use_cache: bool | None = None,
  356. **kwargs: Unpack[TransformersKwargs],
  357. ) -> BaseModelOutputWithPast:
  358. if (input_ids is None) ^ (inputs_embeds is not None):
  359. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  360. if inputs_embeds is None:
  361. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  362. if use_cache and past_key_values is None:
  363. past_key_values = DynamicCache(config=self.config)
  364. if position_ids is None:
  365. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  366. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  367. position_ids = position_ids.unsqueeze(0)
  368. # It may already have been prepared by e.g. `generate`
  369. if not isinstance(causal_mask_mapping := attention_mask, dict):
  370. # Prepare mask arguments
  371. mask_kwargs = {
  372. "config": self.config,
  373. "inputs_embeds": inputs_embeds,
  374. "attention_mask": attention_mask,
  375. "past_key_values": past_key_values,
  376. "position_ids": position_ids,
  377. }
  378. # Create the masks
  379. causal_mask_mapping = {
  380. "full_attention": create_causal_mask(**mask_kwargs),
  381. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  382. }
  383. # embed positions
  384. hidden_states = inputs_embeds
  385. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  386. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  387. hidden_states = decoder_layer(
  388. hidden_states,
  389. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  390. position_embeddings=position_embeddings,
  391. position_ids=position_ids,
  392. past_key_values=past_key_values,
  393. **kwargs,
  394. )
  395. hidden_states = self.norm(hidden_states)
  396. return BaseModelOutputWithPast(
  397. last_hidden_state=hidden_states,
  398. past_key_values=past_key_values,
  399. )
  400. @auto_docstring
  401. class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
  402. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  403. _tp_plan = {"lm_head": "colwise_gather_output"}
  404. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  405. def __init__(self, config):
  406. super().__init__(config)
  407. self.model = Gemma2Model(config)
  408. self.vocab_size = config.vocab_size
  409. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  410. # Initialize weights and apply final processing
  411. self.post_init()
  412. @can_return_tuple
  413. @auto_docstring
  414. def forward(
  415. self,
  416. input_ids: torch.LongTensor | None = None,
  417. attention_mask: torch.Tensor | None = None,
  418. position_ids: torch.LongTensor | None = None,
  419. past_key_values: Cache | None = None,
  420. inputs_embeds: torch.FloatTensor | None = None,
  421. labels: torch.LongTensor | None = None,
  422. use_cache: bool | None = None,
  423. logits_to_keep: int | torch.Tensor = 0,
  424. **kwargs: Unpack[TransformersKwargs],
  425. ) -> CausalLMOutputWithPast:
  426. r"""
  427. Example:
  428. ```python
  429. >>> from transformers import AutoTokenizer, Gemma2ForCausalLM
  430. >>> model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b")
  431. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
  432. >>> prompt = "What is your favorite condiment?"
  433. >>> inputs = tokenizer(prompt, return_tensors="pt")
  434. >>> # Generate
  435. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  436. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  437. "What is your favorite condiment?"
  438. ```"""
  439. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  440. outputs: BaseModelOutputWithPast = self.model(
  441. input_ids=input_ids,
  442. attention_mask=attention_mask,
  443. position_ids=position_ids,
  444. past_key_values=past_key_values,
  445. inputs_embeds=inputs_embeds,
  446. use_cache=use_cache,
  447. **kwargs,
  448. )
  449. hidden_states = outputs.last_hidden_state
  450. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  451. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  452. logits = self.lm_head(hidden_states[:, slice_indices, :])
  453. if self.config.final_logit_softcapping is not None:
  454. logits = logits / self.config.final_logit_softcapping
  455. logits = torch.tanh(logits)
  456. logits = logits * self.config.final_logit_softcapping
  457. loss = None
  458. if labels is not None:
  459. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  460. return CausalLMOutputWithPast(
  461. loss=loss,
  462. logits=logits,
  463. past_key_values=outputs.past_key_values,
  464. hidden_states=outputs.hidden_states,
  465. attentions=outputs.attentions,
  466. )
  467. class Gemma2ForSequenceClassification(GenericForSequenceClassification, Gemma2PreTrainedModel):
  468. pass
  469. class Gemma2ForTokenClassification(GenericForTokenClassification, Gemma2PreTrainedModel):
  470. pass
  471. __all__ = [
  472. "Gemma2ForCausalLM",
  473. "Gemma2Model",
  474. "Gemma2PreTrainedModel",
  475. "Gemma2ForSequenceClassification",
  476. "Gemma2ForTokenClassification",
  477. ]