modular_gemma3.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011
  1. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from collections.abc import Callable
  16. from typing import Any, Optional
  17. import torch
  18. import torch.nn as nn
  19. from huggingface_hub.dataclasses import strict
  20. from ... import initialization as init
  21. from ...cache_utils import Cache, DynamicCache
  22. from ...configuration_utils import PreTrainedConfig
  23. from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
  24. from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, SequenceClassifierOutputWithPast
  26. from ...modeling_rope_utils import (
  27. ROPE_INIT_FUNCTIONS,
  28. dynamic_rope_update,
  29. )
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  33. from ...utils.deprecation import deprecate_kwarg
  34. from ...utils.generic import maybe_autocast
  35. from ..gemma2.configuration_gemma2 import Gemma2Config
  36. from ..gemma2.modeling_gemma2 import (
  37. Gemma2Attention,
  38. Gemma2ForCausalLM,
  39. Gemma2MLP,
  40. Gemma2Model,
  41. Gemma2PreTrainedModel,
  42. Gemma2RMSNorm,
  43. Gemma2RotaryEmbedding,
  44. apply_rotary_pos_emb,
  45. eager_attention_forward,
  46. )
  47. from ..paligemma.modeling_paligemma import (
  48. PaliGemmaCausalLMOutputWithPast,
  49. PaliGemmaForConditionalGeneration,
  50. PaliGemmaModel,
  51. PaligemmaModelOutputWithPast,
  52. token_type_ids_mask_function,
  53. )
  54. from ..siglip import SiglipVisionConfig
  55. logger = logging.get_logger(__name__)
  56. @auto_docstring(checkpoint="google/gemma-3-4b-it")
  57. @strict
  58. class Gemma3TextConfig(Gemma2Config, PreTrainedConfig):
  59. r"""
  60. query_pre_attn_scalar (`float`, *optional*, defaults to 256):
  61. scaling factor used on the attention scores
  62. final_logit_softcapping (`float`, *optional*):
  63. Scaling factor when applying tanh softcapping on the logits.
  64. attn_logit_softcapping (`float`, *optional*):
  65. Scaling factor when applying tanh softcapping on the attention scores.
  66. use_bidirectional_attention (`bool`, *optional*, defaults to `False`):
  67. If True, the model will attend to all text tokens instead of using a causal mask. This does not change
  68. behavior for vision tokens.
  69. ```python
  70. >>> from transformers import Gemma3TextModel, Gemma3TextConfig
  71. >>> # Initializing a Gemma3Text gemma3_text-7b style configuration
  72. >>> configuration = Gemma3TextConfig()
  73. >>> # Initializing a model from the gemma3_text-7b style configuration
  74. >>> model = Gemma3TextModel(configuration)
  75. >>> # Accessing the model configuration
  76. >>> configuration = model.config
  77. ```
  78. """
  79. model_type = "gemma3_text"
  80. base_model_tp_plan = {
  81. "layers.*.self_attn.q_proj": "colwise",
  82. "layers.*.self_attn.k_proj": "colwise",
  83. "layers.*.self_attn.v_proj": "colwise",
  84. "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
  85. "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
  86. "layers.*.self_attn.o_proj": "rowwise",
  87. "layers.*.mlp.gate_proj": "colwise",
  88. "layers.*.mlp.up_proj": "colwise",
  89. "layers.*.mlp.down_proj": "rowwise",
  90. }
  91. default_theta = {"global": 1_000_000.0, "local": 10_000.0}
  92. vocab_size: int = 262_208
  93. max_position_embeddings: int = 131_072
  94. layer_types: list[str] | None = None
  95. final_logit_softcapping: float | None = None
  96. attn_logit_softcapping: float | None = None
  97. rope_parameters: dict | None = None
  98. use_bidirectional_attention: bool | None = False
  99. def __post_init__(self, **kwargs):
  100. if self.use_bidirectional_attention:
  101. self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
  102. # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
  103. self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
  104. if self.layer_types is None:
  105. self.layer_types = [
  106. "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
  107. for i in range(self.num_hidden_layers)
  108. ]
  109. PreTrainedConfig.__post_init__(**kwargs)
  110. def convert_rope_params_to_dict(self, **kwargs):
  111. rope_scaling = kwargs.pop("rope_scaling", None)
  112. # Try to set `rope_scaling` if available, otherwise use `rope_parameters`. If we find `rope_parameters`
  113. # as arg in the inputs, we can safely assume that it is in the new format. New naming used -> new format
  114. default_rope_params = {
  115. "sliding_attention": {"rope_type": "default"},
  116. "full_attention": {"rope_type": "default"},
  117. }
  118. self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else default_rope_params
  119. if rope_scaling is not None:
  120. self.rope_parameters["full_attention"].update(rope_scaling)
  121. # Set default values if not present
  122. if self.rope_parameters.get("full_attention") is None:
  123. self.rope_parameters["full_attention"] = {"rope_type": "default"}
  124. self.rope_parameters["full_attention"].setdefault(
  125. "rope_theta", kwargs.pop("rope_theta", self.default_theta["global"])
  126. )
  127. if self.rope_parameters.get("sliding_attention") is None:
  128. self.rope_parameters["sliding_attention"] = {"rope_type": "default"}
  129. self.rope_parameters["sliding_attention"].setdefault(
  130. "rope_theta", kwargs.pop("rope_local_base_freq", self.default_theta["local"])
  131. )
  132. # Standardize and validate the correctness of rotary position embeddings parameters
  133. self.standardize_rope_params()
  134. return kwargs
  135. @auto_docstring(checkpoint="google/gemma-3-4b-it")
  136. @strict
  137. class Gemma3Config(PreTrainedConfig):
  138. r"""
  139. mm_tokens_per_image (`int`, *optional*, defaults to 256):
  140. The number of tokens per image embedding.
  141. boi_token_index (`int`, *optional*, defaults to 255999):
  142. The begin-of-image token index to wrap the image prompt.
  143. eoi_token_index (`int`, *optional*, defaults to 256000):
  144. The end-of-image token index to wrap the image prompt.
  145. Example:
  146. ```python
  147. >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
  148. >>> # Initializing a Siglip-like vision config
  149. >>> vision_config = SiglipVisionConfig()
  150. >>> # Initializing a Gemma3 Text config
  151. >>> text_config = Gemma3TextConfig()
  152. >>> # Initializing a Gemma3 gemma-3-4b style configuration
  153. >>> configuration = Gemma3Config(vision_config, text_config)
  154. >>> # Initializing a model from the gemma-3-4b style configuration
  155. >>> model = Gemma3TextConfig(configuration)
  156. >>> # Accessing the model configuration
  157. >>> configuration = model.config
  158. ```"""
  159. model_type = "gemma3"
  160. attribute_map = {
  161. "image_token_id": "image_token_index",
  162. "boi_token_id": "boi_token_index",
  163. "eoi_token_id": "eoi_token_index",
  164. }
  165. sub_configs = {
  166. "text_config": Gemma3TextConfig,
  167. "vision_config": SiglipVisionConfig,
  168. }
  169. text_config: Gemma3TextConfig | dict[str, Any] | None = None
  170. vision_config: SiglipVisionConfig | dict[str, Any] | None = None
  171. mm_tokens_per_image: int | None = 256
  172. boi_token_index: int | None = 255_999
  173. eoi_token_index: int | None = 256_000
  174. image_token_index: int | None = 262_144
  175. initializer_range: float | None = 0.02
  176. tie_word_embeddings: bool | None = True
  177. def __post_init__(self, **kwargs):
  178. if self.text_config is None:
  179. self.text_config = Gemma3TextConfig()
  180. logger.info("text_config is None, using default Gemma3TextConfig text config.")
  181. elif isinstance(self.text_config, dict):
  182. self.text_config = Gemma3TextConfig(**self.text_config)
  183. if isinstance(self.vision_config, dict):
  184. self.vision_config = SiglipVisionConfig(**self.vision_config)
  185. elif self.vision_config is None:
  186. self.vision_config = SiglipVisionConfig()
  187. logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
  188. super().__post_init__(**kwargs)
  189. class Gemma3ModelOutputWithPast(PaligemmaModelOutputWithPast):
  190. pass
  191. class Gemma3CausalLMOutputWithPast(PaliGemmaCausalLMOutputWithPast):
  192. pass
  193. class Gemma3TextScaledWordEmbedding(nn.Embedding):
  194. """
  195. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  196. """
  197. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
  198. super().__init__(num_embeddings, embedding_dim, padding_idx)
  199. self.scalar_embed_scale = embed_scale
  200. self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
  201. def forward(self, input_ids: torch.Tensor):
  202. return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
  203. class Gemma3MLP(Gemma2MLP):
  204. def __init__(self, config: Gemma3TextConfig):
  205. super().__init__(config)
  206. class Gemma3RMSNorm(Gemma2RMSNorm):
  207. def __init__(self, dim: int, eps: float = 1e-6):
  208. super().__init__(dim=dim, eps=eps)
  209. class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding):
  210. def __init__(self, config: Gemma3TextConfig, device=None, layer_type=None):
  211. nn.Module.__init__()
  212. self.max_seq_len_cached = config.max_position_embeddings
  213. self.original_max_seq_len = config.max_position_embeddings
  214. self.config = config
  215. self.layer_types = list(set(config.layer_types))
  216. self.rope_type = {}
  217. for layer_type in self.layer_types:
  218. rope_params = self.config.rope_parameters[layer_type]
  219. if rope_params is None:
  220. continue
  221. self.rope_type[layer_type] = rope_params["rope_type"]
  222. rope_init_fn: Callable = self.compute_default_rope_parameters
  223. if self.rope_type[layer_type] != "default":
  224. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
  225. curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
  226. self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
  227. self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
  228. setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
  229. @staticmethod
  230. def compute_default_rope_parameters(
  231. config: Gemma3TextConfig | None = None,
  232. device: Optional["torch.device"] = None,
  233. seq_len: int | None = None,
  234. layer_type: str | None = None,
  235. ) -> tuple["torch.Tensor", float]:
  236. """
  237. Computes the inverse frequencies according to the original RoPE implementation
  238. Args:
  239. config ([`~transformers.PreTrainedConfig`]):
  240. The model configuration.
  241. device (`torch.device`):
  242. The device to use for initialization of the inverse frequencies.
  243. seq_len (`int`, *optional*):
  244. The current sequence length. Unused for this type of RoPE.
  245. layer_type (`str`, *optional*):
  246. The current layer type if the model has different RoPE parameters per type.
  247. Should not be used unless `config.layer_types is not None`
  248. Returns:
  249. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  250. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  251. """
  252. # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
  253. base = config.rope_parameters[layer_type]["rope_theta"]
  254. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  255. attention_factor = 1.0 # Unused in this type of RoPE
  256. # Compute the inverse frequencies
  257. inv_freq = 1.0 / (
  258. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  259. )
  260. return inv_freq, attention_factor
  261. @torch.no_grad()
  262. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  263. def forward(self, x, position_ids, layer_type=None):
  264. inv_freq = getattr(self, f"{layer_type}_inv_freq")
  265. attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
  266. inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  267. position_ids_expanded = position_ids[:, None, :].float()
  268. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  269. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  270. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  271. emb = torch.cat((freqs, freqs), dim=-1)
  272. cos = emb.cos() * attention_scaling
  273. sin = emb.sin() * attention_scaling
  274. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  275. # Weird way to inherit but otherwise the sliding window gets defined first and can't access `is_sliding`
  276. class Gemma3Attention(Gemma2Attention):
  277. def __init__(self, config: Gemma3TextConfig, layer_idx: int):
  278. super().__init__(config, layer_idx)
  279. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  280. self.is_sliding = self.layer_type == "sliding_attention"
  281. self.is_causal = not self.config.use_bidirectional_attention
  282. self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  283. self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  284. def forward(
  285. self,
  286. hidden_states: torch.Tensor,
  287. position_embeddings: torch.Tensor = None,
  288. attention_mask: torch.Tensor | None = None,
  289. past_key_values: Cache | None = None,
  290. **kwargs: Unpack[TransformersKwargs],
  291. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  292. input_shape = hidden_states.shape[:-1]
  293. hidden_shape = (*input_shape, -1, self.head_dim)
  294. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  295. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  296. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  297. query_states = self.q_norm(query_states)
  298. key_states = self.k_norm(key_states)
  299. cos, sin = position_embeddings
  300. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  301. if past_key_values is not None:
  302. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  303. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  304. self.config._attn_implementation, eager_attention_forward
  305. )
  306. attn_output, attn_weights = attention_interface(
  307. self,
  308. query_states,
  309. key_states,
  310. value_states,
  311. attention_mask,
  312. dropout=self.attention_dropout if self.training else 0.0,
  313. scaling=self.scaling,
  314. sliding_window=self.sliding_window,
  315. **kwargs,
  316. )
  317. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  318. attn_output = self.o_proj(attn_output)
  319. return attn_output, attn_weights
  320. class Gemma3DecoderLayer(GradientCheckpointingLayer):
  321. def __init__(self, config: Gemma3TextConfig, layer_idx: int):
  322. super().__init__()
  323. self.config = config
  324. self.hidden_size = config.hidden_size
  325. self.layer_idx = layer_idx
  326. self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
  327. self.mlp = Gemma3MLP(config)
  328. self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  329. self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  330. self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  331. self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  332. def forward(
  333. self,
  334. hidden_states: torch.Tensor,
  335. position_embeddings: torch.Tensor = None,
  336. attention_mask: torch.Tensor | None = None,
  337. position_ids: torch.LongTensor | None = None,
  338. past_key_values: Cache | None = None,
  339. **kwargs: Unpack[TransformersKwargs],
  340. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  341. residual = hidden_states
  342. hidden_states = self.input_layernorm(hidden_states)
  343. hidden_states, _ = self.self_attn(
  344. hidden_states=hidden_states,
  345. position_embeddings=position_embeddings,
  346. attention_mask=attention_mask,
  347. position_ids=position_ids,
  348. past_key_values=past_key_values,
  349. **kwargs,
  350. )
  351. hidden_states = self.post_attention_layernorm(hidden_states)
  352. hidden_states = residual + hidden_states
  353. residual = hidden_states
  354. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  355. hidden_states = self.mlp(hidden_states)
  356. hidden_states = self.post_feedforward_layernorm(hidden_states)
  357. hidden_states = residual + hidden_states
  358. return hidden_states
  359. GEMMA3_START_DOCSTRING = None
  360. class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
  361. base_model_prefix = "model"
  362. input_modalities = ("image", "text")
  363. _no_split_modules = [
  364. "Gemma3DecoderLayer",
  365. "SiglipVisionEmbeddings",
  366. "SiglipEncoderLayer",
  367. "SiglipMultiheadAttentionPoolingHead",
  368. ]
  369. @torch.no_grad()
  370. def _init_weights(self, module):
  371. PreTrainedModel._init_weights(self, module)
  372. if isinstance(module, Gemma3MultiModalProjector):
  373. init.zeros_(module.mm_input_projection_weight)
  374. # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
  375. elif "RMSNorm" in module.__class__.__name__:
  376. init.zeros_(module.weight)
  377. elif isinstance(module, Gemma3TextScaledWordEmbedding):
  378. init.constant_(module.embed_scale, module.scalar_embed_scale)
  379. elif isinstance(module, Gemma3RotaryEmbedding):
  380. for layer_type in module.layer_types:
  381. rope_init_fn = module.compute_default_rope_parameters
  382. if module.rope_type[layer_type] != "default":
  383. rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
  384. curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
  385. init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
  386. init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
  387. def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
  388. """
  389. Enables a bidirectional mask within the sliding window.
  390. """
  391. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  392. """A token can attend to any other token if their absolute distance is within
  393. the (exclusive) sliding window size (distance < sliding_window)."""
  394. return abs(q_idx - kv_idx) < sliding_window
  395. return inner_mask
  396. class Gemma3TextModel(Gemma2Model):
  397. config: Gemma3TextConfig
  398. input_modalities = ("text",)
  399. def __init__(self, config: Gemma3TextConfig):
  400. super().__init__(config)
  401. # Gemma3 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
  402. self.embed_tokens = Gemma3TextScaledWordEmbedding(
  403. config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
  404. )
  405. def forward(
  406. self,
  407. input_ids: torch.LongTensor | None = None,
  408. attention_mask: torch.Tensor | None = None,
  409. position_ids: torch.LongTensor | None = None,
  410. past_key_values: Cache | None = None,
  411. inputs_embeds: torch.FloatTensor | None = None,
  412. use_cache: bool | None = None,
  413. **kwargs: Unpack[TransformersKwargs],
  414. ) -> BaseModelOutputWithPast:
  415. if (input_ids is None) ^ (inputs_embeds is not None):
  416. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  417. if inputs_embeds is None:
  418. inputs_embeds = self.embed_tokens(input_ids)
  419. if use_cache and past_key_values is None:
  420. past_key_values = DynamicCache(config=self.config)
  421. if position_ids is None:
  422. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  423. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  424. position_ids = position_ids.unsqueeze(0)
  425. # It may already have been prepared by e.g. `generate`
  426. if not isinstance(causal_mask_mapping := attention_mask, dict):
  427. # Prepare mask arguments
  428. mask_kwargs = {
  429. "config": self.config,
  430. "inputs_embeds": inputs_embeds,
  431. "attention_mask": attention_mask,
  432. "past_key_values": past_key_values,
  433. "position_ids": position_ids,
  434. }
  435. sliding_mask_kwargs = mask_kwargs.copy()
  436. if self.config.use_bidirectional_attention:
  437. mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool)
  438. sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window)
  439. # Create the masks
  440. causal_mask_mapping = {
  441. "full_attention": create_causal_mask(**mask_kwargs),
  442. "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
  443. }
  444. # embed positions
  445. hidden_states = inputs_embeds
  446. position_embeddings = {}
  447. for layer_type in self.config.layer_types:
  448. position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
  449. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  450. hidden_states = decoder_layer(
  451. hidden_states,
  452. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  453. position_embeddings=position_embeddings[self.config.layer_types[i]],
  454. position_ids=position_ids,
  455. past_key_values=past_key_values,
  456. **kwargs,
  457. )
  458. hidden_states = self.norm(hidden_states)
  459. return BaseModelOutputWithPast(
  460. last_hidden_state=hidden_states,
  461. past_key_values=past_key_values,
  462. )
  463. class Gemma3ForCausalLM(Gemma2ForCausalLM):
  464. config: Gemma3TextConfig
  465. def __init__(self, config: Gemma3TextConfig):
  466. super().__init__(config)
  467. self.model = Gemma3TextModel(config)
  468. class Gemma3MultiModalProjector(nn.Module):
  469. def __init__(self, config: Gemma3Config):
  470. super().__init__()
  471. self.mm_input_projection_weight = nn.Parameter(
  472. torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
  473. )
  474. self.mm_soft_emb_norm = Gemma3RMSNorm(
  475. config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
  476. )
  477. self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)
  478. self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
  479. self.kernel_size = self.patches_per_image // self.tokens_per_side
  480. self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
  481. def forward(self, vision_outputs: torch.Tensor):
  482. batch_size, _, hidden_size = vision_outputs.shape
  483. reshaped_vision_outputs = vision_outputs.transpose(1, 2)
  484. reshaped_vision_outputs = reshaped_vision_outputs.reshape(
  485. batch_size, hidden_size, self.patches_per_image, self.patches_per_image
  486. )
  487. reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
  488. pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
  489. pooled_vision_outputs = pooled_vision_outputs.flatten(2)
  490. pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
  491. normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
  492. projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)
  493. return projected_vision_outputs.type_as(vision_outputs)
  494. @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
  495. def create_causal_mask_mapping(
  496. config: PreTrainedConfig,
  497. inputs_embeds: torch.Tensor,
  498. attention_mask: torch.Tensor | None,
  499. past_key_values: Cache | None,
  500. position_ids: torch.Tensor | None,
  501. token_type_ids: torch.Tensor | None = None,
  502. pixel_values: torch.FloatTensor | None = None,
  503. is_training: bool = False,
  504. is_first_iteration: bool | None = None,
  505. **kwargs,
  506. ) -> dict:
  507. """
  508. Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
  509. for all kinds of forward passes. Gemma3 uses a bidirectional mask for images.
  510. Uses `pixel_values` as an optional input to disambiguate edge cases.
  511. """
  512. if is_training and token_type_ids is None:
  513. raise ValueError("`token_type_ids` is required as a model input when training")
  514. mask_kwargs = {
  515. "config": config.get_text_config(),
  516. "inputs_embeds": inputs_embeds,
  517. "attention_mask": attention_mask,
  518. "past_key_values": past_key_values,
  519. "position_ids": position_ids,
  520. }
  521. # NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
  522. # (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
  523. # means). Determining prefill in that case requires checking data values, which is not compile-compatible.
  524. is_first_iteration = (
  525. is_first_iteration
  526. if is_first_iteration is not None
  527. else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
  528. )
  529. if token_type_ids is not None and is_first_iteration:
  530. # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
  531. # undo the causal masking)
  532. # First find where a new image block starts: 1 if image and previous not image
  533. # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
  534. is_image = (token_type_ids == 1).to(inputs_embeds.device)
  535. is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
  536. new_image_start = is_image & ~is_previous_image
  537. group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
  538. group_ids = torch.where(is_image, group_ids, -1)
  539. mask_kwargs["or_mask_function"] = token_type_ids_mask_function(group_ids)
  540. return create_masks_for_generate(**mask_kwargs)
  541. class Gemma3Model(PaliGemmaModel):
  542. # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
  543. accepts_loss_kwargs = False
  544. def __init__(self, config: Gemma3Config):
  545. super().__init__(config)
  546. del self.text_config_dtype
  547. @can_return_tuple
  548. @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.")
  549. def get_image_features(
  550. self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  551. ) -> tuple | BaseModelOutputWithPooling:
  552. vision_outputs = self.vision_tower(pixel_values=pixel_values, return_dict=True, **kwargs)
  553. last_hidden_state = vision_outputs.last_hidden_state
  554. vision_outputs.pooler_output = self.multi_modal_projector(last_hidden_state)
  555. return vision_outputs
  556. @can_return_tuple
  557. @auto_docstring
  558. def forward(
  559. self,
  560. input_ids: torch.LongTensor | None = None,
  561. pixel_values: torch.FloatTensor | None = None,
  562. attention_mask: torch.Tensor | None = None,
  563. position_ids: torch.LongTensor | None = None,
  564. past_key_values: Cache | None = None,
  565. token_type_ids: torch.LongTensor | None = None,
  566. inputs_embeds: torch.FloatTensor | None = None,
  567. labels: torch.LongTensor | None = None,
  568. use_cache: bool | None = None,
  569. **lm_kwargs: Unpack[TransformersKwargs],
  570. ) -> tuple | Gemma3ModelOutputWithPast:
  571. if (input_ids is None) ^ (inputs_embeds is not None):
  572. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  573. # Replace image id with PAD if the image token if OOV, to avoid index-errors
  574. if input_ids is not None and self.config.image_token_id >= self.vocab_size:
  575. special_image_mask = input_ids == self.config.image_token_id
  576. llm_input_ids = input_ids.clone()
  577. llm_input_ids[special_image_mask] = 0
  578. else:
  579. llm_input_ids = input_ids
  580. if inputs_embeds is None:
  581. inputs_embeds = self.get_input_embeddings()(llm_input_ids)
  582. # Merge text and images
  583. if pixel_values is not None:
  584. image_features = self.get_image_features(pixel_values, return_dict=True).pooler_output
  585. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  586. special_image_mask = self.get_placeholder_mask(
  587. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  588. )
  589. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  590. # It may already have been prepared by e.g. `generate`
  591. if not isinstance(causal_mask_mapping := attention_mask, dict):
  592. causal_mask_mapping = create_causal_mask_mapping(
  593. self.config,
  594. inputs_embeds,
  595. attention_mask,
  596. past_key_values,
  597. position_ids,
  598. token_type_ids,
  599. pixel_values,
  600. is_training=self.training,
  601. )
  602. outputs = self.language_model(
  603. attention_mask=causal_mask_mapping,
  604. position_ids=position_ids,
  605. past_key_values=past_key_values,
  606. inputs_embeds=inputs_embeds,
  607. use_cache=use_cache,
  608. return_dict=True,
  609. **lm_kwargs,
  610. )
  611. return Gemma3ModelOutputWithPast(
  612. last_hidden_state=outputs.last_hidden_state,
  613. past_key_values=outputs.past_key_values,
  614. hidden_states=outputs.hidden_states,
  615. attentions=outputs.attentions,
  616. image_hidden_states=image_features if pixel_values is not None else None,
  617. )
  618. class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
  619. # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
  620. # Fix: https://github.com/huggingface/transformers/issues/40564
  621. accepts_loss_kwargs = False
  622. @can_return_tuple
  623. @auto_docstring
  624. def forward(
  625. self,
  626. input_ids: torch.LongTensor | None = None,
  627. pixel_values: torch.FloatTensor | None = None,
  628. attention_mask: torch.Tensor | None = None,
  629. position_ids: torch.LongTensor | None = None,
  630. past_key_values: Cache | None = None,
  631. token_type_ids: torch.LongTensor | None = None,
  632. inputs_embeds: torch.FloatTensor | None = None,
  633. labels: torch.LongTensor | None = None,
  634. use_cache: bool | None = None,
  635. logits_to_keep: int | torch.Tensor = 0,
  636. **lm_kwargs: Unpack[TransformersKwargs],
  637. ) -> tuple | Gemma3CausalLMOutputWithPast:
  638. r"""
  639. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  640. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  641. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  642. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
  643. Example:
  644. ```python
  645. >>> from PIL import Image
  646. >>> import httpx
  647. >>> from io import BytesIO
  648. >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
  649. >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
  650. >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
  651. >>> messages = [
  652. ... {
  653. ... "role": "system",
  654. ... "content": [
  655. ... {"type": "text", "text": "You are a helpful assistant."}
  656. ... ]
  657. ... },
  658. ... {
  659. ... "role": "user", "content": [
  660. ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
  661. ... {"type": "text", "text": "Where is the cat standing?"},
  662. ... ]
  663. ... },
  664. ... ]
  665. >>> inputs = processor.apply_chat_template(
  666. ... messages,
  667. ... tokenize=True,
  668. ... return_dict=True,
  669. ... return_tensors="pt",
  670. ... add_generation_prompt=True
  671. ... )
  672. >>> # Generate
  673. >>> generate_ids = model.generate(**inputs)
  674. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  675. "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
  676. ```
  677. """
  678. outputs = self.model(
  679. input_ids=input_ids,
  680. pixel_values=pixel_values,
  681. token_type_ids=token_type_ids,
  682. attention_mask=attention_mask,
  683. position_ids=position_ids,
  684. past_key_values=past_key_values,
  685. inputs_embeds=inputs_embeds,
  686. use_cache=use_cache,
  687. labels=labels,
  688. **lm_kwargs,
  689. )
  690. hidden_states = outputs[0]
  691. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  692. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  693. logits = self.lm_head(hidden_states[:, slice_indices, :])
  694. loss = None
  695. if labels is not None:
  696. # Upcast to float if we need to compute the loss to avoid potential precision issues
  697. logits = logits.float()
  698. shift_logits = logits[..., :-1, :]
  699. shift_labels = labels[..., 1:]
  700. if attention_mask is not None:
  701. # we use the input attention mask to shift the logits and labels, because it is 2D.
  702. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
  703. shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
  704. shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
  705. shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
  706. else:
  707. shift_logits = shift_logits.contiguous()
  708. shift_labels = shift_labels.contiguous()
  709. # Flatten the tokens
  710. loss_fct = nn.CrossEntropyLoss()
  711. flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
  712. flat_labels = shift_labels.view(-1).to(shift_logits.device)
  713. loss = loss_fct(flat_logits, flat_labels)
  714. return Gemma3CausalLMOutputWithPast(
  715. loss=loss,
  716. logits=logits,
  717. past_key_values=outputs.past_key_values,
  718. hidden_states=outputs.hidden_states,
  719. attentions=outputs.attentions,
  720. image_hidden_states=outputs.image_hidden_states,
  721. )
  722. def prepare_inputs_for_generation(
  723. self,
  724. input_ids,
  725. past_key_values=None,
  726. inputs_embeds=None,
  727. position_ids=None,
  728. pixel_values=None,
  729. attention_mask=None,
  730. token_type_ids=None,
  731. use_cache=True,
  732. logits_to_keep=None,
  733. labels=None,
  734. is_first_iteration=False,
  735. **kwargs,
  736. ):
  737. # Overwritten -- custom `pixel_values` handling
  738. model_inputs = super().prepare_inputs_for_generation(
  739. input_ids,
  740. past_key_values=past_key_values,
  741. inputs_embeds=inputs_embeds,
  742. attention_mask=attention_mask,
  743. position_ids=position_ids,
  744. use_cache=use_cache,
  745. logits_to_keep=logits_to_keep,
  746. token_type_ids=token_type_ids,
  747. is_first_iteration=is_first_iteration,
  748. **kwargs,
  749. )
  750. # Pixel values are used only in the first iteration if available
  751. # In subsequent iterations, they are already merged with text and cached
  752. # NOTE: first iteration doesn't have to be prefill, it can be the first
  753. # iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always
  754. if is_first_iteration or not use_cache:
  755. model_inputs["pixel_values"] = pixel_values
  756. return model_inputs
  757. class Gemma3ForSequenceClassification(Gemma3PreTrainedModel):
  758. def __init__(self, config):
  759. super().__init__(config)
  760. self.num_labels = config.num_labels
  761. self.model = Gemma3Model(config)
  762. self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)
  763. # Initialize weights and apply final processing
  764. self.post_init()
  765. def get_input_embeddings(self):
  766. return self.model.get_input_embeddings()
  767. def set_input_embeddings(self, value):
  768. self.model.set_input_embeddings(value)
  769. @can_return_tuple
  770. @auto_docstring
  771. def forward(
  772. self,
  773. input_ids: torch.LongTensor | None = None,
  774. pixel_values: torch.FloatTensor | None = None,
  775. attention_mask: torch.Tensor | None = None,
  776. position_ids: torch.LongTensor | None = None,
  777. past_key_values: Cache | None = None,
  778. inputs_embeds: torch.FloatTensor | None = None,
  779. token_type_ids: torch.LongTensor | None = None,
  780. labels: torch.LongTensor | None = None,
  781. use_cache: bool | None = None,
  782. **kwargs: Unpack[TransformersKwargs],
  783. ) -> SequenceClassifierOutputWithPast:
  784. r"""
  785. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  786. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  787. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  788. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  789. """
  790. transformer_outputs = self.model(
  791. input_ids,
  792. attention_mask=attention_mask,
  793. pixel_values=pixel_values,
  794. position_ids=position_ids,
  795. past_key_values=past_key_values,
  796. inputs_embeds=inputs_embeds,
  797. token_type_ids=token_type_ids,
  798. use_cache=use_cache,
  799. **kwargs,
  800. )
  801. hidden_states = transformer_outputs.last_hidden_state
  802. logits = self.score(hidden_states)
  803. if input_ids is not None:
  804. batch_size = input_ids.shape[0]
  805. else:
  806. batch_size = inputs_embeds.shape[0]
  807. if self.config.text_config.pad_token_id is None and batch_size != 1:
  808. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  809. if self.config.text_config.pad_token_id is None:
  810. last_non_pad_token = -1
  811. elif input_ids is not None:
  812. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  813. non_pad_mask = (input_ids != self.config.text_config.pad_token_id).to(logits.device, torch.int32)
  814. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  815. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  816. else:
  817. last_non_pad_token = -1
  818. logger.warning_once(
  819. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  820. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  821. )
  822. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  823. loss = None
  824. if labels is not None:
  825. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  826. return SequenceClassifierOutputWithPast(
  827. loss=loss,
  828. logits=pooled_logits,
  829. past_key_values=transformer_outputs.past_key_values,
  830. hidden_states=transformer_outputs.hidden_states,
  831. attentions=transformer_outputs.attentions,
  832. )
  833. class Gemma3TextForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel):
  834. """
  835. Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig.
  836. It uses the generic sequence classification implementation for efficiency and consistency.
  837. """
  838. config: Gemma3TextConfig
  839. input_modalities = ("text",)
  840. __all__ = [
  841. "Gemma3Config",
  842. "Gemma3TextConfig",
  843. "Gemma3PreTrainedModel",
  844. "Gemma3TextModel",
  845. "Gemma3ForCausalLM",
  846. "Gemma3ForConditionalGeneration",
  847. "Gemma3Model",
  848. "Gemma3ForSequenceClassification",
  849. "Gemma3TextForSequenceClassification",
  850. ]