| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/gemma4/modular_gemma4.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_gemma4.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2026 the HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from collections.abc import Callable
- from dataclasses import dataclass
- from functools import cached_property
- from typing import Optional
- import torch
- from torch import nn
- from torch.nn import functional as F
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache
- from ...configuration_utils import PreTrainedConfig
- from ...generation import GenerationMixin
- from ...integrations import use_experts_implementation, use_kernelized_func
- from ...masking_utils import (
- create_bidirectional_mask,
- create_causal_mask,
- create_masks_for_generate,
- create_sliding_window_causal_mask,
- )
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
- from ...utils.generic import maybe_autocast, merge_with_config_defaults
- from ...utils.output_capturing import OutputRecorder, capture_outputs
- from ..auto.modeling_auto import AutoModel
- from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for Gemma4 outputs, with hidden states and attentions.
- """
- )
- class Gemma4ModelOutputWithPast(BaseModelOutputWithPast):
- r"""
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
- audio_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
- audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
- """
- image_hidden_states: torch.FloatTensor | None = None
- audio_hidden_states: torch.FloatTensor | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for Gemma4 causal language model (or autoregressive) outputs.
- """
- )
- class Gemma4CausalLMOutputWithPast(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
- audio_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
- audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor | None = None
- past_key_values: Cache | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- image_hidden_states: torch.FloatTensor | None = None
- audio_hidden_states: torch.FloatTensor | None = None
- @dataclass
- @auto_docstring
- class Gemma4AudioModelOutput(BaseModelOutputWithPooling):
- r"""
- attention_mask (`torch.BoolTensor`, *optional*):
- A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding.
- """
- attention_mask: torch.BoolTensor | None = None
- class Gemma4ClippableLinear(nn.Module):
- def __init__(
- self,
- config: Gemma4VisionConfig | Gemma4AudioConfig,
- in_features: int,
- out_features: int,
- ) -> None:
- super().__init__()
- self.use_clipped_linears = config.use_clipped_linears
- self.linear = nn.Linear(in_features, out_features, bias=False)
- if self.use_clipped_linears:
- self.register_buffer("input_min", torch.tensor(-float("inf")))
- self.register_buffer("input_max", torch.tensor(float("inf")))
- self.register_buffer("output_min", torch.tensor(-float("inf")))
- self.register_buffer("output_max", torch.tensor(float("inf")))
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- if self.use_clipped_linears:
- hidden_states = torch.clamp(hidden_states, self.input_min, self.input_max)
- hidden_states = self.linear(hidden_states)
- if self.use_clipped_linears:
- hidden_states = torch.clamp(hidden_states, self.output_min, self.output_max)
- return hidden_states
- class Gemma4RMSNorm(nn.Module):
- def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
- super().__init__()
- self.eps = eps
- self.with_scale = with_scale
- if self.with_scale:
- self.weight = nn.Parameter(torch.ones(dim), requires_grad=True)
- def _norm(self, hidden_states: torch.Tensor):
- mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps
- # Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX
- return hidden_states * torch.pow(mean_squared, -0.5)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- normed_output = self._norm(hidden_states.float())
- if self.with_scale:
- normed_output = normed_output * self.weight.float()
- return normed_output.type_as(hidden_states)
- class Gemma4AudioRelPositionalEncoding(nn.Module):
- """Sinusoidal relative positional encoding for the audio encoder.
- Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with
- concatenated [sin..., cos...] layout matching the original Gemma4 convention.
- """
- inv_timescales: torch.Tensor
- def __init__(self, config: Gemma4AudioConfig):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.context_size = (
- config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right
- )
- min_timescale = 1.0
- max_timescale = 10000.0
- num_timescales = self.hidden_size // 2
- log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1)
- inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
- self.register_buffer("inv_timescales", inv_timescales.unsqueeze(0).unsqueeze(0), persistent=False)
- @torch.no_grad()
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- position_ids = torch.arange(12, -1, -1, device=hidden_states.device)
- position_ids = position_ids[..., None]
- scaled_time = position_ids * self.inv_timescales.to(device=hidden_states.device)
- pos_embed = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
- return pos_embed.to(dtype=hidden_states.dtype)
- class Gemma4AudioAttention(nn.Module):
- """Chunked local attention with relative position bias"""
- def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- self.attention_logits_soft_cap = config.attention_logit_cap
- self.head_dim = config.hidden_size // config.num_attention_heads
- self.num_heads = config.num_attention_heads
- self.q_scale = (self.head_dim**-0.5) / math.log(2)
- self.k_scale = math.log(1 + math.e) / math.log(2)
- self.chunk_size = config.attention_chunk_size
- self.max_past_horizon = config.attention_context_left - 1
- self.max_future_horizon = config.attention_context_right
- self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
- self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
- self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
- self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
- self.post = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
- self.relative_k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
- self.per_dim_scale = nn.Parameter(torch.zeros(self.head_dim))
- self.register_buffer("softcap", torch.tensor(self.attention_logits_soft_cap), persistent=False)
- def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """Splits a `(batch_size, seq_len, num_heads, head_dim)` tensor into non-overlapping blocks of `chunk_size` along the sequence dim."""
- batch_size, seq_len, num_heads, head_dim = hidden_states.shape
- num_blocks = (seq_len + self.chunk_size - 1) // self.chunk_size
- pad = num_blocks * self.chunk_size - seq_len
- hidden_states = F.pad(hidden_states, (0, 0, 0, 0, 0, pad))
- return hidden_states.reshape(batch_size, num_blocks, self.chunk_size, num_heads, head_dim).contiguous()
- def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """Extracts overlapping context windows of `context_size` for every block, strided by `chunk_size`."""
- batch_size, seq_len, num_heads, head_dim = hidden_states.shape
- hidden_states = F.pad(
- hidden_states, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1)
- )
- hidden_states = hidden_states.unfold(1, self.context_size, self.chunk_size)
- hidden_states = torch.movedim(hidden_states, -1, 2)
- return hidden_states.contiguous()
- def _rel_shift(self, x: torch.Tensor) -> torch.Tensor:
- """Relative position shift for blocked attention. See appendix B of https://huggingface.co/papers/1901.02860."""
- batch_size, num_heads, num_blocks, block_size, position_length = x.shape
- context_size = self.context_size
- x = F.pad(x, (0, context_size + 1 - position_length))
- x = x.view(batch_size, num_heads, num_blocks, block_size * (context_size + 1))
- x = x[..., : block_size * context_size]
- return x.view(batch_size, num_heads, num_blocks, block_size, context_size)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: torch.Tensor,
- attention_mask: torch.BoolTensor | None = None,
- ) -> tuple[torch.Tensor, None]:
- batch_size, seq_length, _ = hidden_states.shape
- hidden_shape = (batch_size, seq_length, self.num_heads, self.head_dim)
- query_states = self.q_proj(hidden_states).float().view(hidden_shape)
- key_states = self.k_proj(hidden_states).float().view(hidden_shape)
- value_states = self.v_proj(hidden_states).float().view(hidden_shape)
- query_states = query_states * self.q_scale * F.softplus(self.per_dim_scale)
- key_states = key_states * self.k_scale
- query_states = self._convert_to_block(query_states)
- key_states = self._extract_block_context(key_states)
- value_states = self._extract_block_context(value_states)
- num_blocks = query_states.shape[1]
- relative_key_states = self.relative_k_proj(position_embeddings)
- relative_key_states = relative_key_states.view(-1, self.num_heads, self.head_dim)
- relative_key_states = relative_key_states.to(dtype=query_states.dtype)
- queries = query_states.permute(0, 3, 1, 2, 4)
- matrix_ac = queries @ key_states.permute(0, 3, 1, 4, 2)
- queries_flat = queries.reshape(batch_size, self.num_heads, -1, self.head_dim)
- matrix_bd = queries_flat @ relative_key_states.permute(1, 2, 0)
- matrix_bd = matrix_bd.reshape(batch_size, self.num_heads, num_blocks, self.chunk_size, -1)
- matrix_bd = self._rel_shift(matrix_bd)
- attn_weights = matrix_ac + matrix_bd
- attn_weights = attn_weights / self.softcap
- attn_weights = torch.tanh(attn_weights)
- attn_weights = attn_weights * self.softcap
- if attention_mask is not None:
- attn_weights = attn_weights.masked_fill(
- attention_mask.logical_not(), self.config.attention_invalid_logits_value
- )
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
- attn_output = attn_weights @ value_states.permute(0, 3, 1, 2, 4)
- attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, num_blocks * self.chunk_size, -1)
- attn_output = attn_output[:, :seq_length].contiguous()
- attn_output = self.post(attn_output.to(dtype=self.post.linear.weight.dtype))
- return attn_output, attn_weights
- class Gemma4AudioSubSampleConvProjectionLayer(nn.Module):
- def __init__(self, in_channels, out_channels, norm_eps):
- super().__init__()
- self.conv = nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=(3, 3),
- stride=(2, 2),
- padding=1,
- bias=False,
- )
- self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False)
- self.act = nn.ReLU()
- def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor | None = None):
- if mask is not None:
- mask = mask.to(device=hidden_states.device)
- hidden_states = hidden_states * mask[:, None, :, None]
- hidden_states = self.conv(hidden_states.to(self.conv.weight.dtype))
- hidden_states = self.act(self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous())
- if mask is not None:
- mask = mask[:, ::2]
- return hidden_states, mask
- class Gemma4AudioSubSampleConvProjection(nn.Module):
- def __init__(self, config: Gemma4AudioConfig):
- super().__init__()
- self.layer0 = Gemma4AudioSubSampleConvProjectionLayer(
- in_channels=1,
- out_channels=config.subsampling_conv_channels[0],
- norm_eps=config.rms_norm_eps,
- )
- self.layer1 = Gemma4AudioSubSampleConvProjectionLayer(
- in_channels=config.subsampling_conv_channels[0],
- out_channels=config.subsampling_conv_channels[1],
- norm_eps=config.rms_norm_eps,
- )
- proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1]
- self.input_proj_linear = nn.Linear(proj_input_dim, config.hidden_size, bias=False)
- def forward(
- self,
- input_features: torch.Tensor,
- input_features_mask: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- hidden_states = input_features.unsqueeze(1)
- hidden_states, mask = self.layer0(hidden_states, input_features_mask)
- hidden_states, mask = self.layer1(hidden_states, mask)
- batch_size, _, seq_len, _ = hidden_states.shape
- hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1)
- return self.input_proj_linear(hidden_states), mask
- class Gemma4AudioFeedForward(nn.Module):
- def __init__(self, config: Gemma4AudioConfig):
- super().__init__()
- self.config = config
- self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4)
- self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size)
- self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size)
- self.post_layer_norm = Gemma4RMSNorm(config.hidden_size)
- self.act_fn = ACT2FN[config.hidden_act]
- self.gradient_clipping = config.gradient_clipping
- self.post_layer_scale = config.residual_weight
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- # This is needed to avoid any underflow/overflow issues when clipping
- gradient_clipping = min(self.gradient_clipping, torch.finfo(self.ffw_layer_1.linear.weight.dtype).max)
- residual = hidden_states
- hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
- hidden_states = self.pre_layer_norm(hidden_states)
- hidden_states = self.ffw_layer_1(hidden_states)
- hidden_states = self.act_fn(hidden_states)
- hidden_states = self.ffw_layer_2(hidden_states)
- hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
- hidden_states = self.post_layer_norm(hidden_states)
- hidden_states *= self.post_layer_scale
- hidden_states += residual
- return hidden_states
- # TODO: this could be imported from Voxtral realtime
- class Gemma4AudioCausalConv1d(nn.Conv1d):
- # def __init__(
- # self,
- # in_channels: int,
- # out_channels: int,
- # kernel_size: int,
- # # cache_key: str,
- # stride: int = 1,
- # dilation: int = 1,
- # bias: bool = True,
- # ):
- # super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias)
- # self.cache_key = cache_key
- @cached_property
- def left_pad(self):
- effective_kernel_size = (self.kernel_size[0] - 1) * self.dilation[0] + 1
- return effective_kernel_size - self.stride[0]
- def forward(
- self,
- x: torch.Tensor,
- # padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, # TODO: we might want to add a cache?
- ) -> torch.Tensor:
- # if padding_cache is not None:
- # x = padding_cache.update(x, self.cache_key, self)
- # else:
- # x = nn.functional.pad(x, (self.left_pad, 0))
- x = nn.functional.pad(x, (self.left_pad, 0))
- return super().forward(x)
- class Gemma4AudioLightConv1d(nn.Module):
- def __init__(self, config: Gemma4AudioConfig):
- super().__init__()
- self.config = config
- self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2)
- self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
- self.depthwise_conv1d = Gemma4AudioCausalConv1d(
- in_channels=config.hidden_size,
- out_channels=config.hidden_size,
- kernel_size=config.conv_kernel_size,
- groups=config.hidden_size,
- bias=False,
- )
- self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True)
- self.conv_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True)
- self.act_fn = ACT2FN[config.hidden_act]
- self.gradient_clipping = config.gradient_clipping
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.pre_layer_norm(hidden_states)
- hidden_states = self.linear_start(hidden_states)
- hidden_states = nn.functional.glu(hidden_states, dim=-1)
- hidden_states = self.depthwise_conv1d(hidden_states.transpose(1, 2)).transpose(1, 2)
- # This is needed to avoid any underflow/overflow issues when clipping
- gradient_clipping = min(self.gradient_clipping, torch.finfo(self.linear_start.linear.weight.dtype).max)
- hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
- hidden_states = self.conv_norm(hidden_states)
- hidden_states = self.act_fn(hidden_states)
- hidden_states = self.linear_end(hidden_states)
- hidden_states += residual
- return hidden_states
- class Gemma4AudioLayer(nn.Module):
- def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
- super().__init__()
- self.config = config
- self.feed_forward1 = Gemma4AudioFeedForward(config)
- self.feed_forward2 = Gemma4AudioFeedForward(config)
- self.self_attn = Gemma4AudioAttention(config, layer_idx)
- self.lconv1d = Gemma4AudioLightConv1d(config)
- self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size)
- self.norm_post_attn = Gemma4RMSNorm(config.hidden_size)
- self.norm_out = Gemma4RMSNorm(config.hidden_size)
- self.gradient_clipping = config.gradient_clipping
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.BoolTensor | None,
- position_embeddings: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- # This is needed to avoid any underflow/overflow issues when clipping
- gradient_clipping = min(self.gradient_clipping, torch.finfo(self.norm_pre_attn.weight.dtype).max)
- hidden_states = self.feed_forward1(hidden_states)
- residual = hidden_states
- hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
- hidden_states = self.norm_pre_attn(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- )
- hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
- hidden_states = self.norm_post_attn(hidden_states)
- hidden_states += residual
- hidden_states = self.lconv1d(hidden_states)
- hidden_states = self.feed_forward2(hidden_states)
- hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
- hidden_states = self.norm_out(hidden_states)
- return hidden_states
- # ---- Vision Encoder Layers ----
- class Gemma4VisionPatchEmbedder(nn.Module):
- def __init__(self, config: Gemma4VisionConfig):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.patch_size = config.patch_size
- self.position_embedding_size = config.position_embedding_size
- self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False)
- self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size))
- def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor:
- """Prepare patch positions map for matmul with positon embedding table."""
- # Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul.
- clamped_positions = pixel_position_ids.clamp(min=0)
- one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size)
- one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table)
- # Compute positional embeddings and sum across x and y.
- position_embeddings = one_hot @ self.position_embedding_table
- position_embeddings = position_embeddings.sum(dim=1)
- # Zero out embeddings for any padding patches.
- position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
- return position_embeddings
- def forward(
- self, pixel_values: torch.Tensor, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor
- ) -> torch.Tensor:
- # Gemma4 applies no normalization and instead scales in model code
- pixel_values = 2 * (pixel_values - 0.5)
- hidden_states = self.input_proj(pixel_values.to(self.input_proj.weight.dtype))
- position_embeddings = self._position_embeddings(pixel_position_ids, padding_positions)
- return hidden_states + position_embeddings
- class Gemma4VisionPooler(nn.Module):
- """Scaling and optional spatial pooling for vision encodings"""
- def __init__(self, config: Gemma4VisionConfig):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.root_hidden_size = self.hidden_size**0.5
- def _avg_pool_by_positions(
- self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- 2D spatial pooling according to patch positions.
- Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between
- input and output lengths
- """
- input_seq_len = hidden_states.shape[1]
- k = int((input_seq_len // length) ** 0.5)
- k_squared = k**2
- if k_squared * length != input_seq_len:
- raise ValueError(
- f"Cannot pool {hidden_states.shape} to {length}: {k=}^2 times {length=} must be {input_seq_len}."
- )
- # Clamp padding positions (which are -1) to 0 so they don't break one_hot.
- # Padding patches have zero hidden states so they contribute nothing to the average.
- clamped_positions = pixel_position_ids.clamp(min=0)
- max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1
- kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor")
- kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1]
- weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared
- output = weights.transpose(1, 2) @ hidden_states.float()
- mask = torch.logical_not((weights == 0).all(dim=1))
- return output.to(hidden_states.dtype), mask
- def forward(
- self,
- hidden_states: torch.Tensor,
- pixel_position_ids: torch.Tensor,
- padding_positions: torch.Tensor,
- output_length: int | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- if output_length > hidden_states.shape[1]:
- raise ValueError(
- f"Cannot output more soft tokens (requested {output_length}) than there are patches"
- f" ({hidden_states.shape[1]}). Change the value of `num_soft_tokens` when processing."
- )
- hidden_states = hidden_states.masked_fill(padding_positions.unsqueeze(-1), 0.0)
- if hidden_states.shape[1] != output_length:
- hidden_states, padding_positions = self._avg_pool_by_positions(
- hidden_states, pixel_position_ids, output_length
- )
- hidden_states *= self.root_hidden_size
- return hidden_states, padding_positions
- class Gemma4VisionMLP(nn.Module):
- def __init__(self, config: Gemma4VisionConfig):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
- self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
- self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size)
- self.act_fn = ACT2FN[config.hidden_activation]
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
- class Gemma4VisionRotaryEmbedding(nn.Module):
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, config: Gemma4VisionConfig, device=None):
- super().__init__()
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.rope_type = self.config.rope_parameters["rope_type"]
- rope_init_fn: Callable = self.compute_default_rope_parameters
- if self.rope_type != "default":
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
- @staticmethod
- def compute_default_rope_parameters(
- config: Gemma4VisionConfig | None = None,
- device: torch.device | None = None,
- seq_len: int | None = None,
- ) -> tuple["torch.Tensor", float]:
- """
- Computes the inverse frequencies according to the original RoPE implementation
- Args:
- config ([`~transformers.PreTrainedConfig`]):
- The model configuration.
- device (`torch.device`):
- The device to use for initialization of the inverse frequencies.
- seq_len (`int`, *optional*):
- The current sequence length. Unused for this type of RoPE.
- Returns:
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
- """
- base = config.rope_parameters["rope_theta"]
- dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
- # The reference implementation computes RoPE frequencies INDEPENDENTLY
- # for each spatial dimension using the partitioned head_dim (head_dim // ndim),
- # so both x and y dimensions get identical frequency ranges.
- # This is different from splitting the global inv_freq between dimensions.
- spatial_dim = dim // 2
- attention_factor = 1.0 # Unused in this type of RoPE
- inv_freq = 1.0 / (
- base
- ** (torch.arange(0, spatial_dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / spatial_dim)
- )
- return inv_freq, attention_factor
- @torch.no_grad()
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
- def forward(self, x, position_ids):
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- # Multidimensional positions: [batch, num_patches, ndim]. Apply rotations to each spatial dim separately
- all_cos, all_sin = [], []
- for i in range(2):
- dim_position_ids = position_ids[:, :, i]
- dim_position_ids_expanded = dim_position_ids[:, None, :].float()
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ dim_position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * self.attention_scaling
- sin = emb.sin() * self.attention_scaling
- all_cos.append(cos)
- all_sin.append(sin)
- cos = torch.cat(all_cos, dim=-1).to(dtype=x.dtype)
- sin = torch.cat(all_sin, dim=-1).to(dtype=x.dtype)
- return cos, sin
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- x (`torch.Tensor`): The tensor to embed.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- return (x * cos) + (rotate_half(x) * sin)
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- dropout: float | int = 0.0,
- scaling: float | None = None,
- softcap: float | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- if scaling is None:
- scaling = module.head_dim**-0.5
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if softcap is not None:
- attn_weights = attn_weights / softcap
- attn_weights = torch.tanh(attn_weights)
- attn_weights = attn_weights * softcap
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- def apply_multidimensional_rope(
- x: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
- position_ids: torch.Tensor,
- unsqueeze_dim: int = 2,
- ) -> torch.Tensor:
- """Applies multidimensional RoPE to inputs.
- Args:
- x (`torch.Tensor`): The tensor to embed.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- If position_ids.ndim + 2 == x.ndim, then this function passes through to `apply_rotary_pos_emb()`.
- Otherwise, position_ids is used to split the inputs, x, into multiple pieces, where each piece is fed to
- `apply_rotary_pos_emb()`, and then concatenated back together.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- Tensor of shape [B, L, N, H] with RoPE applied.
- """
- ndim = position_ids.shape[-1]
- num_input_channels = x.shape[-1]
- num_rotated_channels_per_dim = 2 * (num_input_channels // (2 * ndim))
- if num_rotated_channels_per_dim <= 0:
- raise ValueError(
- "Invalid configuration: num_rotated_channels_per_dim must be > 0, got"
- f" {num_rotated_channels_per_dim} (num_input_channels={num_input_channels},"
- f" ndim={ndim})"
- )
- # Correctly split the input tensor into ndim parts
- split_sizes = [num_rotated_channels_per_dim] * ndim
- x_parts = torch.split(x, split_sizes, dim=-1)
- cos_parts = torch.split(cos, split_sizes, dim=-1)
- sin_parts = torch.split(sin, split_sizes, dim=-1)
- y_parts = [
- apply_rotary_pos_emb(
- x=x_parts[k],
- cos=cos_parts[k],
- sin=sin_parts[k],
- unsqueeze_dim=unsqueeze_dim,
- )
- for k in range(ndim)
- ]
- return torch.cat(y_parts, dim=-1)
- @use_kernelized_func(apply_rotary_pos_emb)
- class Gemma4VisionAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
- super().__init__()
- self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
- self.config = config
- self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = 1.0
- self.attention_dropout = self.config.attention_dropout
- self.is_causal = False
- self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_attention_heads * self.head_dim)
- self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
- self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
- self.o_proj = Gemma4ClippableLinear(config, config.num_attention_heads * self.head_dim, config.hidden_size)
- self.q_norm = Gemma4RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
- self.k_norm = Gemma4RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
- self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: torch.Tensor = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- cos, sin = position_embeddings
- query_states = self.q_proj(hidden_states).view(hidden_shape)
- query_states = self.q_norm(query_states)
- query_states = apply_multidimensional_rope(query_states, cos, sin, position_ids)
- query_states = query_states.transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape)
- key_states = self.k_norm(key_states)
- key_states = apply_multidimensional_rope(key_states, cos, sin, position_ids)
- key_states = key_states.transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape)
- value_states = self.v_norm(value_states)
- value_states = value_states.transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=self.attention_dropout if self.training else 0.0,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class Gemma4VisionEncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.layer_idx = layer_idx
- self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx)
- self.mlp = Gemma4VisionMLP(config)
- self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: torch.Tensor = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- position_ids=position_ids,
- **kwargs,
- )
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.pre_feedforward_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = self.post_feedforward_layernorm(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- class Gemma4VisionEncoder(nn.Module):
- def __init__(self, config: Gemma4VisionConfig):
- super().__init__()
- self.config = config
- self.num_layers = config.num_hidden_layers
- self.rotary_emb = Gemma4VisionRotaryEmbedding(config)
- self.layers = nn.ModuleList(
- [Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)]
- )
- def forward(
- self,
- inputs_embeds: torch.Tensor,
- attention_mask: torch.Tensor,
- pixel_position_ids: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPast:
- r"""
- pixel_position_ids (torch.Tensor):
- Patch positions as (x, y) coordinates in the image as [batch, num_patches, 2].
- """
- attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- )
- # embed positions
- hidden_states = inputs_embeds
- position_embeddings = self.rotary_emb(hidden_states, pixel_position_ids)
- # decoder layers
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
- hidden_states = decoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- position_embeddings=position_embeddings,
- position_ids=pixel_position_ids,
- **kwargs,
- )
- return BaseModelOutputWithPast(last_hidden_state=hidden_states)
- class Gemma4TextMLP(nn.Module):
- def __init__(self, config: Gemma4TextConfig, layer_idx: int):
- super().__init__()
- first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers
- is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
- use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1)
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_activation]
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
- class Gemma4TextRotaryEmbedding(nn.Module):
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None):
- super().__init__()
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.layer_types = set(config.layer_types)
- self.rope_init_fns: dict[str, Callable[..., tuple[torch.Tensor, float]]] = {}
- self.rope_type: dict[str, str] = {}
- for layer_type in self.layer_types:
- rope_params = self.config.rope_parameters[layer_type]
- if rope_params is None:
- continue
- if (rope_type := rope_params["rope_type"]) != "default":
- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
- else:
- rope_init_fn = self.compute_default_rope_parameters
- self.rope_init_fns[layer_type] = rope_init_fn
- self.rope_type[layer_type] = rope_type
- rope_init_fn_kwargs = {"device": device, "layer_type": layer_type}
- if layer_type == "full_attention" and rope_type == "proportional":
- rope_init_fn_kwargs["head_dim_key"] = "global_head_dim"
- curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, **rope_init_fn_kwargs)
- self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
- self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
- setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
- @staticmethod
- def compute_default_rope_parameters(
- config: Gemma4TextConfig | None = None,
- device: Optional["torch.device"] = None,
- seq_len: int | None = None,
- layer_type: str | None = None,
- ) -> tuple["torch.Tensor", float]:
- """
- Computes the inverse frequencies according to the original RoPE implementation
- Args:
- config ([`~transformers.PreTrainedConfig`]):
- The model configuration.
- device (`torch.device`):
- The device to use for initialization of the inverse frequencies.
- seq_len (`int`, *optional*):
- The current sequence length. Unused for this type of RoPE.
- layer_type (`str`, *optional*):
- The current layer type if the model has different RoPE parameters per type.
- Should not be used unless `config.layer_types is not None`
- Returns:
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
- """
- # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
- base = config.rope_parameters[layer_type]["rope_theta"]
- dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
- attention_factor = 1.0 # Unused in this type of RoPE
- # Compute the inverse frequencies
- inv_freq = 1.0 / (
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
- )
- return inv_freq, attention_factor
- @torch.no_grad()
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
- def forward(self, x, position_ids, layer_type=None):
- inv_freq = getattr(self, f"{layer_type}_inv_freq")
- attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
- inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
- position_ids_expanded = position_ids[:, None, :].float()
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * attention_scaling
- sin = emb.sin() * attention_scaling
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
- @use_kernelized_func(apply_rotary_pos_emb)
- class Gemma4TextAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: Gemma4TextConfig, layer_idx: int):
- super().__init__()
- self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
- self.config = config
- self.layer_idx = layer_idx
- self.is_sliding = self.layer_type == "sliding_attention"
- self.sliding_window = config.sliding_window if self.is_sliding else None
- self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
- self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding
- num_key_value_heads = (
- config.num_global_key_value_heads if self.use_alternative_attention else config.num_key_value_heads
- )
- self.num_key_value_groups = config.num_attention_heads // num_key_value_heads
- self.scaling = 1.0
- self.attention_dropout = self.config.attention_dropout
- self.is_causal = config.use_bidirectional_attention != "all"
- # Shared kv cache
- first_kv_shared_layer_idx = self.config.num_hidden_layers - getattr(self.config, "num_kv_shared_layers", 0)
- self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
- prev_layers = config.layer_types[:first_kv_shared_layer_idx]
- if self.is_kv_shared_layer:
- # For shared layers, find the last non-shared layer of the same type before sharing starts
- self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
- self.store_full_length_kv = False
- else:
- self.kv_shared_layer_index = None
- # For non-shared layers, store full-length kv if this is the last non-shared layer of its type
- self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
- config.layer_types[layer_idx]
- )
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.q_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps)
- # Layers sharing kv states don't need any weight matrices
- if not self.is_kv_shared_layer:
- self.k_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps)
- self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False)
- self.k_proj = nn.Linear(
- config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = (
- nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias)
- if not self.use_alternative_attention
- else None
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: torch.Tensor,
- attention_mask: torch.Tensor | None,
- shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
- past_key_values: Cache | None = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- cos, sin = position_embeddings
- query_states = self.q_proj(hidden_states).view(hidden_shape)
- query_states = self.q_norm(query_states)
- query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
- query_states = query_states.transpose(1, 2)
- # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
- # We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
- # once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
- if self.is_kv_shared_layer:
- key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
- # Device of past layer may be different from current one
- key_states = key_states.to(query_states.device)
- value_states = value_states.to(query_states.device)
- else:
- key_states = self.k_proj(hidden_states).view(hidden_shape)
- value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
- key_states = self.k_norm(key_states)
- key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
- key_states = key_states.transpose(1, 2)
- value_states = self.v_norm(value_states)
- value_states = value_states.transpose(1, 2)
- if past_key_values is not None and not self.is_kv_shared_layer:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- if self.store_full_length_kv:
- shared_kv_states[self.layer_idx] = key_states, value_states
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=self.attention_dropout if self.training else 0.0,
- scaling=self.scaling,
- sliding_window=self.sliding_window,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- @use_experts_implementation
- class Gemma4TextExperts(nn.Module):
- """Collection of expert weights stored as 3D tensors."""
- def __init__(self, config: Gemma4TextConfig):
- super().__init__()
- self.num_experts = config.num_experts
- self.hidden_dim = config.hidden_size
- self.intermediate_dim = config.moe_intermediate_size
- self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
- self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
- self.act_fn = ACT2FN[config.hidden_activation]
- def forward(
- self,
- hidden_states: torch.Tensor,
- top_k_index: torch.Tensor,
- top_k_weights: torch.Tensor,
- ) -> torch.Tensor:
- final_hidden_states = torch.zeros_like(hidden_states)
- with torch.no_grad():
- expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
- expert_mask = expert_mask.permute(2, 1, 0)
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
- for expert_idx in expert_hit:
- expert_idx = expert_idx[0]
- if expert_idx == self.num_experts:
- continue
- top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
- current_state = hidden_states[token_idx]
- gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
- current_hidden_states = self.act_fn(gate) * up
- current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
- current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
- final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
- return final_hidden_states
- class Gemma4TextRouter(nn.Module):
- def __init__(self, config: Gemma4TextConfig):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.scalar_root_size = self.hidden_size**-0.5
- self.eps = config.rms_norm_eps
- self.norm = Gemma4RMSNorm(self.hidden_size, eps=self.eps, with_scale=False)
- self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False)
- self.scale = nn.Parameter(torch.ones(self.hidden_size))
- self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))
- def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
- hidden_states = self.norm(hidden_states)
- hidden_states = hidden_states * self.scale * self.scalar_root_size
- expert_scores = self.proj(hidden_states) # [B*S, E]
- router_probabilities = nn.functional.softmax(expert_scores, dim=-1)
- # topk returns both values (probabilities) and indices directly
- top_k_weights, top_k_index = torch.topk(
- router_probabilities,
- k=self.config.top_k_experts,
- dim=-1,
- ) # both [B*S, K]
- # Normalize the top-k weights so they sum to 1 per token
- top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
- # Apply per-expert scale directly to the weights
- top_k_weights = top_k_weights * self.per_expert_scale[top_k_index]
- return router_probabilities, top_k_weights, top_k_index
- class Gemma4TextDecoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.layer_idx = layer_idx
- self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx)
- self.mlp = Gemma4TextMLP(config, layer_idx)
- self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- self.register_buffer("layer_scalar", torch.ones(1))
- self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
- if self.hidden_size_per_layer_input:
- self.act_fn = ACT2FN[config.hidden_activation]
- self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
- self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
- self.post_per_layer_input_norm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- self.enable_moe_block = config.enable_moe_block
- if self.enable_moe_block:
- self.router = Gemma4TextRouter(config)
- self.experts = Gemma4TextExperts(config)
- self.post_feedforward_layernorm_1 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- self.post_feedforward_layernorm_2 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- self.pre_feedforward_layernorm_2 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- per_layer_input: torch.Tensor = None,
- shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]] | None = None,
- position_embeddings: torch.Tensor = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- **kwargs,
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- shared_kv_states=shared_kv_states,
- position_ids=position_ids,
- past_key_values=past_key_values,
- **kwargs,
- )
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.pre_feedforward_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- if self.enable_moe_block:
- hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states)
- # Take hidden states before MLP here
- hidden_states_flat = residual.reshape(-1, residual.shape[-1])
- _, top_k_weights, top_k_index = self.router(hidden_states_flat)
- hidden_states_2 = self.pre_feedforward_layernorm_2(hidden_states_flat)
- hidden_states_2 = self.experts(hidden_states_2, top_k_index, top_k_weights)
- hidden_states_2 = hidden_states_2.reshape(residual.shape)
- hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2)
- # Combine mlp and moe outputs
- hidden_states = hidden_states_1 + hidden_states_2
- hidden_states = self.post_feedforward_layernorm(hidden_states)
- hidden_states = residual + hidden_states
- if self.hidden_size_per_layer_input:
- residual = hidden_states
- hidden_states = self.per_layer_input_gate(hidden_states)
- hidden_states = self.act_fn(hidden_states)
- hidden_states = hidden_states * per_layer_input
- hidden_states = self.per_layer_projection(hidden_states)
- hidden_states = self.post_per_layer_input_norm(hidden_states)
- hidden_states = residual + hidden_states
- hidden_states *= self.layer_scalar
- return hidden_states
- class Gemma4TextScaledWordEmbedding(nn.Embedding):
- """
- This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
- """
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
- super().__init__(num_embeddings, embedding_dim, padding_idx)
- self.scalar_embed_scale = embed_scale
- self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
- def forward(self, input_ids: torch.Tensor):
- return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
- # ---- Model Classes ----
- class Gemma4PreTrainedModel(PreTrainedModel):
- config: Gemma4Config
- supports_gradient_checkpointing = True
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _can_compile_fullgraph = True
- _supports_attention_backend = True
- _no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"]
- _skip_keys_device_placement = ["past_key_values", "shared_kv_states"]
- input_modalities = ("image", "text", "video", "audio")
- @torch.no_grad()
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, Gemma4VisionPatchEmbedder):
- init.ones_(module.position_embedding_table)
- elif isinstance(module, Gemma4AudioRelPositionalEncoding):
- min_timescale = 1.0
- max_timescale = 10000.0
- num_timescales = module.hidden_size // 2
- log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1)
- inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
- init.copy_(module.inv_timescales, inv_timescales.unsqueeze(0).unsqueeze(0))
- elif isinstance(module, Gemma4AudioAttention):
- init.constant_(module.softcap, module.attention_logits_soft_cap)
- init.zeros_(module.per_dim_scale)
- elif isinstance(module, Gemma4TextRotaryEmbedding):
- for layer_type, rope_init_fn in module.rope_init_fns.items():
- rope_init_fn_kwargs = {"layer_type": layer_type}
- if layer_type == "full_attention" and module.rope_type[layer_type] == "proportional":
- rope_init_fn_kwargs["head_dim_key"] = "global_head_dim"
- curr_inv_freq, _ = rope_init_fn(module.config, **rope_init_fn_kwargs)
- init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
- init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
- elif isinstance(module, Gemma4VisionRotaryEmbedding):
- rope_fn = (
- ROPE_INIT_FUNCTIONS[module.rope_type]
- if module.rope_type != "default"
- else module.compute_default_rope_parameters
- )
- buffer_value, _ = rope_fn(module.config)
- init.copy_(module.inv_freq, buffer_value)
- init.copy_(module.original_inv_freq, buffer_value)
- elif isinstance(module, Gemma4TextScaledWordEmbedding):
- init.constant_(module.embed_scale, module.scalar_embed_scale)
- elif isinstance(module, Gemma4TextRouter):
- init.ones_(module.scale)
- init.ones_(module.per_expert_scale)
- elif isinstance(module, Gemma4TextExperts):
- std = self.config.initializer_range
- init.normal_(module.gate_up_proj, mean=0.0, std=std)
- init.normal_(module.down_proj, mean=0.0, std=std)
- elif isinstance(module, Gemma4TextDecoderLayer):
- init.ones_(module.layer_scalar)
- elif isinstance(module, Gemma4ClippableLinear) and module.use_clipped_linears:
- init.constant_(module.input_min, -float("inf"))
- init.constant_(module.input_max, float("inf"))
- init.constant_(module.output_min, -float("inf"))
- init.constant_(module.output_max, float("inf"))
- elif isinstance(module, Gemma4VisionModel) and module.config.standardize:
- init.zeros_(module.std_bias)
- init.ones_(module.std_scale)
- @auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.")
- class Gemma4TextModel(Gemma4PreTrainedModel):
- config: Gemma4TextConfig
- input_modalities = ("text",)
- _can_record_outputs = {
- "router_logits": OutputRecorder(Gemma4TextRouter, index=0),
- "hidden_states": Gemma4TextDecoderLayer,
- "attentions": Gemma4TextAttention,
- }
- def __init__(self, config: Gemma4TextConfig):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- # Gemma4 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
- self.embed_tokens = Gemma4TextScaledWordEmbedding(
- config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
- )
- self.layers = nn.ModuleList(
- [Gemma4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = Gemma4TextRotaryEmbedding(config)
- self.gradient_checkpointing = False
- self.unique_layer_types = set(self.config.layer_types)
- self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
- if self.hidden_size_per_layer_input:
- self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding(
- config.vocab_size_per_layer_input,
- config.num_hidden_layers * config.hidden_size_per_layer_input,
- self.padding_idx,
- embed_scale=config.hidden_size_per_layer_input**0.5,
- )
- self.per_layer_input_scale = 2.0**-0.5
- self.per_layer_model_projection = nn.Linear(
- config.hidden_size,
- config.num_hidden_layers * config.hidden_size_per_layer_input,
- bias=False,
- )
- self.per_layer_model_projection_scale = config.hidden_size**-0.5
- self.per_layer_projection_norm = Gemma4RMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps)
- # Update `_keys_to_ignore_on_load_unexpected` to drop all k/v proj and norms for the shared layers
- self._keys_to_ignore_on_load_unexpected = []
- for i, layer in enumerate(self.layers):
- if layer.self_attn.is_kv_shared_layer:
- self._keys_to_ignore_on_load_unexpected.extend(
- [f"layers.{i}.self_attn.{name}" for name in ("k_proj", "v_proj", "k_norm", "v_norm")]
- )
- # Initialize weights and apply final processing
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- per_layer_inputs: torch.Tensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPast:
- r"""
- per_layer_inputs (`torch.Tensor` of shape `(batch_size, sequence_length, num_hidden_layers, hidden_size_per_layer_input)`, *optional*):
- Pre-computed per-layer input embeddings. When provided, these are used directly instead of being
- computed from `input_ids` via `get_per_layer_inputs()`. This is primarily used by the multimodal
- model (`Gemma4Model`) which pre-computes per-layer inputs from the original `input_ids` *before*
- merging multimodal soft tokens into `inputs_embeds` — at which point the original token ids are
- no longer recoverable.
- """
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if input_ids is not None:
- inputs_embeds = self.embed_tokens(input_ids)
- if self.hidden_size_per_layer_input:
- if per_layer_inputs is None:
- per_layer_inputs = self.get_per_layer_inputs(input_ids, inputs_embeds)
- per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- if position_ids is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
- position_ids = position_ids.unsqueeze(0)
- # It may already have been prepared by e.g. `generate`
- if not isinstance(causal_mask_mapping := attention_mask, dict):
- # Prepare mask arguments
- mask_kwargs = {
- "config": self.config,
- "inputs_embeds": inputs_embeds,
- "attention_mask": attention_mask,
- "past_key_values": past_key_values,
- "position_ids": position_ids,
- }
- # Create the masks
- causal_mask_mapping = {
- "full_attention": create_causal_mask(**mask_kwargs),
- "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
- }
- # embed positions
- hidden_states = inputs_embeds
- position_embeddings = {}
- for layer_type in self.unique_layer_types:
- position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
- # Initialize as empty dict - it will be filled in the right layers
- shared_kv_states = {}
- # decoder layers
- for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
- per_layer_input = per_layer_inputs[:, :, i, :] if per_layer_inputs is not None else None
- hidden_states = decoder_layer(
- hidden_states,
- per_layer_input,
- shared_kv_states=shared_kv_states,
- position_embeddings=position_embeddings[self.config.layer_types[i]],
- attention_mask=causal_mask_mapping[self.config.layer_types[i]],
- position_ids=position_ids,
- past_key_values=past_key_values,
- **kwargs,
- )
- hidden_states = self.norm(hidden_states)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- )
- def get_per_layer_inputs(self, input_ids: torch.Tensor | None, inputs_embeds: torch.Tensor | None) -> torch.Tensor:
- if not self.hidden_size_per_layer_input:
- raise RuntimeError(
- "Attempting to call get_per_layer_inputs() from a model initialized with a config that does not support"
- f" per-layer embeddings. {self.config}"
- )
- # If only inputs_embeds are provided, reverse main embedding to find the input_ids - this allows to `generate`
- # from `inputs_embeds` only as other models (otherwise it would need the value from both embeddings)
- if input_ids is None:
- with torch.no_grad():
- input_ids = (
- (
- inputs_embeds[:, :, None, :]
- == self.embed_tokens.weight[None, None, :, :] * self.config.hidden_size**0.5
- )
- .all(dim=3)
- .nonzero()[:, 2]
- )
- try:
- input_ids = input_ids.view(inputs_embeds.shape[:2])
- except RuntimeError:
- raise RuntimeError(
- "It seems like you tried to call `forward` from `inputs_embeds` without providing `input_ids`, and that "
- "the `inputs_embeds` you provided do not exactly match the embedding weights. Since Gemma4 needs to reverse "
- "the embedding to compute another embedding, make sure you provide exact `inputs_embeds`"
- )
- return self.embed_tokens_per_layer(input_ids).reshape(
- *input_ids.shape,
- self.config.num_hidden_layers,
- self.hidden_size_per_layer_input,
- )
- def project_per_layer_inputs(
- self,
- inputs_embeds: torch.Tensor,
- per_layer_inputs: torch.Tensor | None = None,
- ) -> torch.Tensor:
- if not self.hidden_size_per_layer_input:
- raise RuntimeError(
- "Attempting to call project_per_layer_inputs() from a model initialized with a config that does not"
- f" support per-layer embeddings. {self.config}"
- )
- per_layer_projection = self.per_layer_model_projection(inputs_embeds) * self.per_layer_model_projection_scale
- per_layer_projection = per_layer_projection.reshape(
- *inputs_embeds.shape[:-1],
- self.config.num_hidden_layers,
- self.hidden_size_per_layer_input,
- )
- per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
- if per_layer_inputs is None:
- return per_layer_projection
- return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
- @auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.")
- class Gemma4ForCausalLM(Gemma4PreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
- _tp_plan = {"lm_head": "colwise_gather_output"}
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
- config: Gemma4TextConfig
- base_model_prefix = "model"
- def __init__(self, config: Gemma4TextConfig):
- super().__init__(config)
- self.model = Gemma4TextModel(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Grab the ones from the child
- self._keys_to_ignore_on_load_unexpected = [
- f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
- ]
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
- r"""
- Example:
- ```python
- >>> from transformers import AutoTokenizer, Gemma4ForCausalLM
- >>> model = Gemma4ForCausalLM.from_pretrained("google/gemma-2-9b")
- >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
- >>> prompt = "What is your favorite condiment?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "What is your favorite condiment?"
- ```"""
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs: BaseModelOutputWithPast = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- if self.config.final_logit_softcapping is not None:
- logits = logits / self.config.final_logit_softcapping
- logits = torch.tanh(logits)
- logits = logits * self.config.final_logit_softcapping
- loss = None
- if labels is not None:
- loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- def sliding_window_mask_function(sliding_window: tuple[int, int]) -> Callable:
- """
- This creates uni/bidirectional attention mask with sliding window.
- """
- def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
- left_window_size, right_window_size = sliding_window
- dist = q_idx - kv_idx
- left_mask = (dist >= 0) & (dist < left_window_size)
- right_mask = (dist < 0) & (-dist < right_window_size)
- return left_mask | right_mask
- return inner_mask
- class Gemma4AudioModel(Gemma4PreTrainedModel):
- """An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture."""
- config: Gemma4AudioConfig
- main_input_name = "input_features"
- base_model_prefix = "model.audio_tower" # prefix for Gemma4ForConditionalGeneration saved checkpoints, required for Gemma4AudioModel.from_pretrained()
- _can_record_outputs = {
- "hidden_states": Gemma4AudioLayer,
- "attentions": Gemma4AudioAttention,
- }
- def __init__(self, config: Gemma4AudioConfig):
- super().__init__(config)
- self.config = config
- self.subsample_conv_projection = Gemma4AudioSubSampleConvProjection(config)
- self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config)
- self.layers = nn.ModuleList(
- [Gemma4AudioLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.output_proj = nn.Linear(config.hidden_size, config.output_proj_dims, bias=True)
- self.post_init()
- def _convert_4d_mask_to_blocked_5d(self, mask_4d: torch.Tensor) -> torch.Tensor:
- """
- Convert a standard 4D attention mask `[batch_size, 1, seq_len, seq_len]` to the 5D blocked format
- `[batch_size, 1, num_blocks, chunk_size, context_size]` expected by the chunked local attention,
- """
- batch_size, _, seq_len, _ = mask_4d.shape
- device = mask_4d.device
- chunk_size = self.config.attention_chunk_size
- max_past_horizon = self.config.attention_context_left - 1
- max_future_horizon = self.config.attention_context_right
- num_blocks = (seq_len + chunk_size - 1) // chunk_size
- padded_seq_len = num_blocks * chunk_size
- pad_amount = padded_seq_len - seq_len
- mask_4d = F.pad(mask_4d, (0, pad_amount, 0, pad_amount), value=False)
- mask_5d = mask_4d.reshape(batch_size, 1, num_blocks, chunk_size, padded_seq_len)
- mask_5d = F.pad(mask_5d, (max_past_horizon, max_future_horizon), value=False)
- block_starts = torch.arange(num_blocks, device=device) * chunk_size
- offsets = torch.arange(chunk_size + max_past_horizon + max_future_horizon, device=device)
- kv_indices = block_starts[:, None] + offsets[None, :]
- kv_indices = kv_indices[None, None, :, None, :].expand(batch_size, 1, -1, chunk_size, -1)
- return mask_5d.gather(-1, kv_indices)
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring(custom_intro="Encodes audio features to soft tokens.")
- def forward(
- self,
- input_features: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.BoolTensor]:
- hidden_states, output_mask = self.subsample_conv_projection(input_features, attention_mask)
- position_embeddings = self.rel_pos_enc(hidden_states)
- attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=hidden_states,
- attention_mask=output_mask,
- and_mask_function=sliding_window_mask_function(
- (self.config.attention_context_left - 1, self.config.attention_context_right)
- ),
- )
- attention_mask = self._convert_4d_mask_to_blocked_5d(attention_mask)
- for encoder_layer in self.layers[: self.config.num_hidden_layers]:
- hidden_states = encoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = self.output_proj(hidden_states)
- return Gemma4AudioModelOutput(last_hidden_state=hidden_states, attention_mask=output_mask)
- class Gemma4VisionModel(Gemma4PreTrainedModel):
- """The Gemma 4 Vision Encoder."""
- config = Gemma4VisionConfig
- _can_record_outputs = {
- "hidden_states": Gemma4VisionEncoderLayer,
- "attentions": Gemma4VisionAttention,
- }
- def __init__(self, config: Gemma4VisionConfig):
- super().__init__(config)
- self.patch_embedder = Gemma4VisionPatchEmbedder(config)
- self.encoder = Gemma4VisionEncoder(config)
- self.pooler = Gemma4VisionPooler(config)
- if self.config.standardize:
- self.register_buffer("std_bias", torch.empty(self.config.hidden_size))
- self.register_buffer("std_scale", torch.empty(self.config.hidden_size))
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring(custom_intro="Encodes image pixels to soft tokens from patches.")
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- pixel_position_ids: torch.LongTensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPast:
- r"""
- pixel_values (`torch.FloatTensor` or `list[torch.FloatTensor]`):
- The images to encode. Either a single `[batch, channels, height, width]` tensor
- (all images same size) or a list of `[1, channels, height, width]` tensors (different sizes).
- pixel_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`):
- The patch positions as (x, y) coordinates in the image. Padding patches are indicated by (-1, -1).
- """
- pooling_kernel_size = self.config.pooling_kernel_size
- output_length = pixel_values.shape[-2] // (pooling_kernel_size * pooling_kernel_size)
- padding_positions = (pixel_position_ids == -1).all(dim=-1)
- inputs_embeds = self.patch_embedder(pixel_values, pixel_position_ids, padding_positions)
- output = self.encoder(
- inputs_embeds=inputs_embeds,
- attention_mask=~padding_positions, # encoder expects True=valid, padding_positions is True=padding
- pixel_position_ids=pixel_position_ids,
- **kwargs,
- )
- hidden_states, pooler_mask = self.pooler(
- hidden_states=output.last_hidden_state,
- pixel_position_ids=pixel_position_ids,
- padding_positions=padding_positions,
- output_length=output_length,
- )
- # Strip padding tokens. pooler_mask is True = valid, False = padding.
- hidden_states = hidden_states[pooler_mask]
- if self.config.standardize:
- hidden_states = (hidden_states - self.std_bias) * self.std_scale
- return BaseModelOutputWithPast(last_hidden_state=hidden_states)
- class Gemma4MultimodalEmbedder(nn.Module):
- """Embeds token ids or soft tokens for multimodal content into language model space."""
- def __init__(
- self,
- multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig,
- text_config: Gemma4TextConfig,
- ):
- super().__init__()
- self.multimodal_hidden_size = getattr(multimodal_config, "output_proj_dims", multimodal_config.hidden_size)
- self.eps = multimodal_config.rms_norm_eps
- self.text_hidden_size = text_config.hidden_size
- self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
- self.embedding_pre_projection_norm = Gemma4RMSNorm(self.multimodal_hidden_size, eps=self.eps, with_scale=False)
- def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
- """Embeds token ids or soft tokens for multimodal content into language model space.
- Args:
- inputs_embeds: A torch.Tensor containing the soft tokens to embed.
- Returns:
- A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
- """
- embs_normed = self.embedding_pre_projection_norm(inputs_embeds)
- return self.embedding_projection(embs_normed)
- # Identical as Gemma3 but modular can't resolve if we simply import. FIXME: @cyril
- def token_type_ids_mask_function(
- token_type_ids: torch.Tensor | None,
- image_group_ids: torch.Tensor | None,
- ) -> Callable | None:
- """
- This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
- not start and end indices.
- """
- # Do not return an additional mask in this case
- if token_type_ids is None:
- return None
- def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
- seq_length = image_group_ids.shape[-1]
- # clamp indices because with static cache they can go beyond `image_group_ids.shape[-1]`
- q_idx_clamped = q_idx.clamp(max=seq_length - 1)
- kv_idx_clamped = kv_idx.clamp(max=seq_length - 1)
- # Unmask if the q and kv come from same group which is not -1 (i.e. non-text)
- q_group = image_group_ids[batch_idx, q_idx_clamped]
- kv_group = image_group_ids[batch_idx, kv_idx_clamped]
- q_group = torch.where(q_idx < seq_length, q_group, -1)
- kv_group = torch.where(kv_idx < seq_length, kv_group, -1)
- return (q_group == kv_group) & (q_group >= 0)
- return inner_mask
- # Similar to Gemma3 but `sliding_mask_kwargs` and `mask_kwargs` are different and `token_type_ids->mm_token_type_ids`
- def create_causal_mask_mapping(
- config: PreTrainedConfig,
- inputs_embeds: torch.Tensor,
- attention_mask: torch.Tensor | None,
- past_key_values: Cache | None,
- position_ids: torch.Tensor | None,
- mm_token_type_ids: torch.Tensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- is_training: bool = False,
- is_first_iteration: bool | None = None,
- **kwargs,
- ) -> dict:
- """
- Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
- for all kinds of forward passes. Gemma4 uses a bidirectional mask for images.
- Uses `pixel_values` as an optional input to disambiguate edge cases.
- """
- if is_training and mm_token_type_ids is None:
- raise ValueError("`mm_token_type_ids` is required as a model input when training")
- mask_kwargs = {
- "config": config.get_text_config(),
- "inputs_embeds": inputs_embeds,
- "attention_mask": attention_mask,
- "past_key_values": past_key_values,
- "position_ids": position_ids,
- }
- sliding_mask_kwargs = mask_kwargs.copy()
- # NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
- # (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
- # means). Determining prefill in that case requires checking data values, which is not compile-compatible.
- is_first_iteration = (
- is_first_iteration
- if is_first_iteration is not None
- else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
- )
- if mm_token_type_ids is not None and is_first_iteration:
- # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
- # undo the causal masking)
- # First find where a new vision block starts. Vision tokens cannot attend to
- # future vision tokens, but can attend to all prev tokens and to itself bidirectionally
- is_vision = (mm_token_type_ids == 1) | (mm_token_type_ids == 2)
- is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1)
- is_prev_vision[..., 0] = False
- new_vision_starts = is_vision & ~is_prev_vision
- vision_group_ids = torch.cumsum(new_vision_starts.int(), dim=1) - 1
- vision_group_ids = torch.where(is_vision, vision_group_ids, -1)
- sliding_mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
- mm_token_type_ids.to(inputs_embeds.device), vision_group_ids
- )
- return {
- "full_attention": create_causal_mask(**mask_kwargs),
- "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
- }
- @auto_docstring(
- custom_intro="""
- The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a
- language modeling head.
- """
- )
- class Gemma4Model(Gemma4PreTrainedModel):
- # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
- accepts_loss_kwargs = False
- def __init__(self, config: Gemma4Config):
- super().__init__(config)
- self.vocab_size = config.text_config.vocab_size
- language_model = AutoModel.from_config(config=config.text_config)
- self.language_model = language_model
- self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
- self.vision_tower = AutoModel.from_config(config.vision_config) if config.vision_config is not None else None
- self.embed_vision = (
- Gemma4MultimodalEmbedder(config.vision_config, config.text_config)
- if config.vision_config is not None
- else None
- )
- self.audio_tower = AutoModel.from_config(config.audio_config) if config.audio_config is not None else None
- self.embed_audio = (
- Gemma4MultimodalEmbedder(config.audio_config, config.text_config)
- if config.audio_config is not None
- else None
- )
- # Grab the ones from the child
- self._keys_to_ignore_on_load_unexpected = [
- f"language_model.{name}" for name in self.language_model._keys_to_ignore_on_load_unexpected
- ]
- self.post_init()
- def get_input_embeddings(self):
- return self.language_model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.language_model.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.")
- def get_image_features(
- self,
- pixel_values: torch.FloatTensor,
- image_position_ids: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPooling:
- r"""
- image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*):
- The patch positions as (x, y) coordinates in the image. Padding patches are indicated by (-1, -1).
- """
- vision_outputs = self.vision_tower(
- pixel_values=pixel_values,
- pixel_position_ids=image_position_ids,
- **kwargs,
- )
- last_hidden_state = vision_outputs.last_hidden_state
- vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state)
- return vision_outputs
- def get_placeholder_mask(
- self,
- input_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- ) -> tuple[torch.BoolTensor, torch.BoolTensor, torch.BoolTensor]:
- """
- Obtains mask for multimodal placeholders (replaced by soft tokens) and hard text tokens.
- Masks will be obtained from `mm_token_type_ids`, `input_ids`, or `inputs_embeds` as available and in that
- precedence order. If passing `input_ids` or `inputs_embeds`, the image mask will be derived using
- `config.image_token_id`. Same goes for audio and video masks
- Args:
- input_ids: A tensor containing the hard token IDs from the text tokenizer.
- inputs_embeds: A tensor containing the embeddings for all hard text tokens.
- Returns:
- image_mask, video_mask, audio_mask
- """
- if input_ids is not None:
- special_image_mask = input_ids == self.config.image_token_id
- special_video_mask = input_ids == self.config.video_token_id
- special_audio_mask = input_ids == self.config.audio_token_id
- else:
- special_image_mask = (
- inputs_embeds
- == self.get_input_embeddings()(
- torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- ).all(-1)
- special_video_mask = (
- inputs_embeds
- == self.get_input_embeddings()(
- torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- ).all(-1)
- special_audio_mask = (
- inputs_embeds
- == self.get_input_embeddings()(
- torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- ).all(-1)
- return special_image_mask, special_video_mask, special_audio_mask
- @merge_with_config_defaults
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- pixel_values_videos: torch.FloatTensor | None = None,
- input_features: torch.FloatTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- input_features_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- mm_token_type_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- image_position_ids: torch.LongTensor | None = None,
- video_position_ids: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> Gemma4ModelOutputWithPast:
- r"""
- input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
- The attention mask for the input audio.
- image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*):
- 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding.
- Passed through to the vision encoder for positional embedding computation.
- video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*):
- 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding.
- Passed through to the vision encoder for positional embedding computation.
- """
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- image_mask, video_mask, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds)
- multimodal_mask = image_mask | video_mask | audio_mask
- # Replace image id with PAD if the image token if OOV, to avoid index-errors
- llm_input_ids = None
- if inputs_embeds is None:
- llm_input_ids = input_ids.clone()
- llm_input_ids[multimodal_mask] = self.config.text_config.pad_token_id
- inputs_embeds = self.get_input_embeddings()(llm_input_ids)
- if self.config.get_text_config().hidden_size_per_layer_input:
- pad_embedding = self.language_model.embed_tokens.weight[self.config.text_config.pad_token_id, :]
- llm_inputs_embeds = torch.where(multimodal_mask[..., None], pad_embedding.view(1, 1, -1), inputs_embeds)
- per_layer_inputs = self.language_model.get_per_layer_inputs(llm_input_ids, llm_inputs_embeds)
- else:
- per_layer_inputs = None
- # Merge text and images
- if pixel_values is not None:
- image_features = self.get_image_features(pixel_values, image_position_ids, return_dict=True).pooler_output
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- # Confirm the number of soft tokens from the vision tower matches the number of slots in the embeddings.
- n_image_tokens = image_mask.sum()
- image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- torch_compilable_check(
- inputs_embeds[image_mask].numel() == image_features.numel(),
- f"Image features and image tokens do not match, tokens: {n_image_tokens}, features:"
- f" {image_features.shape[0]}",
- )
- inputs_embeds = inputs_embeds.masked_scatter(
- image_mask.to(inputs_embeds.device), image_features.to(inputs_embeds.device)
- )
- if pixel_values_videos is not None:
- video_features = self.get_video_features(
- pixel_values_videos, video_position_ids, return_dict=True
- ).pooler_output
- video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
- # Confirm the number of soft tokens from the vision tower matches the number of slots in the embeddings.
- n_video_tokens = video_mask.sum()
- video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- torch_compilable_check(
- inputs_embeds[video_mask].numel() == video_features.numel(),
- f"Video features and video tokens do not match, tokens: {n_video_tokens}, features:"
- f" {video_features.shape[0]}",
- )
- inputs_embeds = inputs_embeds.masked_scatter(
- video_mask.to(inputs_embeds.device), video_features.to(inputs_embeds.device)
- )
- # Merge text and audio
- if input_features is not None and input_features_mask is not None:
- audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True)
- audio_features = audio_output.pooler_output
- audio_mask_from_encoder = audio_output.attention_mask # True = valid
- # Strip padding tokens: only keep real (non-padding) audio soft tokens.
- # audio_mask_from_encoder is True for valid positions, False for padding tokens.
- # This mirrors the vision encoder's padding stripping (see Gemma4VisionEncoder.forward).
- audio_features = audio_features[audio_mask_from_encoder]
- n_audio_tokens = audio_mask.sum()
- audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- torch_compilable_check(
- inputs_embeds[audio_mask].numel() == audio_features.numel(),
- f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features:"
- f" {audio_features.shape[0] * audio_features.shape[1]}",
- )
- inputs_embeds = inputs_embeds.masked_scatter(
- audio_mask.to(inputs_embeds.device), audio_features.to(inputs_embeds.device)
- )
- # It may already have been prepared by, e.g., `generate`
- if position_ids is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
- position_ids = position_ids.unsqueeze(0)
- if not isinstance(causal_mask_mapping := attention_mask, dict):
- if self.config.get_text_config().use_bidirectional_attention == "vision":
- # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs
- causal_mask_mapping = create_causal_mask_mapping(
- self.config,
- inputs_embeds,
- attention_mask,
- past_key_values,
- position_ids,
- mm_token_type_ids,
- pixel_values,
- is_training=self.training,
- )
- else:
- # Smaller Gemma models use a conventional casual attention mask
- causal_mask_mapping = create_masks_for_generate(
- self.config,
- inputs_embeds,
- attention_mask,
- past_key_values,
- position_ids,
- )
- outputs = self.language_model(
- per_layer_inputs=per_layer_inputs,
- attention_mask=causal_mask_mapping,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- return_dict=True,
- **kwargs,
- )
- return Gemma4ModelOutputWithPast(
- last_hidden_state=outputs.last_hidden_state,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
- audio_hidden_states=audio_features if input_features is not None else None,
- )
- @can_return_tuple
- @auto_docstring(custom_intro="Projects the last hidden state from the audio encoder into language model space.")
- def get_audio_features(
- self,
- input_features: torch.Tensor,
- input_features_mask: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Gemma4AudioModelOutput:
- r"""
- input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`):
- The tensors corresponding to the input audio.
- input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
- The attention mask for the input audio.
- """
- if self.audio_tower is None:
- raise ValueError(
- "Audio features were requested, but the model was initialized without an audio_config. "
- "Cannot process audio without an audio tower and audio embedder."
- )
- audio_outputs = self.audio_tower(input_features, input_features_mask, return_dict=True, **kwargs)
- audio_outputs.pooler_output = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state)
- return audio_outputs
- @can_return_tuple
- @auto_docstring(custom_intro="Projects the last hidden state from the vision encoder into language model space.")
- def get_video_features(
- self,
- pixel_values_videos: torch.FloatTensor,
- video_position_ids: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPooling:
- r"""
- video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*):
- 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding.
- Passed through to the vision encoder for positional embedding computation.
- """
- pixel_values_videos = pixel_values_videos.flatten(0, 1)
- video_position_ids = video_position_ids.flatten(0, 1)
- vision_outputs = self.vision_tower(
- pixel_values=pixel_values_videos,
- pixel_position_ids=video_position_ids,
- **kwargs,
- )
- last_hidden_state = vision_outputs.last_hidden_state
- vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state)
- return vision_outputs
- @auto_docstring(
- custom_intro="""
- The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling
- head.
- """
- )
- class Gemma4ForConditionalGeneration(Gemma4PreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
- base_model_prefix = "model"
- def __init__(self, config: Gemma4Config):
- super().__init__(config)
- self.model = Gemma4Model(config)
- self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
- # Grab the ones from the child
- self._keys_to_ignore_on_load_unexpected = [
- f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
- ]
- self.post_init()
- def get_input_embeddings(self):
- return self.model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.model.set_input_embeddings(value)
- @auto_docstring
- def get_image_features(
- self,
- pixel_values: torch.FloatTensor,
- image_position_ids: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ):
- r"""
- image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*):
- 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding.
- Passed through to the vision encoder for positional embedding computation.
- """
- return self.model.get_image_features(pixel_values, image_position_ids, **kwargs)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- pixel_values_videos: torch.FloatTensor | None = None,
- input_features: torch.FloatTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- input_features_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- image_position_ids: torch.LongTensor | None = None,
- video_position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- mm_token_type_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> Gemma4CausalLMOutputWithPast:
- r"""
- input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
- The attention mask for the input audio.
- image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*):
- 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding.
- Passed through to the vision encoder for positional embedding computation.
- video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*):
- 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding.
- Passed through to the vision encoder for positional embedding computation.
- """
- outputs = self.model(
- input_ids=input_ids,
- pixel_values=pixel_values,
- pixel_values_videos=pixel_values_videos,
- input_features=input_features,
- attention_mask=attention_mask,
- input_features_mask=input_features_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- mm_token_type_ids=mm_token_type_ids,
- inputs_embeds=inputs_embeds,
- labels=labels,
- use_cache=use_cache,
- image_position_ids=image_position_ids,
- video_position_ids=video_position_ids,
- return_dict=True,
- **kwargs,
- )
- hidden_states = outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None:
- logits = logits / final_logit_softcapping
- logits = torch.tanh(logits)
- logits = logits * final_logit_softcapping
- loss = None
- if labels is not None:
- # Upcast to float if we need to compute the loss to avoid potential precision issues
- logits = logits.float()
- shift_logits = logits[..., :-1, :]
- shift_labels = labels[..., 1:]
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
- shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
- else:
- shift_logits = shift_logits.contiguous()
- shift_labels = shift_labels.contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- flat_logits = shift_logits.view(-1, self.config.get_text_config().vocab_size)
- flat_labels = shift_labels.view(-1).to(shift_logits.device)
- loss = loss_fct(flat_logits, flat_labels)
- return Gemma4CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- image_hidden_states=outputs.image_hidden_states,
- audio_hidden_states=outputs.audio_hidden_states,
- )
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- inputs_embeds=None,
- position_ids=None,
- pixel_values=None,
- pixel_values_videos=None,
- input_features=None,
- attention_mask=None,
- input_features_mask=None,
- token_type_ids=None,
- use_cache=True,
- logits_to_keep=None,
- labels=None,
- is_first_iteration=False,
- **kwargs,
- ):
- # Overwritten -- custom `position_ids` and `pixel_values` handling
- model_inputs = super().prepare_inputs_for_generation(
- input_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- position_ids=position_ids,
- use_cache=use_cache,
- logits_to_keep=logits_to_keep,
- token_type_ids=token_type_ids,
- is_first_iteration=is_first_iteration,
- **kwargs,
- )
- # If we're in cached decoding stage, multimodal inputs are already cached and can be dropped
- if is_first_iteration or not use_cache:
- model_inputs["pixel_values"] = pixel_values
- model_inputs["pixel_values_videos"] = pixel_values_videos
- model_inputs["input_features"] = input_features
- model_inputs["input_features_mask"] = input_features_mask
- return model_inputs
- @staticmethod
- def create_masks_for_generate(
- config: PreTrainedConfig,
- inputs_embeds: torch.Tensor,
- attention_mask: torch.Tensor | None,
- past_key_values: Cache | None,
- position_ids: torch.Tensor | None,
- mm_token_type_ids: torch.Tensor | None = None,
- is_first_iteration: bool | None = False,
- **kwargs,
- ) -> dict:
- if getattr(config.get_text_config(), "use_bidirectional_attention", None) == "vision":
- # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs
- return create_causal_mask_mapping(
- config,
- inputs_embeds,
- attention_mask,
- past_key_values,
- position_ids,
- mm_token_type_ids,
- is_first_iteration=is_first_iteration,
- **{k: v for k, v in kwargs.items() if k != "pixel_values"},
- )
- else:
- # Smaller Gemma models use a conventional casual attention mask
- return create_masks_for_generate(
- config, inputs_embeds, attention_mask, past_key_values, position_ids, **kwargs
- )
- __all__ = [
- "Gemma4AudioModel",
- "Gemma4ForCausalLM",
- "Gemma4ForConditionalGeneration",
- "Gemma4Model",
- "Gemma4PreTrainedModel",
- "Gemma4TextModel",
- "Gemma4VisionModel",
- ]
|