modeling_gemma.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_gemma.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. from typing import Optional
  23. import torch
  24. from torch import nn
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernel_func_from_hub, use_kernelized_func
  30. from ...masking_utils import create_causal_mask
  31. from ...modeling_layers import (
  32. GenericForSequenceClassification,
  33. GenericForTokenClassification,
  34. GradientCheckpointingLayer,
  35. )
  36. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  41. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  42. from ...utils.output_capturing import capture_outputs
  43. from .configuration_gemma import GemmaConfig
  44. class GemmaTextScaledWordEmbedding(nn.Embedding):
  45. """
  46. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  47. """
  48. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
  49. super().__init__(num_embeddings, embedding_dim, padding_idx)
  50. self.scalar_embed_scale = embed_scale
  51. self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
  52. def forward(self, input_ids: torch.Tensor):
  53. return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
  54. class GemmaRMSNorm(nn.Module):
  55. def __init__(self, dim: int, eps: float = 1e-6):
  56. super().__init__()
  57. self.eps = eps
  58. self.weight = nn.Parameter(torch.zeros(dim))
  59. def _norm(self, x):
  60. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  61. def forward(self, x):
  62. output = self._norm(x.float())
  63. # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
  64. # See https://github.com/huggingface/transformers/pull/29402
  65. output = output * (1.0 + self.weight.float())
  66. return output.type_as(x)
  67. def extra_repr(self):
  68. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  69. class GemmaMLP(nn.Module):
  70. def __init__(self, config):
  71. super().__init__()
  72. self.config = config
  73. self.hidden_size = config.hidden_size
  74. self.intermediate_size = config.intermediate_size
  75. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  76. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  77. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  78. self.act_fn = ACT2FN[config.hidden_act]
  79. def forward(self, x):
  80. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  81. return down_proj
  82. class GemmaRotaryEmbedding(nn.Module):
  83. inv_freq: torch.Tensor # fix linting for `register_buffer`
  84. def __init__(self, config: GemmaConfig, device=None):
  85. super().__init__()
  86. self.max_seq_len_cached = config.max_position_embeddings
  87. self.original_max_seq_len = config.max_position_embeddings
  88. self.config = config
  89. self.rope_type = self.config.rope_parameters["rope_type"]
  90. rope_init_fn: Callable = self.compute_default_rope_parameters
  91. if self.rope_type != "default":
  92. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  93. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  94. self.register_buffer("inv_freq", inv_freq, persistent=False)
  95. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  96. @staticmethod
  97. def compute_default_rope_parameters(
  98. config: GemmaConfig | None = None,
  99. device: Optional["torch.device"] = None,
  100. seq_len: int | None = None,
  101. ) -> tuple["torch.Tensor", float]:
  102. """
  103. Computes the inverse frequencies according to the original RoPE implementation
  104. Args:
  105. config ([`~transformers.PreTrainedConfig`]):
  106. The model configuration.
  107. device (`torch.device`):
  108. The device to use for initialization of the inverse frequencies.
  109. seq_len (`int`, *optional*):
  110. The current sequence length. Unused for this type of RoPE.
  111. Returns:
  112. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  113. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  114. """
  115. base = config.rope_parameters["rope_theta"]
  116. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  117. attention_factor = 1.0 # Unused in this type of RoPE
  118. # Compute the inverse frequencies
  119. inv_freq = 1.0 / (
  120. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  121. )
  122. return inv_freq, attention_factor
  123. @torch.no_grad()
  124. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  125. def forward(self, x, position_ids):
  126. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  127. position_ids_expanded = position_ids[:, None, :].float()
  128. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  129. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  130. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  131. emb = torch.cat((freqs, freqs), dim=-1)
  132. cos = emb.cos() * self.attention_scaling
  133. sin = emb.sin() * self.attention_scaling
  134. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  135. def rotate_half(x):
  136. """Rotates half the hidden dims of the input."""
  137. x1 = x[..., : x.shape[-1] // 2]
  138. x2 = x[..., x.shape[-1] // 2 :]
  139. return torch.cat((-x2, x1), dim=-1)
  140. @use_kernel_func_from_hub("rotary_pos_emb")
  141. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  142. """Applies Rotary Position Embedding to the query and key tensors.
  143. Args:
  144. q (`torch.Tensor`): The query tensor.
  145. k (`torch.Tensor`): The key tensor.
  146. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  147. sin (`torch.Tensor`): The sine part of the rotary embedding.
  148. unsqueeze_dim (`int`, *optional*, defaults to 1):
  149. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  150. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  151. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  152. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  153. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  154. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  155. Returns:
  156. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  157. """
  158. cos = cos.unsqueeze(unsqueeze_dim)
  159. sin = sin.unsqueeze(unsqueeze_dim)
  160. q_embed = (q * cos) + (rotate_half(q) * sin)
  161. k_embed = (k * cos) + (rotate_half(k) * sin)
  162. return q_embed, k_embed
  163. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  164. """
  165. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  166. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  167. """
  168. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  169. if n_rep == 1:
  170. return hidden_states
  171. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  172. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  173. def eager_attention_forward(
  174. module: nn.Module,
  175. query: torch.Tensor,
  176. key: torch.Tensor,
  177. value: torch.Tensor,
  178. attention_mask: torch.Tensor | None,
  179. scaling: float,
  180. dropout: float = 0.0,
  181. **kwargs: Unpack[TransformersKwargs],
  182. ):
  183. key_states = repeat_kv(key, module.num_key_value_groups)
  184. value_states = repeat_kv(value, module.num_key_value_groups)
  185. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  186. if attention_mask is not None:
  187. attn_weights = attn_weights + attention_mask
  188. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  189. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  190. attn_output = torch.matmul(attn_weights, value_states)
  191. attn_output = attn_output.transpose(1, 2).contiguous()
  192. return attn_output, attn_weights
  193. @use_kernelized_func(apply_rotary_pos_emb)
  194. class GemmaAttention(nn.Module):
  195. """Multi-headed attention from 'Attention Is All You Need' paper"""
  196. def __init__(self, config: GemmaConfig, layer_idx: int):
  197. super().__init__()
  198. self.config = config
  199. self.layer_idx = layer_idx
  200. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  201. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  202. self.scaling = self.head_dim**-0.5
  203. self.attention_dropout = config.attention_dropout
  204. self.is_causal = not getattr(config, "use_bidirectional_attention", False)
  205. self.q_proj = nn.Linear(
  206. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  207. )
  208. self.k_proj = nn.Linear(
  209. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  210. )
  211. self.v_proj = nn.Linear(
  212. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  213. )
  214. self.o_proj = nn.Linear(
  215. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  216. )
  217. def forward(
  218. self,
  219. hidden_states: torch.Tensor,
  220. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  221. attention_mask: torch.Tensor | None = None,
  222. past_key_values: Cache | None = None,
  223. **kwargs: Unpack[TransformersKwargs],
  224. ) -> tuple[torch.Tensor, torch.Tensor]:
  225. input_shape = hidden_states.shape[:-1]
  226. hidden_shape = (*input_shape, -1, self.head_dim)
  227. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  228. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  229. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  230. cos, sin = position_embeddings
  231. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  232. if past_key_values is not None:
  233. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  234. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  235. self.config._attn_implementation, eager_attention_forward
  236. )
  237. attn_output, attn_weights = attention_interface(
  238. self,
  239. query_states,
  240. key_states,
  241. value_states,
  242. attention_mask,
  243. dropout=0.0 if not self.training else self.attention_dropout,
  244. scaling=self.scaling,
  245. **kwargs,
  246. )
  247. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  248. attn_output = self.o_proj(attn_output)
  249. return attn_output, attn_weights
  250. class GemmaDecoderLayer(GradientCheckpointingLayer):
  251. def __init__(self, config: GemmaConfig, layer_idx: int):
  252. super().__init__()
  253. self.hidden_size = config.hidden_size
  254. self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
  255. self.mlp = GemmaMLP(config)
  256. self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  257. self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  258. def forward(
  259. self,
  260. hidden_states: torch.Tensor,
  261. attention_mask: torch.Tensor | None = None,
  262. position_ids: torch.LongTensor | None = None,
  263. past_key_values: Cache | None = None,
  264. use_cache: bool | None = False,
  265. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  266. **kwargs: Unpack[TransformersKwargs],
  267. ) -> torch.Tensor:
  268. residual = hidden_states
  269. hidden_states = self.input_layernorm(hidden_states)
  270. # Self Attention
  271. hidden_states, _ = self.self_attn(
  272. hidden_states=hidden_states,
  273. attention_mask=attention_mask,
  274. position_ids=position_ids,
  275. past_key_values=past_key_values,
  276. use_cache=use_cache,
  277. position_embeddings=position_embeddings,
  278. **kwargs,
  279. )
  280. hidden_states = residual + hidden_states
  281. # Fully Connected
  282. residual = hidden_states
  283. hidden_states = self.post_attention_layernorm(hidden_states)
  284. hidden_states = self.mlp(hidden_states)
  285. hidden_states = residual + hidden_states
  286. return hidden_states
  287. @auto_docstring
  288. class GemmaPreTrainedModel(PreTrainedModel):
  289. config: GemmaConfig
  290. base_model_prefix = "model"
  291. supports_gradient_checkpointing = True
  292. _no_split_modules = ["GemmaDecoderLayer"]
  293. _skip_keys_device_placement = ["past_key_values"]
  294. _supports_flash_attn = True
  295. _supports_sdpa = True
  296. _supports_flex_attn = True
  297. _can_compile_fullgraph = True
  298. _supports_attention_backend = True
  299. _can_record_outputs = {
  300. "hidden_states": GemmaDecoderLayer,
  301. "attentions": GemmaAttention,
  302. }
  303. @torch.no_grad()
  304. def _init_weights(self, module):
  305. super()._init_weights(module)
  306. # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
  307. if "RMSNorm" in module.__class__.__name__:
  308. init.zeros_(module.weight)
  309. elif isinstance(module, GemmaTextScaledWordEmbedding):
  310. init.constant_(module.embed_scale, module.scalar_embed_scale)
  311. @auto_docstring
  312. class GemmaModel(GemmaPreTrainedModel):
  313. def __init__(self, config: GemmaConfig):
  314. super().__init__(config)
  315. self.padding_idx = config.pad_token_id
  316. self.vocab_size = config.vocab_size
  317. # Gemma3 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
  318. self.embed_tokens = GemmaTextScaledWordEmbedding(
  319. config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
  320. )
  321. self.layers = nn.ModuleList(
  322. [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  323. )
  324. self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  325. self.rotary_emb = GemmaRotaryEmbedding(config=config)
  326. self.gradient_checkpointing = False
  327. # Initialize weights and apply final processing
  328. self.post_init()
  329. @merge_with_config_defaults
  330. @capture_outputs
  331. @auto_docstring
  332. def forward(
  333. self,
  334. input_ids: torch.LongTensor | None = None,
  335. attention_mask: torch.Tensor | None = None,
  336. position_ids: torch.LongTensor | None = None,
  337. past_key_values: Cache | None = None,
  338. inputs_embeds: torch.FloatTensor | None = None,
  339. use_cache: bool | None = None,
  340. **kwargs: Unpack[TransformersKwargs],
  341. ) -> BaseModelOutputWithPast:
  342. if (input_ids is None) ^ (inputs_embeds is not None):
  343. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  344. if inputs_embeds is None:
  345. inputs_embeds = self.embed_tokens(input_ids)
  346. if use_cache and past_key_values is None:
  347. past_key_values = DynamicCache(config=self.config)
  348. if position_ids is None:
  349. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  350. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  351. position_ids = position_ids.unsqueeze(0)
  352. causal_mask = create_causal_mask(
  353. config=self.config,
  354. inputs_embeds=inputs_embeds,
  355. attention_mask=attention_mask,
  356. past_key_values=past_key_values,
  357. position_ids=position_ids,
  358. )
  359. # embed positions
  360. hidden_states = inputs_embeds
  361. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  362. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  363. hidden_states = decoder_layer(
  364. hidden_states,
  365. attention_mask=causal_mask,
  366. position_ids=position_ids,
  367. past_key_values=past_key_values,
  368. use_cache=use_cache,
  369. position_embeddings=position_embeddings,
  370. **kwargs,
  371. )
  372. hidden_states = self.norm(hidden_states)
  373. return BaseModelOutputWithPast(
  374. last_hidden_state=hidden_states,
  375. past_key_values=past_key_values if use_cache else None,
  376. )
  377. @auto_docstring
  378. class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
  379. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  380. _tp_plan = {"lm_head": "colwise_gather_output"}
  381. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  382. def __init__(self, config):
  383. super().__init__(config)
  384. self.model = GemmaModel(config)
  385. self.vocab_size = config.vocab_size
  386. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  387. # Initialize weights and apply final processing
  388. self.post_init()
  389. @can_return_tuple
  390. @auto_docstring
  391. def forward(
  392. self,
  393. input_ids: torch.LongTensor | None = None,
  394. attention_mask: torch.Tensor | None = None,
  395. position_ids: torch.LongTensor | None = None,
  396. past_key_values: Cache | None = None,
  397. inputs_embeds: torch.FloatTensor | None = None,
  398. labels: torch.LongTensor | None = None,
  399. use_cache: bool | None = None,
  400. logits_to_keep: int | torch.Tensor = 0,
  401. **kwargs: Unpack[TransformersKwargs],
  402. ) -> CausalLMOutputWithPast:
  403. r"""
  404. Example:
  405. ```python
  406. >>> from transformers import AutoTokenizer, GemmaForCausalLM
  407. >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
  408. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
  409. >>> prompt = "What is your favorite condiment?"
  410. >>> inputs = tokenizer(prompt, return_tensors="pt")
  411. >>> # Generate
  412. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  413. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  414. "What is your favorite condiment?"
  415. ```"""
  416. outputs: BaseModelOutputWithPast = self.model(
  417. input_ids=input_ids,
  418. attention_mask=attention_mask,
  419. position_ids=position_ids,
  420. past_key_values=past_key_values,
  421. inputs_embeds=inputs_embeds,
  422. use_cache=use_cache,
  423. **kwargs,
  424. )
  425. hidden_states = outputs.last_hidden_state
  426. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  427. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  428. logits = self.lm_head(hidden_states[:, slice_indices, :])
  429. loss = None
  430. if labels is not None:
  431. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  432. return CausalLMOutputWithPast(
  433. loss=loss,
  434. logits=logits,
  435. past_key_values=outputs.past_key_values,
  436. hidden_states=outputs.hidden_states,
  437. attentions=outputs.attentions,
  438. )
  439. class GemmaForSequenceClassification(GenericForSequenceClassification, GemmaPreTrainedModel):
  440. pass
  441. class GemmaForTokenClassification(GenericForTokenClassification, GemmaPreTrainedModel):
  442. pass
  443. __all__ = [
  444. "GemmaModel",
  445. "GemmaForCausalLM",
  446. "GemmaForSequenceClassification",
  447. "GemmaForTokenClassification",
  448. "GemmaPreTrainedModel",
  449. ]