modeling_gemma3.py 54 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_gemma3.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. from dataclasses import dataclass
  23. from typing import Optional
  24. import torch
  25. import torch.nn as nn
  26. from ... import initialization as init
  27. from ...activations import ACT2FN
  28. from ...cache_utils import Cache, DynamicCache
  29. from ...configuration_utils import PreTrainedConfig
  30. from ...generation import GenerationMixin
  31. from ...integrations import use_kernel_func_from_hub, use_kernelized_func
  32. from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
  33. from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
  34. from ...modeling_outputs import (
  35. BaseModelOutputWithPast,
  36. BaseModelOutputWithPooling,
  37. CausalLMOutputWithPast,
  38. SequenceClassifierOutputWithPast,
  39. )
  40. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  41. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  42. from ...processing_utils import Unpack
  43. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
  44. from ...utils.deprecation import deprecate_kwarg
  45. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  46. from ...utils.output_capturing import capture_outputs
  47. from ..auto import AutoModel
  48. from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
  49. logger = logging.get_logger(__name__)
  50. @dataclass
  51. @auto_docstring(
  52. custom_intro="""
  53. Base class for Gemma3 outputs, with hidden states and attentions.
  54. """
  55. )
  56. class Gemma3ModelOutputWithPast(BaseModelOutputWithPast):
  57. r"""
  58. image_hidden_states (`torch.FloatTensor`, *optional*):
  59. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  60. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  61. """
  62. image_hidden_states: torch.FloatTensor | None = None
  63. @dataclass
  64. @auto_docstring(
  65. custom_intro="""
  66. Base class for Gemma3 causal language model (or autoregressive) outputs.
  67. """
  68. )
  69. class Gemma3CausalLMOutputWithPast(ModelOutput):
  70. r"""
  71. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  72. Language modeling loss (for next-token prediction).
  73. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
  74. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  75. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  76. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  77. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  78. `past_key_values` input) to speed up sequential decoding.
  79. image_hidden_states (`torch.FloatTensor`, *optional*):
  80. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  81. image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
  82. """
  83. loss: torch.FloatTensor | None = None
  84. logits: torch.FloatTensor | None = None
  85. past_key_values: Cache | None = None
  86. hidden_states: tuple[torch.FloatTensor] | None = None
  87. attentions: tuple[torch.FloatTensor] | None = None
  88. image_hidden_states: torch.FloatTensor | None = None
  89. class Gemma3TextScaledWordEmbedding(nn.Embedding):
  90. """
  91. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  92. """
  93. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
  94. super().__init__(num_embeddings, embedding_dim, padding_idx)
  95. self.scalar_embed_scale = embed_scale
  96. self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
  97. def forward(self, input_ids: torch.Tensor):
  98. return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
  99. class Gemma3MLP(nn.Module):
  100. def __init__(self, config: Gemma3TextConfig):
  101. super().__init__()
  102. self.config = config
  103. self.hidden_size = config.hidden_size
  104. self.intermediate_size = config.intermediate_size
  105. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  106. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  107. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  108. self.act_fn = ACT2FN[config.hidden_activation]
  109. def forward(self, x):
  110. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  111. return down_proj
  112. class Gemma3RMSNorm(nn.Module):
  113. def __init__(self, dim: int, eps: float = 1e-6):
  114. super().__init__()
  115. self.eps = eps
  116. self.weight = nn.Parameter(torch.zeros(dim))
  117. def _norm(self, x):
  118. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  119. def forward(self, x):
  120. output = self._norm(x.float())
  121. # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
  122. # See https://github.com/huggingface/transformers/pull/29402
  123. output = output * (1.0 + self.weight.float())
  124. return output.type_as(x)
  125. def extra_repr(self):
  126. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  127. class Gemma3RotaryEmbedding(nn.Module):
  128. inv_freq: torch.Tensor # fix linting for `register_buffer`
  129. def __init__(self, config: Gemma3TextConfig, device=None, layer_type=None):
  130. super().__init__()
  131. self.max_seq_len_cached = config.max_position_embeddings
  132. self.original_max_seq_len = config.max_position_embeddings
  133. self.config = config
  134. self.layer_types = list(set(config.layer_types))
  135. self.rope_type = {}
  136. for layer_type in self.layer_types:
  137. rope_params = self.config.rope_parameters[layer_type]
  138. if rope_params is None:
  139. continue
  140. self.rope_type[layer_type] = rope_params["rope_type"]
  141. rope_init_fn: Callable = self.compute_default_rope_parameters
  142. if self.rope_type[layer_type] != "default":
  143. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
  144. curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
  145. self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
  146. self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
  147. setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
  148. @staticmethod
  149. def compute_default_rope_parameters(
  150. config: Gemma3TextConfig | None = None,
  151. device: Optional["torch.device"] = None,
  152. seq_len: int | None = None,
  153. layer_type: str | None = None,
  154. ) -> tuple["torch.Tensor", float]:
  155. """
  156. Computes the inverse frequencies according to the original RoPE implementation
  157. Args:
  158. config ([`~transformers.PreTrainedConfig`]):
  159. The model configuration.
  160. device (`torch.device`):
  161. The device to use for initialization of the inverse frequencies.
  162. seq_len (`int`, *optional*):
  163. The current sequence length. Unused for this type of RoPE.
  164. layer_type (`str`, *optional*):
  165. The current layer type if the model has different RoPE parameters per type.
  166. Should not be used unless `config.layer_types is not None`
  167. Returns:
  168. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  169. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  170. """
  171. # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
  172. base = config.rope_parameters[layer_type]["rope_theta"]
  173. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  174. attention_factor = 1.0 # Unused in this type of RoPE
  175. # Compute the inverse frequencies
  176. inv_freq = 1.0 / (
  177. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  178. )
  179. return inv_freq, attention_factor
  180. @torch.no_grad()
  181. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  182. def forward(self, x, position_ids, layer_type=None):
  183. inv_freq = getattr(self, f"{layer_type}_inv_freq")
  184. attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
  185. inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  186. position_ids_expanded = position_ids[:, None, :].float()
  187. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  188. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  189. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  190. emb = torch.cat((freqs, freqs), dim=-1)
  191. cos = emb.cos() * attention_scaling
  192. sin = emb.sin() * attention_scaling
  193. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  194. def rotate_half(x):
  195. """Rotates half the hidden dims of the input."""
  196. x1 = x[..., : x.shape[-1] // 2]
  197. x2 = x[..., x.shape[-1] // 2 :]
  198. return torch.cat((-x2, x1), dim=-1)
  199. @use_kernel_func_from_hub("rotary_pos_emb")
  200. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  201. """Applies Rotary Position Embedding to the query and key tensors.
  202. Args:
  203. q (`torch.Tensor`): The query tensor.
  204. k (`torch.Tensor`): The key tensor.
  205. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  206. sin (`torch.Tensor`): The sine part of the rotary embedding.
  207. unsqueeze_dim (`int`, *optional*, defaults to 1):
  208. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  209. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  210. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  211. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  212. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  213. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  214. Returns:
  215. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  216. """
  217. cos = cos.unsqueeze(unsqueeze_dim)
  218. sin = sin.unsqueeze(unsqueeze_dim)
  219. q_embed = (q * cos) + (rotate_half(q) * sin)
  220. k_embed = (k * cos) + (rotate_half(k) * sin)
  221. return q_embed, k_embed
  222. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  223. """
  224. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  225. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  226. """
  227. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  228. if n_rep == 1:
  229. return hidden_states
  230. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  231. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  232. def eager_attention_forward(
  233. module: nn.Module,
  234. query: torch.Tensor,
  235. key: torch.Tensor,
  236. value: torch.Tensor,
  237. attention_mask: torch.Tensor | None,
  238. dropout: float | int = 0.0,
  239. scaling: float | None = None,
  240. softcap: float | None = None,
  241. **kwargs,
  242. ) -> tuple[torch.Tensor, torch.Tensor]:
  243. if scaling is None:
  244. scaling = module.head_dim**-0.5
  245. key_states = repeat_kv(key, module.num_key_value_groups)
  246. value_states = repeat_kv(value, module.num_key_value_groups)
  247. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  248. if softcap is not None:
  249. attn_weights = attn_weights / softcap
  250. attn_weights = torch.tanh(attn_weights)
  251. attn_weights = attn_weights * softcap
  252. if attention_mask is not None:
  253. attn_weights = attn_weights + attention_mask
  254. # upcast attention to fp32
  255. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  256. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  257. attn_output = torch.matmul(attn_weights, value_states)
  258. attn_output = attn_output.transpose(1, 2).contiguous()
  259. return attn_output, attn_weights
  260. @use_kernelized_func(apply_rotary_pos_emb)
  261. class Gemma3Attention(nn.Module):
  262. """Multi-headed attention from 'Attention Is All You Need' paper"""
  263. def __init__(self, config: Gemma3TextConfig, layer_idx: int):
  264. super().__init__()
  265. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  266. self.config = config
  267. self.layer_idx = layer_idx
  268. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  269. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  270. self.scaling = config.query_pre_attn_scalar**-0.5
  271. self.attention_dropout = self.config.attention_dropout
  272. self.is_causal = not self.config.use_bidirectional_attention
  273. self.q_proj = nn.Linear(
  274. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  275. )
  276. self.k_proj = nn.Linear(
  277. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  278. )
  279. self.v_proj = nn.Linear(
  280. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  281. )
  282. self.o_proj = nn.Linear(
  283. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  284. )
  285. self.attn_logit_softcapping = self.config.attn_logit_softcapping
  286. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  287. self.is_sliding = self.layer_type == "sliding_attention"
  288. self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  289. self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  290. def forward(
  291. self,
  292. hidden_states: torch.Tensor,
  293. position_embeddings: torch.Tensor = None,
  294. attention_mask: torch.Tensor | None = None,
  295. past_key_values: Cache | None = None,
  296. **kwargs: Unpack[TransformersKwargs],
  297. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  298. input_shape = hidden_states.shape[:-1]
  299. hidden_shape = (*input_shape, -1, self.head_dim)
  300. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  301. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  302. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  303. query_states = self.q_norm(query_states)
  304. key_states = self.k_norm(key_states)
  305. cos, sin = position_embeddings
  306. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  307. if past_key_values is not None:
  308. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  309. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  310. self.config._attn_implementation, eager_attention_forward
  311. )
  312. attn_output, attn_weights = attention_interface(
  313. self,
  314. query_states,
  315. key_states,
  316. value_states,
  317. attention_mask,
  318. dropout=self.attention_dropout if self.training else 0.0,
  319. scaling=self.scaling,
  320. sliding_window=self.sliding_window,
  321. **kwargs,
  322. )
  323. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  324. attn_output = self.o_proj(attn_output)
  325. return attn_output, attn_weights
  326. class Gemma3DecoderLayer(GradientCheckpointingLayer):
  327. def __init__(self, config: Gemma3TextConfig, layer_idx: int):
  328. super().__init__()
  329. self.config = config
  330. self.hidden_size = config.hidden_size
  331. self.layer_idx = layer_idx
  332. self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
  333. self.mlp = Gemma3MLP(config)
  334. self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  335. self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  336. self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  337. self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  338. def forward(
  339. self,
  340. hidden_states: torch.Tensor,
  341. position_embeddings: torch.Tensor = None,
  342. attention_mask: torch.Tensor | None = None,
  343. position_ids: torch.LongTensor | None = None,
  344. past_key_values: Cache | None = None,
  345. **kwargs: Unpack[TransformersKwargs],
  346. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  347. residual = hidden_states
  348. hidden_states = self.input_layernorm(hidden_states)
  349. hidden_states, _ = self.self_attn(
  350. hidden_states=hidden_states,
  351. position_embeddings=position_embeddings,
  352. attention_mask=attention_mask,
  353. position_ids=position_ids,
  354. past_key_values=past_key_values,
  355. **kwargs,
  356. )
  357. hidden_states = self.post_attention_layernorm(hidden_states)
  358. hidden_states = residual + hidden_states
  359. residual = hidden_states
  360. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  361. hidden_states = self.mlp(hidden_states)
  362. hidden_states = self.post_feedforward_layernorm(hidden_states)
  363. hidden_states = residual + hidden_states
  364. return hidden_states
  365. @auto_docstring
  366. class Gemma3PreTrainedModel(PreTrainedModel):
  367. config: Gemma3Config
  368. base_model_prefix = "model"
  369. supports_gradient_checkpointing = True
  370. _no_split_modules = [
  371. "Gemma3DecoderLayer",
  372. "SiglipVisionEmbeddings",
  373. "SiglipEncoderLayer",
  374. "SiglipMultiheadAttentionPoolingHead",
  375. ]
  376. _skip_keys_device_placement = ["past_key_values"]
  377. _supports_flash_attn = True
  378. _supports_sdpa = True
  379. _supports_flex_attn = True
  380. _can_compile_fullgraph = True
  381. _supports_attention_backend = True
  382. _can_record_outputs = {
  383. "hidden_states": Gemma3DecoderLayer,
  384. "attentions": Gemma3Attention,
  385. }
  386. input_modalities = ("image", "text")
  387. @torch.no_grad()
  388. def _init_weights(self, module):
  389. super()._init_weights(module)
  390. if isinstance(module, Gemma3MultiModalProjector):
  391. init.zeros_(module.mm_input_projection_weight)
  392. # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
  393. elif "RMSNorm" in module.__class__.__name__:
  394. init.zeros_(module.weight)
  395. elif isinstance(module, Gemma3TextScaledWordEmbedding):
  396. init.constant_(module.embed_scale, module.scalar_embed_scale)
  397. elif isinstance(module, Gemma3RotaryEmbedding):
  398. for layer_type in module.layer_types:
  399. rope_init_fn = module.compute_default_rope_parameters
  400. if module.rope_type[layer_type] != "default":
  401. rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
  402. curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
  403. init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
  404. init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
  405. def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
  406. """
  407. Enables a bidirectional mask within the sliding window.
  408. """
  409. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  410. """A token can attend to any other token if their absolute distance is within
  411. the (exclusive) sliding window size (distance < sliding_window)."""
  412. return abs(q_idx - kv_idx) < sliding_window
  413. return inner_mask
  414. @auto_docstring
  415. class Gemma3TextModel(Gemma3PreTrainedModel):
  416. config: Gemma3TextConfig
  417. input_modalities = ("text",)
  418. def __init__(self, config: Gemma3TextConfig):
  419. super().__init__(config)
  420. self.padding_idx = config.pad_token_id
  421. self.vocab_size = config.vocab_size
  422. # Gemma3 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
  423. self.embed_tokens = Gemma3TextScaledWordEmbedding(
  424. config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
  425. )
  426. self.layers = nn.ModuleList(
  427. [Gemma3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  428. )
  429. self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  430. self.rotary_emb = Gemma3RotaryEmbedding(config)
  431. self.gradient_checkpointing = False
  432. # Initialize weights and apply final processing
  433. self.post_init()
  434. @merge_with_config_defaults
  435. @capture_outputs
  436. @auto_docstring
  437. def forward(
  438. self,
  439. input_ids: torch.LongTensor | None = None,
  440. attention_mask: torch.Tensor | None = None,
  441. position_ids: torch.LongTensor | None = None,
  442. past_key_values: Cache | None = None,
  443. inputs_embeds: torch.FloatTensor | None = None,
  444. use_cache: bool | None = None,
  445. **kwargs: Unpack[TransformersKwargs],
  446. ) -> BaseModelOutputWithPast:
  447. if (input_ids is None) ^ (inputs_embeds is not None):
  448. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  449. if inputs_embeds is None:
  450. inputs_embeds = self.embed_tokens(input_ids)
  451. if use_cache and past_key_values is None:
  452. past_key_values = DynamicCache(config=self.config)
  453. if position_ids is None:
  454. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  455. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  456. position_ids = position_ids.unsqueeze(0)
  457. # It may already have been prepared by e.g. `generate`
  458. if not isinstance(causal_mask_mapping := attention_mask, dict):
  459. # Prepare mask arguments
  460. mask_kwargs = {
  461. "config": self.config,
  462. "inputs_embeds": inputs_embeds,
  463. "attention_mask": attention_mask,
  464. "past_key_values": past_key_values,
  465. "position_ids": position_ids,
  466. }
  467. sliding_mask_kwargs = mask_kwargs.copy()
  468. if self.config.use_bidirectional_attention:
  469. mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool)
  470. sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window)
  471. # Create the masks
  472. causal_mask_mapping = {
  473. "full_attention": create_causal_mask(**mask_kwargs),
  474. "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
  475. }
  476. # embed positions
  477. hidden_states = inputs_embeds
  478. position_embeddings = {}
  479. for layer_type in self.config.layer_types:
  480. position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
  481. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  482. hidden_states = decoder_layer(
  483. hidden_states,
  484. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  485. position_embeddings=position_embeddings[self.config.layer_types[i]],
  486. position_ids=position_ids,
  487. past_key_values=past_key_values,
  488. **kwargs,
  489. )
  490. hidden_states = self.norm(hidden_states)
  491. return BaseModelOutputWithPast(
  492. last_hidden_state=hidden_states,
  493. past_key_values=past_key_values,
  494. )
  495. @auto_docstring
  496. class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
  497. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  498. _tp_plan = {"lm_head": "colwise_gather_output"}
  499. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  500. config: Gemma3TextConfig
  501. def __init__(self, config: Gemma3TextConfig):
  502. super().__init__(config)
  503. self.model = Gemma3TextModel(config)
  504. self.vocab_size = config.vocab_size
  505. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  506. # Initialize weights and apply final processing
  507. self.post_init()
  508. @can_return_tuple
  509. @auto_docstring
  510. def forward(
  511. self,
  512. input_ids: torch.LongTensor | None = None,
  513. attention_mask: torch.Tensor | None = None,
  514. position_ids: torch.LongTensor | None = None,
  515. past_key_values: Cache | None = None,
  516. inputs_embeds: torch.FloatTensor | None = None,
  517. labels: torch.LongTensor | None = None,
  518. use_cache: bool | None = None,
  519. logits_to_keep: int | torch.Tensor = 0,
  520. **kwargs: Unpack[TransformersKwargs],
  521. ) -> CausalLMOutputWithPast:
  522. r"""
  523. Example:
  524. ```python
  525. >>> from transformers import AutoTokenizer, Gemma3ForCausalLM
  526. >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
  527. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
  528. >>> prompt = "What is your favorite condiment?"
  529. >>> inputs = tokenizer(prompt, return_tensors="pt")
  530. >>> # Generate
  531. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  532. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  533. "What is your favorite condiment?"
  534. ```"""
  535. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  536. outputs: BaseModelOutputWithPast = self.model(
  537. input_ids=input_ids,
  538. attention_mask=attention_mask,
  539. position_ids=position_ids,
  540. past_key_values=past_key_values,
  541. inputs_embeds=inputs_embeds,
  542. use_cache=use_cache,
  543. **kwargs,
  544. )
  545. hidden_states = outputs.last_hidden_state
  546. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  547. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  548. logits = self.lm_head(hidden_states[:, slice_indices, :])
  549. if self.config.final_logit_softcapping is not None:
  550. logits = logits / self.config.final_logit_softcapping
  551. logits = torch.tanh(logits)
  552. logits = logits * self.config.final_logit_softcapping
  553. loss = None
  554. if labels is not None:
  555. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  556. return CausalLMOutputWithPast(
  557. loss=loss,
  558. logits=logits,
  559. past_key_values=outputs.past_key_values,
  560. hidden_states=outputs.hidden_states,
  561. attentions=outputs.attentions,
  562. )
  563. class Gemma3MultiModalProjector(nn.Module):
  564. def __init__(self, config: Gemma3Config):
  565. super().__init__()
  566. self.mm_input_projection_weight = nn.Parameter(
  567. torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
  568. )
  569. self.mm_soft_emb_norm = Gemma3RMSNorm(
  570. config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
  571. )
  572. self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)
  573. self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
  574. self.kernel_size = self.patches_per_image // self.tokens_per_side
  575. self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
  576. def forward(self, vision_outputs: torch.Tensor):
  577. batch_size, _, hidden_size = vision_outputs.shape
  578. reshaped_vision_outputs = vision_outputs.transpose(1, 2)
  579. reshaped_vision_outputs = reshaped_vision_outputs.reshape(
  580. batch_size, hidden_size, self.patches_per_image, self.patches_per_image
  581. )
  582. reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
  583. pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
  584. pooled_vision_outputs = pooled_vision_outputs.flatten(2)
  585. pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
  586. normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
  587. projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)
  588. return projected_vision_outputs.type_as(vision_outputs)
  589. def token_type_ids_mask_function(group_ids: torch.Tensor) -> Callable:
  590. """
  591. This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
  592. not start and end indices.
  593. Args:
  594. group_ids (`torch.Tensor`):
  595. A tensor of shape `(bs, len)` assigning each token to a vision group. Tokens with the same group
  596. come from the same input image. Text is denoted by `-1`.
  597. """
  598. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  599. seq_length = group_ids.shape[-1]
  600. # clamp indices because with static cache they can go beyond `group_ids.shape[-1]`
  601. q_idx_clamped = q_idx.clamp(max=seq_length - 1)
  602. kv_idx_clamped = kv_idx.clamp(max=seq_length - 1)
  603. # Unmask if the q and kv come from same group which is not -1 (i.e. non-text)
  604. q_group = group_ids[batch_idx, q_idx_clamped]
  605. kv_group = group_ids[batch_idx, kv_idx_clamped]
  606. q_group = torch.where(q_idx < seq_length, q_group, -1)
  607. kv_group = torch.where(kv_idx < seq_length, kv_group, -1)
  608. return (q_group == kv_group) & (q_group >= 0)
  609. return inner_mask
  610. @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
  611. def create_causal_mask_mapping(
  612. config: PreTrainedConfig,
  613. inputs_embeds: torch.Tensor,
  614. attention_mask: torch.Tensor | None,
  615. past_key_values: Cache | None,
  616. position_ids: torch.Tensor | None,
  617. token_type_ids: torch.Tensor | None = None,
  618. pixel_values: torch.FloatTensor | None = None,
  619. is_training: bool = False,
  620. is_first_iteration: bool | None = None,
  621. **kwargs,
  622. ) -> dict:
  623. """
  624. Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
  625. for all kinds of forward passes. Gemma3 uses a bidirectional mask for images.
  626. Uses `pixel_values` as an optional input to disambiguate edge cases.
  627. """
  628. if is_training and token_type_ids is None:
  629. raise ValueError("`token_type_ids` is required as a model input when training")
  630. mask_kwargs = {
  631. "config": config.get_text_config(),
  632. "inputs_embeds": inputs_embeds,
  633. "attention_mask": attention_mask,
  634. "past_key_values": past_key_values,
  635. "position_ids": position_ids,
  636. }
  637. # NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
  638. # (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
  639. # means). Determining prefill in that case requires checking data values, which is not compile-compatible.
  640. is_first_iteration = (
  641. is_first_iteration
  642. if is_first_iteration is not None
  643. else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
  644. )
  645. if token_type_ids is not None and is_first_iteration:
  646. # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
  647. # undo the causal masking)
  648. # First find where a new image block starts: 1 if image and previous not image
  649. # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
  650. is_image = (token_type_ids == 1).to(inputs_embeds.device)
  651. is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
  652. new_image_start = is_image & ~is_previous_image
  653. group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
  654. group_ids = torch.where(is_image, group_ids, -1)
  655. mask_kwargs["or_mask_function"] = token_type_ids_mask_function(group_ids)
  656. return create_masks_for_generate(**mask_kwargs)
  657. @auto_docstring(
  658. custom_intro="""
  659. The Base Gemma3 model which consists of a vision backbone and a language model without language modeling head.,
  660. """
  661. )
  662. class Gemma3Model(Gemma3PreTrainedModel):
  663. # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
  664. accepts_loss_kwargs = False
  665. def __init__(self, config: Gemma3Config):
  666. super().__init__(config)
  667. self.vision_tower = AutoModel.from_config(config=config.vision_config)
  668. self.multi_modal_projector = Gemma3MultiModalProjector(config)
  669. self.vocab_size = config.text_config.vocab_size
  670. language_model = AutoModel.from_config(config=config.text_config)
  671. self.language_model = language_model
  672. self.post_init()
  673. def get_input_embeddings(self):
  674. return self.language_model.get_input_embeddings()
  675. def set_input_embeddings(self, value):
  676. self.language_model.set_input_embeddings(value)
  677. @can_return_tuple
  678. @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.")
  679. def get_image_features(
  680. self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  681. ) -> tuple | BaseModelOutputWithPooling:
  682. vision_outputs = self.vision_tower(pixel_values=pixel_values, return_dict=True, **kwargs)
  683. last_hidden_state = vision_outputs.last_hidden_state
  684. vision_outputs.pooler_output = self.multi_modal_projector(last_hidden_state)
  685. return vision_outputs
  686. def get_placeholder_mask(
  687. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  688. ):
  689. """
  690. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  691. equal to the length of multimodal features. If the lengths are different, an error is raised.
  692. """
  693. if input_ids is None:
  694. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  695. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  696. )
  697. special_image_mask = special_image_mask.all(-1)
  698. else:
  699. special_image_mask = input_ids == self.config.image_token_id
  700. n_image_tokens = special_image_mask.sum()
  701. n_image_features = image_features.shape[0] * image_features.shape[1]
  702. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  703. torch_compilable_check(
  704. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  705. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
  706. )
  707. return special_image_mask
  708. @can_return_tuple
  709. @auto_docstring
  710. def forward(
  711. self,
  712. input_ids: torch.LongTensor | None = None,
  713. pixel_values: torch.FloatTensor | None = None,
  714. attention_mask: torch.Tensor | None = None,
  715. position_ids: torch.LongTensor | None = None,
  716. past_key_values: Cache | None = None,
  717. token_type_ids: torch.LongTensor | None = None,
  718. inputs_embeds: torch.FloatTensor | None = None,
  719. labels: torch.LongTensor | None = None,
  720. use_cache: bool | None = None,
  721. **lm_kwargs: Unpack[TransformersKwargs],
  722. ) -> tuple | Gemma3ModelOutputWithPast:
  723. r"""
  724. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  725. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  726. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  727. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
  728. Example:
  729. ```python
  730. >>> from PIL import Image
  731. >>> import httpx
  732. >>> from io import BytesIO
  733. >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
  734. >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma32-3b-mix-224")
  735. >>> processor = AutoProcessor.from_pretrained("google/gemma32-3b-mix-224")
  736. >>> prompt = "Where is the cat standing?"
  737. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
  738. >>> with httpx.stream("GET", url) as response:
  739. ... image = Image.open(BytesIO(response.read()))
  740. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  741. >>> # Generate
  742. >>> generate_ids = model.generate(**inputs,)
  743. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  744. "Where is the cat standing?\nsnow"
  745. ```"""
  746. if (input_ids is None) ^ (inputs_embeds is not None):
  747. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  748. # Replace image id with PAD if the image token if OOV, to avoid index-errors
  749. if input_ids is not None and self.config.image_token_id >= self.vocab_size:
  750. special_image_mask = input_ids == self.config.image_token_id
  751. llm_input_ids = input_ids.clone()
  752. llm_input_ids[special_image_mask] = 0
  753. else:
  754. llm_input_ids = input_ids
  755. if inputs_embeds is None:
  756. inputs_embeds = self.get_input_embeddings()(llm_input_ids)
  757. # Merge text and images
  758. if pixel_values is not None:
  759. image_features = self.get_image_features(pixel_values, return_dict=True).pooler_output
  760. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  761. special_image_mask = self.get_placeholder_mask(
  762. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  763. )
  764. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  765. # It may already have been prepared by e.g. `generate`
  766. if not isinstance(causal_mask_mapping := attention_mask, dict):
  767. causal_mask_mapping = create_causal_mask_mapping(
  768. self.config,
  769. inputs_embeds,
  770. attention_mask,
  771. past_key_values,
  772. position_ids,
  773. token_type_ids,
  774. pixel_values,
  775. is_training=self.training,
  776. )
  777. outputs = self.language_model(
  778. attention_mask=causal_mask_mapping,
  779. position_ids=position_ids,
  780. past_key_values=past_key_values,
  781. inputs_embeds=inputs_embeds,
  782. use_cache=use_cache,
  783. return_dict=True,
  784. **lm_kwargs,
  785. )
  786. return Gemma3ModelOutputWithPast(
  787. last_hidden_state=outputs.last_hidden_state,
  788. past_key_values=outputs.past_key_values,
  789. hidden_states=outputs.hidden_states,
  790. attentions=outputs.attentions,
  791. image_hidden_states=image_features if pixel_values is not None else None,
  792. )
  793. @auto_docstring(
  794. custom_intro="""
  795. The Base Gemma3 model which consists of a vision backbone and a language model without language modeling head.,
  796. """
  797. )
  798. class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
  799. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  800. # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
  801. # Fix: https://github.com/huggingface/transformers/issues/40564
  802. accepts_loss_kwargs = False
  803. def __init__(self, config: Gemma3Config):
  804. super().__init__(config)
  805. self.model = Gemma3Model(config)
  806. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  807. self.post_init()
  808. def get_input_embeddings(self):
  809. return self.model.get_input_embeddings()
  810. def set_input_embeddings(self, value):
  811. self.model.set_input_embeddings(value)
  812. @auto_docstring
  813. def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]):
  814. return self.model.get_image_features(pixel_values, **kwargs)
  815. @can_return_tuple
  816. @auto_docstring
  817. def forward(
  818. self,
  819. input_ids: torch.LongTensor | None = None,
  820. pixel_values: torch.FloatTensor | None = None,
  821. attention_mask: torch.Tensor | None = None,
  822. position_ids: torch.LongTensor | None = None,
  823. past_key_values: Cache | None = None,
  824. token_type_ids: torch.LongTensor | None = None,
  825. inputs_embeds: torch.FloatTensor | None = None,
  826. labels: torch.LongTensor | None = None,
  827. use_cache: bool | None = None,
  828. logits_to_keep: int | torch.Tensor = 0,
  829. **lm_kwargs: Unpack[TransformersKwargs],
  830. ) -> tuple | Gemma3CausalLMOutputWithPast:
  831. r"""
  832. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  833. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  834. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  835. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
  836. Example:
  837. ```python
  838. >>> from PIL import Image
  839. >>> import httpx
  840. >>> from io import BytesIO
  841. >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
  842. >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
  843. >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
  844. >>> messages = [
  845. ... {
  846. ... "role": "system",
  847. ... "content": [
  848. ... {"type": "text", "text": "You are a helpful assistant."}
  849. ... ]
  850. ... },
  851. ... {
  852. ... "role": "user", "content": [
  853. ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
  854. ... {"type": "text", "text": "Where is the cat standing?"},
  855. ... ]
  856. ... },
  857. ... ]
  858. >>> inputs = processor.apply_chat_template(
  859. ... messages,
  860. ... tokenize=True,
  861. ... return_dict=True,
  862. ... return_tensors="pt",
  863. ... add_generation_prompt=True
  864. ... )
  865. >>> # Generate
  866. >>> generate_ids = model.generate(**inputs)
  867. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  868. "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
  869. ```
  870. """
  871. outputs = self.model(
  872. input_ids=input_ids,
  873. pixel_values=pixel_values,
  874. token_type_ids=token_type_ids,
  875. attention_mask=attention_mask,
  876. position_ids=position_ids,
  877. past_key_values=past_key_values,
  878. inputs_embeds=inputs_embeds,
  879. use_cache=use_cache,
  880. labels=labels,
  881. **lm_kwargs,
  882. )
  883. hidden_states = outputs[0]
  884. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  885. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  886. logits = self.lm_head(hidden_states[:, slice_indices, :])
  887. loss = None
  888. if labels is not None:
  889. # Upcast to float if we need to compute the loss to avoid potential precision issues
  890. logits = logits.float()
  891. shift_logits = logits[..., :-1, :]
  892. shift_labels = labels[..., 1:]
  893. if attention_mask is not None:
  894. # we use the input attention mask to shift the logits and labels, because it is 2D.
  895. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
  896. shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
  897. shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
  898. shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
  899. else:
  900. shift_logits = shift_logits.contiguous()
  901. shift_labels = shift_labels.contiguous()
  902. # Flatten the tokens
  903. loss_fct = nn.CrossEntropyLoss()
  904. flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
  905. flat_labels = shift_labels.view(-1).to(shift_logits.device)
  906. loss = loss_fct(flat_logits, flat_labels)
  907. return Gemma3CausalLMOutputWithPast(
  908. loss=loss,
  909. logits=logits,
  910. past_key_values=outputs.past_key_values,
  911. hidden_states=outputs.hidden_states,
  912. attentions=outputs.attentions,
  913. image_hidden_states=outputs.image_hidden_states,
  914. )
  915. def prepare_inputs_for_generation(
  916. self,
  917. input_ids,
  918. past_key_values=None,
  919. inputs_embeds=None,
  920. position_ids=None,
  921. pixel_values=None,
  922. attention_mask=None,
  923. token_type_ids=None,
  924. use_cache=True,
  925. logits_to_keep=None,
  926. labels=None,
  927. is_first_iteration=False,
  928. **kwargs,
  929. ):
  930. # Overwritten -- custom `pixel_values` handling
  931. model_inputs = super().prepare_inputs_for_generation(
  932. input_ids,
  933. past_key_values=past_key_values,
  934. inputs_embeds=inputs_embeds,
  935. attention_mask=attention_mask,
  936. position_ids=position_ids,
  937. use_cache=use_cache,
  938. logits_to_keep=logits_to_keep,
  939. token_type_ids=token_type_ids,
  940. is_first_iteration=is_first_iteration,
  941. **kwargs,
  942. )
  943. # Pixel values are used only in the first iteration if available
  944. # In subsequent iterations, they are already merged with text and cached
  945. # NOTE: first iteration doesn't have to be prefill, it can be the first
  946. # iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always
  947. if is_first_iteration or not use_cache:
  948. model_inputs["pixel_values"] = pixel_values
  949. return model_inputs
  950. @staticmethod
  951. @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
  952. def create_masks_for_generate(
  953. config: PreTrainedConfig,
  954. inputs_embeds: torch.Tensor,
  955. attention_mask: torch.Tensor | None,
  956. past_key_values: Cache | None,
  957. position_ids: torch.Tensor | None,
  958. token_type_ids: torch.Tensor | None = None,
  959. is_first_iteration: bool | None = False,
  960. **kwargs,
  961. ) -> dict:
  962. # Uses the overwritten `create_masks_for_generate` with `token_type_ids` masking
  963. return create_causal_mask_mapping(
  964. config,
  965. inputs_embeds,
  966. attention_mask,
  967. past_key_values,
  968. position_ids,
  969. token_type_ids,
  970. is_first_iteration=is_first_iteration,
  971. **{k: v for k, v in kwargs.items() if k != "pixel_values"},
  972. )
  973. class Gemma3ForSequenceClassification(Gemma3PreTrainedModel):
  974. def __init__(self, config):
  975. super().__init__(config)
  976. self.num_labels = config.num_labels
  977. self.model = Gemma3Model(config)
  978. self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)
  979. # Initialize weights and apply final processing
  980. self.post_init()
  981. def get_input_embeddings(self):
  982. return self.model.get_input_embeddings()
  983. def set_input_embeddings(self, value):
  984. self.model.set_input_embeddings(value)
  985. @can_return_tuple
  986. @auto_docstring
  987. def forward(
  988. self,
  989. input_ids: torch.LongTensor | None = None,
  990. pixel_values: torch.FloatTensor | None = None,
  991. attention_mask: torch.Tensor | None = None,
  992. position_ids: torch.LongTensor | None = None,
  993. past_key_values: Cache | None = None,
  994. inputs_embeds: torch.FloatTensor | None = None,
  995. token_type_ids: torch.LongTensor | None = None,
  996. labels: torch.LongTensor | None = None,
  997. use_cache: bool | None = None,
  998. **kwargs: Unpack[TransformersKwargs],
  999. ) -> SequenceClassifierOutputWithPast:
  1000. r"""
  1001. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1002. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1003. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1004. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1005. """
  1006. transformer_outputs = self.model(
  1007. input_ids,
  1008. attention_mask=attention_mask,
  1009. pixel_values=pixel_values,
  1010. position_ids=position_ids,
  1011. past_key_values=past_key_values,
  1012. inputs_embeds=inputs_embeds,
  1013. token_type_ids=token_type_ids,
  1014. use_cache=use_cache,
  1015. **kwargs,
  1016. )
  1017. hidden_states = transformer_outputs.last_hidden_state
  1018. logits = self.score(hidden_states)
  1019. if input_ids is not None:
  1020. batch_size = input_ids.shape[0]
  1021. else:
  1022. batch_size = inputs_embeds.shape[0]
  1023. if self.config.text_config.pad_token_id is None and batch_size != 1:
  1024. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  1025. if self.config.text_config.pad_token_id is None:
  1026. last_non_pad_token = -1
  1027. elif input_ids is not None:
  1028. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  1029. non_pad_mask = (input_ids != self.config.text_config.pad_token_id).to(logits.device, torch.int32)
  1030. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  1031. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  1032. else:
  1033. last_non_pad_token = -1
  1034. logger.warning_once(
  1035. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  1036. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  1037. )
  1038. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  1039. loss = None
  1040. if labels is not None:
  1041. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  1042. return SequenceClassifierOutputWithPast(
  1043. loss=loss,
  1044. logits=pooled_logits,
  1045. past_key_values=transformer_outputs.past_key_values,
  1046. hidden_states=transformer_outputs.hidden_states,
  1047. attentions=transformer_outputs.attentions,
  1048. )
  1049. class Gemma3TextForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel):
  1050. """
  1051. Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig.
  1052. It uses the generic sequence classification implementation for efficiency and consistency.
  1053. """
  1054. config: Gemma3TextConfig
  1055. input_modalities = ("text",)
  1056. __all__ = [
  1057. "Gemma3PreTrainedModel",
  1058. "Gemma3TextModel",
  1059. "Gemma3ForCausalLM",
  1060. "Gemma3ForConditionalGeneration",
  1061. "Gemma3Model",
  1062. "Gemma3ForSequenceClassification",
  1063. "Gemma3TextForSequenceClassification",
  1064. ]