modeling_gemma4.py 117 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/gemma4/modular_gemma4.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_gemma4.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2026 the HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. from dataclasses import dataclass
  23. from functools import cached_property
  24. from typing import Optional
  25. import torch
  26. from torch import nn
  27. from torch.nn import functional as F
  28. from ... import initialization as init
  29. from ...activations import ACT2FN
  30. from ...cache_utils import Cache, DynamicCache
  31. from ...configuration_utils import PreTrainedConfig
  32. from ...generation import GenerationMixin
  33. from ...integrations import use_experts_implementation, use_kernelized_func
  34. from ...masking_utils import (
  35. create_bidirectional_mask,
  36. create_causal_mask,
  37. create_masks_for_generate,
  38. create_sliding_window_causal_mask,
  39. )
  40. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  41. from ...modeling_layers import GradientCheckpointingLayer
  42. from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
  43. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  44. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  45. from ...processing_utils import Unpack
  46. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
  47. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  48. from ...utils.output_capturing import OutputRecorder, capture_outputs
  49. from ..auto.modeling_auto import AutoModel
  50. from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig
  51. @dataclass
  52. @auto_docstring(
  53. custom_intro="""
  54. Base class for Gemma4 outputs, with hidden states and attentions.
  55. """
  56. )
  57. class Gemma4ModelOutputWithPast(BaseModelOutputWithPast):
  58. r"""
  59. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  60. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  61. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  62. `past_key_values` input) to speed up sequential decoding.
  63. image_hidden_states (`torch.FloatTensor`, *optional*):
  64. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  65. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  66. audio_hidden_states (`torch.FloatTensor`, *optional*):
  67. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  68. audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
  69. """
  70. image_hidden_states: torch.FloatTensor | None = None
  71. audio_hidden_states: torch.FloatTensor | None = None
  72. @dataclass
  73. @auto_docstring(
  74. custom_intro="""
  75. Base class for Gemma4 causal language model (or autoregressive) outputs.
  76. """
  77. )
  78. class Gemma4CausalLMOutputWithPast(ModelOutput):
  79. r"""
  80. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  81. Language modeling loss (for next-token prediction).
  82. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
  83. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  84. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  85. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  86. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  87. `past_key_values` input) to speed up sequential decoding.
  88. image_hidden_states (`torch.FloatTensor`, *optional*):
  89. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  90. image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
  91. audio_hidden_states (`torch.FloatTensor`, *optional*):
  92. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  93. audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
  94. """
  95. loss: torch.FloatTensor | None = None
  96. logits: torch.FloatTensor | None = None
  97. past_key_values: Cache | None = None
  98. hidden_states: tuple[torch.FloatTensor] | None = None
  99. attentions: tuple[torch.FloatTensor] | None = None
  100. image_hidden_states: torch.FloatTensor | None = None
  101. audio_hidden_states: torch.FloatTensor | None = None
  102. @dataclass
  103. @auto_docstring
  104. class Gemma4AudioModelOutput(BaseModelOutputWithPooling):
  105. r"""
  106. attention_mask (`torch.BoolTensor`, *optional*):
  107. A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding.
  108. """
  109. attention_mask: torch.BoolTensor | None = None
  110. class Gemma4ClippableLinear(nn.Module):
  111. def __init__(
  112. self,
  113. config: Gemma4VisionConfig | Gemma4AudioConfig,
  114. in_features: int,
  115. out_features: int,
  116. ) -> None:
  117. super().__init__()
  118. self.use_clipped_linears = config.use_clipped_linears
  119. self.linear = nn.Linear(in_features, out_features, bias=False)
  120. if self.use_clipped_linears:
  121. self.register_buffer("input_min", torch.tensor(-float("inf")))
  122. self.register_buffer("input_max", torch.tensor(float("inf")))
  123. self.register_buffer("output_min", torch.tensor(-float("inf")))
  124. self.register_buffer("output_max", torch.tensor(float("inf")))
  125. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  126. if self.use_clipped_linears:
  127. hidden_states = torch.clamp(hidden_states, self.input_min, self.input_max)
  128. hidden_states = self.linear(hidden_states)
  129. if self.use_clipped_linears:
  130. hidden_states = torch.clamp(hidden_states, self.output_min, self.output_max)
  131. return hidden_states
  132. class Gemma4RMSNorm(nn.Module):
  133. def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
  134. super().__init__()
  135. self.eps = eps
  136. self.with_scale = with_scale
  137. if self.with_scale:
  138. self.weight = nn.Parameter(torch.ones(dim), requires_grad=True)
  139. def _norm(self, hidden_states: torch.Tensor):
  140. mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps
  141. # Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX
  142. return hidden_states * torch.pow(mean_squared, -0.5)
  143. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  144. normed_output = self._norm(hidden_states.float())
  145. if self.with_scale:
  146. normed_output = normed_output * self.weight.float()
  147. return normed_output.type_as(hidden_states)
  148. class Gemma4AudioRelPositionalEncoding(nn.Module):
  149. """Sinusoidal relative positional encoding for the audio encoder.
  150. Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with
  151. concatenated [sin..., cos...] layout matching the original Gemma4 convention.
  152. """
  153. inv_timescales: torch.Tensor
  154. def __init__(self, config: Gemma4AudioConfig):
  155. super().__init__()
  156. self.hidden_size = config.hidden_size
  157. self.context_size = (
  158. config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right
  159. )
  160. min_timescale = 1.0
  161. max_timescale = 10000.0
  162. num_timescales = self.hidden_size // 2
  163. log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1)
  164. inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
  165. self.register_buffer("inv_timescales", inv_timescales.unsqueeze(0).unsqueeze(0), persistent=False)
  166. @torch.no_grad()
  167. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  168. position_ids = torch.arange(12, -1, -1, device=hidden_states.device)
  169. position_ids = position_ids[..., None]
  170. scaled_time = position_ids * self.inv_timescales.to(device=hidden_states.device)
  171. pos_embed = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
  172. return pos_embed.to(dtype=hidden_states.dtype)
  173. class Gemma4AudioAttention(nn.Module):
  174. """Chunked local attention with relative position bias"""
  175. def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
  176. super().__init__()
  177. self.config = config
  178. self.layer_idx = layer_idx
  179. self.attention_logits_soft_cap = config.attention_logit_cap
  180. self.head_dim = config.hidden_size // config.num_attention_heads
  181. self.num_heads = config.num_attention_heads
  182. self.q_scale = (self.head_dim**-0.5) / math.log(2)
  183. self.k_scale = math.log(1 + math.e) / math.log(2)
  184. self.chunk_size = config.attention_chunk_size
  185. self.max_past_horizon = config.attention_context_left - 1
  186. self.max_future_horizon = config.attention_context_right
  187. self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
  188. self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
  189. self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
  190. self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
  191. self.post = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
  192. self.relative_k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
  193. self.per_dim_scale = nn.Parameter(torch.zeros(self.head_dim))
  194. self.register_buffer("softcap", torch.tensor(self.attention_logits_soft_cap), persistent=False)
  195. def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
  196. """Splits a `(batch_size, seq_len, num_heads, head_dim)` tensor into non-overlapping blocks of `chunk_size` along the sequence dim."""
  197. batch_size, seq_len, num_heads, head_dim = hidden_states.shape
  198. num_blocks = (seq_len + self.chunk_size - 1) // self.chunk_size
  199. pad = num_blocks * self.chunk_size - seq_len
  200. hidden_states = F.pad(hidden_states, (0, 0, 0, 0, 0, pad))
  201. return hidden_states.reshape(batch_size, num_blocks, self.chunk_size, num_heads, head_dim).contiguous()
  202. def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
  203. """Extracts overlapping context windows of `context_size` for every block, strided by `chunk_size`."""
  204. batch_size, seq_len, num_heads, head_dim = hidden_states.shape
  205. hidden_states = F.pad(
  206. hidden_states, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1)
  207. )
  208. hidden_states = hidden_states.unfold(1, self.context_size, self.chunk_size)
  209. hidden_states = torch.movedim(hidden_states, -1, 2)
  210. return hidden_states.contiguous()
  211. def _rel_shift(self, x: torch.Tensor) -> torch.Tensor:
  212. """Relative position shift for blocked attention. See appendix B of https://huggingface.co/papers/1901.02860."""
  213. batch_size, num_heads, num_blocks, block_size, position_length = x.shape
  214. context_size = self.context_size
  215. x = F.pad(x, (0, context_size + 1 - position_length))
  216. x = x.view(batch_size, num_heads, num_blocks, block_size * (context_size + 1))
  217. x = x[..., : block_size * context_size]
  218. return x.view(batch_size, num_heads, num_blocks, block_size, context_size)
  219. def forward(
  220. self,
  221. hidden_states: torch.Tensor,
  222. position_embeddings: torch.Tensor,
  223. attention_mask: torch.BoolTensor | None = None,
  224. ) -> tuple[torch.Tensor, None]:
  225. batch_size, seq_length, _ = hidden_states.shape
  226. hidden_shape = (batch_size, seq_length, self.num_heads, self.head_dim)
  227. query_states = self.q_proj(hidden_states).float().view(hidden_shape)
  228. key_states = self.k_proj(hidden_states).float().view(hidden_shape)
  229. value_states = self.v_proj(hidden_states).float().view(hidden_shape)
  230. query_states = query_states * self.q_scale * F.softplus(self.per_dim_scale)
  231. key_states = key_states * self.k_scale
  232. query_states = self._convert_to_block(query_states)
  233. key_states = self._extract_block_context(key_states)
  234. value_states = self._extract_block_context(value_states)
  235. num_blocks = query_states.shape[1]
  236. relative_key_states = self.relative_k_proj(position_embeddings)
  237. relative_key_states = relative_key_states.view(-1, self.num_heads, self.head_dim)
  238. relative_key_states = relative_key_states.to(dtype=query_states.dtype)
  239. queries = query_states.permute(0, 3, 1, 2, 4)
  240. matrix_ac = queries @ key_states.permute(0, 3, 1, 4, 2)
  241. queries_flat = queries.reshape(batch_size, self.num_heads, -1, self.head_dim)
  242. matrix_bd = queries_flat @ relative_key_states.permute(1, 2, 0)
  243. matrix_bd = matrix_bd.reshape(batch_size, self.num_heads, num_blocks, self.chunk_size, -1)
  244. matrix_bd = self._rel_shift(matrix_bd)
  245. attn_weights = matrix_ac + matrix_bd
  246. attn_weights = attn_weights / self.softcap
  247. attn_weights = torch.tanh(attn_weights)
  248. attn_weights = attn_weights * self.softcap
  249. if attention_mask is not None:
  250. attn_weights = attn_weights.masked_fill(
  251. attention_mask.logical_not(), self.config.attention_invalid_logits_value
  252. )
  253. attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
  254. attn_output = attn_weights @ value_states.permute(0, 3, 1, 2, 4)
  255. attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, num_blocks * self.chunk_size, -1)
  256. attn_output = attn_output[:, :seq_length].contiguous()
  257. attn_output = self.post(attn_output.to(dtype=self.post.linear.weight.dtype))
  258. return attn_output, attn_weights
  259. class Gemma4AudioSubSampleConvProjectionLayer(nn.Module):
  260. def __init__(self, in_channels, out_channels, norm_eps):
  261. super().__init__()
  262. self.conv = nn.Conv2d(
  263. in_channels=in_channels,
  264. out_channels=out_channels,
  265. kernel_size=(3, 3),
  266. stride=(2, 2),
  267. padding=1,
  268. bias=False,
  269. )
  270. self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False)
  271. self.act = nn.ReLU()
  272. def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor | None = None):
  273. if mask is not None:
  274. mask = mask.to(device=hidden_states.device)
  275. hidden_states = hidden_states * mask[:, None, :, None]
  276. hidden_states = self.conv(hidden_states.to(self.conv.weight.dtype))
  277. hidden_states = self.act(self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous())
  278. if mask is not None:
  279. mask = mask[:, ::2]
  280. return hidden_states, mask
  281. class Gemma4AudioSubSampleConvProjection(nn.Module):
  282. def __init__(self, config: Gemma4AudioConfig):
  283. super().__init__()
  284. self.layer0 = Gemma4AudioSubSampleConvProjectionLayer(
  285. in_channels=1,
  286. out_channels=config.subsampling_conv_channels[0],
  287. norm_eps=config.rms_norm_eps,
  288. )
  289. self.layer1 = Gemma4AudioSubSampleConvProjectionLayer(
  290. in_channels=config.subsampling_conv_channels[0],
  291. out_channels=config.subsampling_conv_channels[1],
  292. norm_eps=config.rms_norm_eps,
  293. )
  294. proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1]
  295. self.input_proj_linear = nn.Linear(proj_input_dim, config.hidden_size, bias=False)
  296. def forward(
  297. self,
  298. input_features: torch.Tensor,
  299. input_features_mask: torch.Tensor | None = None,
  300. ) -> tuple[torch.Tensor, torch.Tensor]:
  301. hidden_states = input_features.unsqueeze(1)
  302. hidden_states, mask = self.layer0(hidden_states, input_features_mask)
  303. hidden_states, mask = self.layer1(hidden_states, mask)
  304. batch_size, _, seq_len, _ = hidden_states.shape
  305. hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1)
  306. return self.input_proj_linear(hidden_states), mask
  307. class Gemma4AudioFeedForward(nn.Module):
  308. def __init__(self, config: Gemma4AudioConfig):
  309. super().__init__()
  310. self.config = config
  311. self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4)
  312. self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size)
  313. self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size)
  314. self.post_layer_norm = Gemma4RMSNorm(config.hidden_size)
  315. self.act_fn = ACT2FN[config.hidden_act]
  316. self.gradient_clipping = config.gradient_clipping
  317. self.post_layer_scale = config.residual_weight
  318. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  319. # This is needed to avoid any underflow/overflow issues when clipping
  320. gradient_clipping = min(self.gradient_clipping, torch.finfo(self.ffw_layer_1.linear.weight.dtype).max)
  321. residual = hidden_states
  322. hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
  323. hidden_states = self.pre_layer_norm(hidden_states)
  324. hidden_states = self.ffw_layer_1(hidden_states)
  325. hidden_states = self.act_fn(hidden_states)
  326. hidden_states = self.ffw_layer_2(hidden_states)
  327. hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
  328. hidden_states = self.post_layer_norm(hidden_states)
  329. hidden_states *= self.post_layer_scale
  330. hidden_states += residual
  331. return hidden_states
  332. # TODO: this could be imported from Voxtral realtime
  333. class Gemma4AudioCausalConv1d(nn.Conv1d):
  334. # def __init__(
  335. # self,
  336. # in_channels: int,
  337. # out_channels: int,
  338. # kernel_size: int,
  339. # # cache_key: str,
  340. # stride: int = 1,
  341. # dilation: int = 1,
  342. # bias: bool = True,
  343. # ):
  344. # super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias)
  345. # self.cache_key = cache_key
  346. @cached_property
  347. def left_pad(self):
  348. effective_kernel_size = (self.kernel_size[0] - 1) * self.dilation[0] + 1
  349. return effective_kernel_size - self.stride[0]
  350. def forward(
  351. self,
  352. x: torch.Tensor,
  353. # padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, # TODO: we might want to add a cache?
  354. ) -> torch.Tensor:
  355. # if padding_cache is not None:
  356. # x = padding_cache.update(x, self.cache_key, self)
  357. # else:
  358. # x = nn.functional.pad(x, (self.left_pad, 0))
  359. x = nn.functional.pad(x, (self.left_pad, 0))
  360. return super().forward(x)
  361. class Gemma4AudioLightConv1d(nn.Module):
  362. def __init__(self, config: Gemma4AudioConfig):
  363. super().__init__()
  364. self.config = config
  365. self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2)
  366. self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
  367. self.depthwise_conv1d = Gemma4AudioCausalConv1d(
  368. in_channels=config.hidden_size,
  369. out_channels=config.hidden_size,
  370. kernel_size=config.conv_kernel_size,
  371. groups=config.hidden_size,
  372. bias=False,
  373. )
  374. self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True)
  375. self.conv_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True)
  376. self.act_fn = ACT2FN[config.hidden_act]
  377. self.gradient_clipping = config.gradient_clipping
  378. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  379. residual = hidden_states
  380. hidden_states = self.pre_layer_norm(hidden_states)
  381. hidden_states = self.linear_start(hidden_states)
  382. hidden_states = nn.functional.glu(hidden_states, dim=-1)
  383. hidden_states = self.depthwise_conv1d(hidden_states.transpose(1, 2)).transpose(1, 2)
  384. # This is needed to avoid any underflow/overflow issues when clipping
  385. gradient_clipping = min(self.gradient_clipping, torch.finfo(self.linear_start.linear.weight.dtype).max)
  386. hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
  387. hidden_states = self.conv_norm(hidden_states)
  388. hidden_states = self.act_fn(hidden_states)
  389. hidden_states = self.linear_end(hidden_states)
  390. hidden_states += residual
  391. return hidden_states
  392. class Gemma4AudioLayer(nn.Module):
  393. def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
  394. super().__init__()
  395. self.config = config
  396. self.feed_forward1 = Gemma4AudioFeedForward(config)
  397. self.feed_forward2 = Gemma4AudioFeedForward(config)
  398. self.self_attn = Gemma4AudioAttention(config, layer_idx)
  399. self.lconv1d = Gemma4AudioLightConv1d(config)
  400. self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size)
  401. self.norm_post_attn = Gemma4RMSNorm(config.hidden_size)
  402. self.norm_out = Gemma4RMSNorm(config.hidden_size)
  403. self.gradient_clipping = config.gradient_clipping
  404. def forward(
  405. self,
  406. hidden_states: torch.Tensor,
  407. attention_mask: torch.BoolTensor | None,
  408. position_embeddings: torch.Tensor,
  409. **kwargs: Unpack[TransformersKwargs],
  410. ) -> torch.Tensor:
  411. # This is needed to avoid any underflow/overflow issues when clipping
  412. gradient_clipping = min(self.gradient_clipping, torch.finfo(self.norm_pre_attn.weight.dtype).max)
  413. hidden_states = self.feed_forward1(hidden_states)
  414. residual = hidden_states
  415. hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
  416. hidden_states = self.norm_pre_attn(hidden_states)
  417. hidden_states, _ = self.self_attn(
  418. hidden_states=hidden_states,
  419. position_embeddings=position_embeddings,
  420. attention_mask=attention_mask,
  421. )
  422. hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
  423. hidden_states = self.norm_post_attn(hidden_states)
  424. hidden_states += residual
  425. hidden_states = self.lconv1d(hidden_states)
  426. hidden_states = self.feed_forward2(hidden_states)
  427. hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
  428. hidden_states = self.norm_out(hidden_states)
  429. return hidden_states
  430. # ---- Vision Encoder Layers ----
  431. class Gemma4VisionPatchEmbedder(nn.Module):
  432. def __init__(self, config: Gemma4VisionConfig):
  433. super().__init__()
  434. self.config = config
  435. self.hidden_size = config.hidden_size
  436. self.patch_size = config.patch_size
  437. self.position_embedding_size = config.position_embedding_size
  438. self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False)
  439. self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size))
  440. def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor:
  441. """Prepare patch positions map for matmul with positon embedding table."""
  442. # Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul.
  443. clamped_positions = pixel_position_ids.clamp(min=0)
  444. one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size)
  445. one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table)
  446. # Compute positional embeddings and sum across x and y.
  447. position_embeddings = one_hot @ self.position_embedding_table
  448. position_embeddings = position_embeddings.sum(dim=1)
  449. # Zero out embeddings for any padding patches.
  450. position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
  451. return position_embeddings
  452. def forward(
  453. self, pixel_values: torch.Tensor, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor
  454. ) -> torch.Tensor:
  455. # Gemma4 applies no normalization and instead scales in model code
  456. pixel_values = 2 * (pixel_values - 0.5)
  457. hidden_states = self.input_proj(pixel_values.to(self.input_proj.weight.dtype))
  458. position_embeddings = self._position_embeddings(pixel_position_ids, padding_positions)
  459. return hidden_states + position_embeddings
  460. class Gemma4VisionPooler(nn.Module):
  461. """Scaling and optional spatial pooling for vision encodings"""
  462. def __init__(self, config: Gemma4VisionConfig):
  463. super().__init__()
  464. self.hidden_size = config.hidden_size
  465. self.root_hidden_size = self.hidden_size**0.5
  466. def _avg_pool_by_positions(
  467. self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int
  468. ) -> tuple[torch.Tensor, torch.Tensor]:
  469. """
  470. 2D spatial pooling according to patch positions.
  471. Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between
  472. input and output lengths
  473. """
  474. input_seq_len = hidden_states.shape[1]
  475. k = int((input_seq_len // length) ** 0.5)
  476. k_squared = k**2
  477. if k_squared * length != input_seq_len:
  478. raise ValueError(
  479. f"Cannot pool {hidden_states.shape} to {length}: {k=}^2 times {length=} must be {input_seq_len}."
  480. )
  481. # Clamp padding positions (which are -1) to 0 so they don't break one_hot.
  482. # Padding patches have zero hidden states so they contribute nothing to the average.
  483. clamped_positions = pixel_position_ids.clamp(min=0)
  484. max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1
  485. kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor")
  486. kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1]
  487. weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared
  488. output = weights.transpose(1, 2) @ hidden_states.float()
  489. mask = torch.logical_not((weights == 0).all(dim=1))
  490. return output.to(hidden_states.dtype), mask
  491. def forward(
  492. self,
  493. hidden_states: torch.Tensor,
  494. pixel_position_ids: torch.Tensor,
  495. padding_positions: torch.Tensor,
  496. output_length: int | None = None,
  497. ) -> tuple[torch.Tensor, torch.Tensor]:
  498. if output_length > hidden_states.shape[1]:
  499. raise ValueError(
  500. f"Cannot output more soft tokens (requested {output_length}) than there are patches"
  501. f" ({hidden_states.shape[1]}). Change the value of `num_soft_tokens` when processing."
  502. )
  503. hidden_states = hidden_states.masked_fill(padding_positions.unsqueeze(-1), 0.0)
  504. if hidden_states.shape[1] != output_length:
  505. hidden_states, padding_positions = self._avg_pool_by_positions(
  506. hidden_states, pixel_position_ids, output_length
  507. )
  508. hidden_states *= self.root_hidden_size
  509. return hidden_states, padding_positions
  510. class Gemma4VisionMLP(nn.Module):
  511. def __init__(self, config: Gemma4VisionConfig):
  512. super().__init__()
  513. self.config = config
  514. self.hidden_size = config.hidden_size
  515. self.intermediate_size = config.intermediate_size
  516. self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
  517. self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
  518. self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size)
  519. self.act_fn = ACT2FN[config.hidden_activation]
  520. def forward(self, x):
  521. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  522. return down_proj
  523. class Gemma4VisionRotaryEmbedding(nn.Module):
  524. inv_freq: torch.Tensor # fix linting for `register_buffer`
  525. def __init__(self, config: Gemma4VisionConfig, device=None):
  526. super().__init__()
  527. self.max_seq_len_cached = config.max_position_embeddings
  528. self.original_max_seq_len = config.max_position_embeddings
  529. self.config = config
  530. self.rope_type = self.config.rope_parameters["rope_type"]
  531. rope_init_fn: Callable = self.compute_default_rope_parameters
  532. if self.rope_type != "default":
  533. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  534. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  535. self.register_buffer("inv_freq", inv_freq, persistent=False)
  536. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  537. @staticmethod
  538. def compute_default_rope_parameters(
  539. config: Gemma4VisionConfig | None = None,
  540. device: torch.device | None = None,
  541. seq_len: int | None = None,
  542. ) -> tuple["torch.Tensor", float]:
  543. """
  544. Computes the inverse frequencies according to the original RoPE implementation
  545. Args:
  546. config ([`~transformers.PreTrainedConfig`]):
  547. The model configuration.
  548. device (`torch.device`):
  549. The device to use for initialization of the inverse frequencies.
  550. seq_len (`int`, *optional*):
  551. The current sequence length. Unused for this type of RoPE.
  552. Returns:
  553. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  554. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  555. """
  556. base = config.rope_parameters["rope_theta"]
  557. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  558. # The reference implementation computes RoPE frequencies INDEPENDENTLY
  559. # for each spatial dimension using the partitioned head_dim (head_dim // ndim),
  560. # so both x and y dimensions get identical frequency ranges.
  561. # This is different from splitting the global inv_freq between dimensions.
  562. spatial_dim = dim // 2
  563. attention_factor = 1.0 # Unused in this type of RoPE
  564. inv_freq = 1.0 / (
  565. base
  566. ** (torch.arange(0, spatial_dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / spatial_dim)
  567. )
  568. return inv_freq, attention_factor
  569. @torch.no_grad()
  570. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  571. def forward(self, x, position_ids):
  572. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  573. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  574. # Multidimensional positions: [batch, num_patches, ndim]. Apply rotations to each spatial dim separately
  575. all_cos, all_sin = [], []
  576. for i in range(2):
  577. dim_position_ids = position_ids[:, :, i]
  578. dim_position_ids_expanded = dim_position_ids[:, None, :].float()
  579. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  580. freqs = (inv_freq_expanded.float() @ dim_position_ids_expanded.float()).transpose(1, 2)
  581. emb = torch.cat((freqs, freqs), dim=-1)
  582. cos = emb.cos() * self.attention_scaling
  583. sin = emb.sin() * self.attention_scaling
  584. all_cos.append(cos)
  585. all_sin.append(sin)
  586. cos = torch.cat(all_cos, dim=-1).to(dtype=x.dtype)
  587. sin = torch.cat(all_sin, dim=-1).to(dtype=x.dtype)
  588. return cos, sin
  589. def rotate_half(x):
  590. """Rotates half the hidden dims of the input."""
  591. x1 = x[..., : x.shape[-1] // 2]
  592. x2 = x[..., x.shape[-1] // 2 :]
  593. return torch.cat((-x2, x1), dim=-1)
  594. def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1):
  595. """Applies Rotary Position Embedding to the query and key tensors.
  596. Args:
  597. x (`torch.Tensor`): The tensor to embed.
  598. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  599. sin (`torch.Tensor`): The sine part of the rotary embedding.
  600. unsqueeze_dim (`int`, *optional*, defaults to 1):
  601. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  602. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  603. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  604. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  605. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  606. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  607. Returns:
  608. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  609. """
  610. cos = cos.unsqueeze(unsqueeze_dim)
  611. sin = sin.unsqueeze(unsqueeze_dim)
  612. return (x * cos) + (rotate_half(x) * sin)
  613. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  614. """
  615. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  616. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  617. """
  618. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  619. if n_rep == 1:
  620. return hidden_states
  621. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  622. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  623. def eager_attention_forward(
  624. module: nn.Module,
  625. query: torch.Tensor,
  626. key: torch.Tensor,
  627. value: torch.Tensor,
  628. attention_mask: torch.Tensor | None,
  629. dropout: float | int = 0.0,
  630. scaling: float | None = None,
  631. softcap: float | None = None,
  632. **kwargs,
  633. ) -> tuple[torch.Tensor, torch.Tensor]:
  634. if scaling is None:
  635. scaling = module.head_dim**-0.5
  636. key_states = repeat_kv(key, module.num_key_value_groups)
  637. value_states = repeat_kv(value, module.num_key_value_groups)
  638. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  639. if softcap is not None:
  640. attn_weights = attn_weights / softcap
  641. attn_weights = torch.tanh(attn_weights)
  642. attn_weights = attn_weights * softcap
  643. if attention_mask is not None:
  644. attn_weights = attn_weights + attention_mask
  645. # upcast attention to fp32
  646. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  647. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  648. attn_output = torch.matmul(attn_weights, value_states)
  649. attn_output = attn_output.transpose(1, 2).contiguous()
  650. return attn_output, attn_weights
  651. def apply_multidimensional_rope(
  652. x: torch.Tensor,
  653. cos: torch.Tensor,
  654. sin: torch.Tensor,
  655. position_ids: torch.Tensor,
  656. unsqueeze_dim: int = 2,
  657. ) -> torch.Tensor:
  658. """Applies multidimensional RoPE to inputs.
  659. Args:
  660. x (`torch.Tensor`): The tensor to embed.
  661. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  662. sin (`torch.Tensor`): The sine part of the rotary embedding.
  663. position_ids (`torch.Tensor`, *optional*):
  664. If position_ids.ndim + 2 == x.ndim, then this function passes through to `apply_rotary_pos_emb()`.
  665. Otherwise, position_ids is used to split the inputs, x, into multiple pieces, where each piece is fed to
  666. `apply_rotary_pos_emb()`, and then concatenated back together.
  667. unsqueeze_dim (`int`, *optional*, defaults to 1):
  668. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  669. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  670. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  671. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  672. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  673. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  674. Returns:
  675. Tensor of shape [B, L, N, H] with RoPE applied.
  676. """
  677. ndim = position_ids.shape[-1]
  678. num_input_channels = x.shape[-1]
  679. num_rotated_channels_per_dim = 2 * (num_input_channels // (2 * ndim))
  680. if num_rotated_channels_per_dim <= 0:
  681. raise ValueError(
  682. "Invalid configuration: num_rotated_channels_per_dim must be > 0, got"
  683. f" {num_rotated_channels_per_dim} (num_input_channels={num_input_channels},"
  684. f" ndim={ndim})"
  685. )
  686. # Correctly split the input tensor into ndim parts
  687. split_sizes = [num_rotated_channels_per_dim] * ndim
  688. x_parts = torch.split(x, split_sizes, dim=-1)
  689. cos_parts = torch.split(cos, split_sizes, dim=-1)
  690. sin_parts = torch.split(sin, split_sizes, dim=-1)
  691. y_parts = [
  692. apply_rotary_pos_emb(
  693. x=x_parts[k],
  694. cos=cos_parts[k],
  695. sin=sin_parts[k],
  696. unsqueeze_dim=unsqueeze_dim,
  697. )
  698. for k in range(ndim)
  699. ]
  700. return torch.cat(y_parts, dim=-1)
  701. @use_kernelized_func(apply_rotary_pos_emb)
  702. class Gemma4VisionAttention(nn.Module):
  703. """Multi-headed attention from 'Attention Is All You Need' paper"""
  704. def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
  705. super().__init__()
  706. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  707. self.config = config
  708. self.layer_idx = layer_idx
  709. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  710. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  711. self.scaling = 1.0
  712. self.attention_dropout = self.config.attention_dropout
  713. self.is_causal = False
  714. self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_attention_heads * self.head_dim)
  715. self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
  716. self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
  717. self.o_proj = Gemma4ClippableLinear(config, config.num_attention_heads * self.head_dim, config.hidden_size)
  718. self.q_norm = Gemma4RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  719. self.k_norm = Gemma4RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  720. self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False)
  721. def forward(
  722. self,
  723. hidden_states: torch.Tensor,
  724. position_embeddings: torch.Tensor = None,
  725. attention_mask: torch.Tensor | None = None,
  726. position_ids: torch.LongTensor | None = None,
  727. **kwargs: Unpack[TransformersKwargs],
  728. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  729. input_shape = hidden_states.shape[:-1]
  730. hidden_shape = (*input_shape, -1, self.head_dim)
  731. cos, sin = position_embeddings
  732. query_states = self.q_proj(hidden_states).view(hidden_shape)
  733. query_states = self.q_norm(query_states)
  734. query_states = apply_multidimensional_rope(query_states, cos, sin, position_ids)
  735. query_states = query_states.transpose(1, 2)
  736. key_states = self.k_proj(hidden_states).view(hidden_shape)
  737. key_states = self.k_norm(key_states)
  738. key_states = apply_multidimensional_rope(key_states, cos, sin, position_ids)
  739. key_states = key_states.transpose(1, 2)
  740. value_states = self.v_proj(hidden_states).view(hidden_shape)
  741. value_states = self.v_norm(value_states)
  742. value_states = value_states.transpose(1, 2)
  743. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  744. self.config._attn_implementation, eager_attention_forward
  745. )
  746. attn_output, attn_weights = attention_interface(
  747. self,
  748. query_states,
  749. key_states,
  750. value_states,
  751. attention_mask,
  752. dropout=self.attention_dropout if self.training else 0.0,
  753. scaling=self.scaling,
  754. **kwargs,
  755. )
  756. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  757. attn_output = self.o_proj(attn_output)
  758. return attn_output, attn_weights
  759. class Gemma4VisionEncoderLayer(GradientCheckpointingLayer):
  760. def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
  761. super().__init__()
  762. self.config = config
  763. self.hidden_size = config.hidden_size
  764. self.layer_idx = layer_idx
  765. self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx)
  766. self.mlp = Gemma4VisionMLP(config)
  767. self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  768. self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  769. self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  770. self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  771. def forward(
  772. self,
  773. hidden_states: torch.Tensor,
  774. position_embeddings: torch.Tensor = None,
  775. attention_mask: torch.Tensor | None = None,
  776. position_ids: torch.LongTensor | None = None,
  777. **kwargs: Unpack[TransformersKwargs],
  778. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  779. residual = hidden_states
  780. hidden_states = self.input_layernorm(hidden_states)
  781. hidden_states, _ = self.self_attn(
  782. hidden_states=hidden_states,
  783. position_embeddings=position_embeddings,
  784. attention_mask=attention_mask,
  785. position_ids=position_ids,
  786. **kwargs,
  787. )
  788. hidden_states = self.post_attention_layernorm(hidden_states)
  789. hidden_states = residual + hidden_states
  790. residual = hidden_states
  791. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  792. hidden_states = self.mlp(hidden_states)
  793. hidden_states = self.post_feedforward_layernorm(hidden_states)
  794. hidden_states = residual + hidden_states
  795. return hidden_states
  796. class Gemma4VisionEncoder(nn.Module):
  797. def __init__(self, config: Gemma4VisionConfig):
  798. super().__init__()
  799. self.config = config
  800. self.num_layers = config.num_hidden_layers
  801. self.rotary_emb = Gemma4VisionRotaryEmbedding(config)
  802. self.layers = nn.ModuleList(
  803. [Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)]
  804. )
  805. def forward(
  806. self,
  807. inputs_embeds: torch.Tensor,
  808. attention_mask: torch.Tensor,
  809. pixel_position_ids: torch.LongTensor | None = None,
  810. **kwargs: Unpack[TransformersKwargs],
  811. ) -> BaseModelOutputWithPast:
  812. r"""
  813. pixel_position_ids (torch.Tensor):
  814. Patch positions as (x, y) coordinates in the image as [batch, num_patches, 2].
  815. """
  816. attention_mask = create_bidirectional_mask(
  817. config=self.config,
  818. inputs_embeds=inputs_embeds,
  819. attention_mask=attention_mask,
  820. )
  821. # embed positions
  822. hidden_states = inputs_embeds
  823. position_embeddings = self.rotary_emb(hidden_states, pixel_position_ids)
  824. # decoder layers
  825. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  826. hidden_states = decoder_layer(
  827. hidden_states,
  828. attention_mask=attention_mask,
  829. position_embeddings=position_embeddings,
  830. position_ids=pixel_position_ids,
  831. **kwargs,
  832. )
  833. return BaseModelOutputWithPast(last_hidden_state=hidden_states)
  834. class Gemma4TextMLP(nn.Module):
  835. def __init__(self, config: Gemma4TextConfig, layer_idx: int):
  836. super().__init__()
  837. first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers
  838. is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
  839. use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
  840. self.config = config
  841. self.hidden_size = config.hidden_size
  842. self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1)
  843. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  844. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  845. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  846. self.act_fn = ACT2FN[config.hidden_activation]
  847. def forward(self, x):
  848. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  849. return down_proj
  850. class Gemma4TextRotaryEmbedding(nn.Module):
  851. inv_freq: torch.Tensor # fix linting for `register_buffer`
  852. def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None):
  853. super().__init__()
  854. self.max_seq_len_cached = config.max_position_embeddings
  855. self.original_max_seq_len = config.max_position_embeddings
  856. self.config = config
  857. self.layer_types = set(config.layer_types)
  858. self.rope_init_fns: dict[str, Callable[..., tuple[torch.Tensor, float]]] = {}
  859. self.rope_type: dict[str, str] = {}
  860. for layer_type in self.layer_types:
  861. rope_params = self.config.rope_parameters[layer_type]
  862. if rope_params is None:
  863. continue
  864. if (rope_type := rope_params["rope_type"]) != "default":
  865. rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
  866. else:
  867. rope_init_fn = self.compute_default_rope_parameters
  868. self.rope_init_fns[layer_type] = rope_init_fn
  869. self.rope_type[layer_type] = rope_type
  870. rope_init_fn_kwargs = {"device": device, "layer_type": layer_type}
  871. if layer_type == "full_attention" and rope_type == "proportional":
  872. rope_init_fn_kwargs["head_dim_key"] = "global_head_dim"
  873. curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, **rope_init_fn_kwargs)
  874. self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
  875. self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
  876. setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
  877. @staticmethod
  878. def compute_default_rope_parameters(
  879. config: Gemma4TextConfig | None = None,
  880. device: Optional["torch.device"] = None,
  881. seq_len: int | None = None,
  882. layer_type: str | None = None,
  883. ) -> tuple["torch.Tensor", float]:
  884. """
  885. Computes the inverse frequencies according to the original RoPE implementation
  886. Args:
  887. config ([`~transformers.PreTrainedConfig`]):
  888. The model configuration.
  889. device (`torch.device`):
  890. The device to use for initialization of the inverse frequencies.
  891. seq_len (`int`, *optional*):
  892. The current sequence length. Unused for this type of RoPE.
  893. layer_type (`str`, *optional*):
  894. The current layer type if the model has different RoPE parameters per type.
  895. Should not be used unless `config.layer_types is not None`
  896. Returns:
  897. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  898. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  899. """
  900. # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
  901. base = config.rope_parameters[layer_type]["rope_theta"]
  902. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  903. attention_factor = 1.0 # Unused in this type of RoPE
  904. # Compute the inverse frequencies
  905. inv_freq = 1.0 / (
  906. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  907. )
  908. return inv_freq, attention_factor
  909. @torch.no_grad()
  910. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  911. def forward(self, x, position_ids, layer_type=None):
  912. inv_freq = getattr(self, f"{layer_type}_inv_freq")
  913. attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
  914. inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  915. position_ids_expanded = position_ids[:, None, :].float()
  916. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  917. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  918. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  919. emb = torch.cat((freqs, freqs), dim=-1)
  920. cos = emb.cos() * attention_scaling
  921. sin = emb.sin() * attention_scaling
  922. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  923. @use_kernelized_func(apply_rotary_pos_emb)
  924. class Gemma4TextAttention(nn.Module):
  925. """Multi-headed attention from 'Attention Is All You Need' paper"""
  926. def __init__(self, config: Gemma4TextConfig, layer_idx: int):
  927. super().__init__()
  928. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  929. self.config = config
  930. self.layer_idx = layer_idx
  931. self.is_sliding = self.layer_type == "sliding_attention"
  932. self.sliding_window = config.sliding_window if self.is_sliding else None
  933. self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
  934. self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding
  935. num_key_value_heads = (
  936. config.num_global_key_value_heads if self.use_alternative_attention else config.num_key_value_heads
  937. )
  938. self.num_key_value_groups = config.num_attention_heads // num_key_value_heads
  939. self.scaling = 1.0
  940. self.attention_dropout = self.config.attention_dropout
  941. self.is_causal = config.use_bidirectional_attention != "all"
  942. # Shared kv cache
  943. first_kv_shared_layer_idx = self.config.num_hidden_layers - getattr(self.config, "num_kv_shared_layers", 0)
  944. self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
  945. prev_layers = config.layer_types[:first_kv_shared_layer_idx]
  946. if self.is_kv_shared_layer:
  947. # For shared layers, find the last non-shared layer of the same type before sharing starts
  948. self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
  949. self.store_full_length_kv = False
  950. else:
  951. self.kv_shared_layer_index = None
  952. # For non-shared layers, store full-length kv if this is the last non-shared layer of its type
  953. self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
  954. config.layer_types[layer_idx]
  955. )
  956. self.q_proj = nn.Linear(
  957. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  958. )
  959. self.q_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps)
  960. # Layers sharing kv states don't need any weight matrices
  961. if not self.is_kv_shared_layer:
  962. self.k_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps)
  963. self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False)
  964. self.k_proj = nn.Linear(
  965. config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias
  966. )
  967. self.v_proj = (
  968. nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias)
  969. if not self.use_alternative_attention
  970. else None
  971. )
  972. self.o_proj = nn.Linear(
  973. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  974. )
  975. def forward(
  976. self,
  977. hidden_states: torch.Tensor,
  978. position_embeddings: torch.Tensor,
  979. attention_mask: torch.Tensor | None,
  980. shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
  981. past_key_values: Cache | None = None,
  982. **kwargs: Unpack[FlashAttentionKwargs],
  983. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  984. input_shape = hidden_states.shape[:-1]
  985. hidden_shape = (*input_shape, -1, self.head_dim)
  986. cos, sin = position_embeddings
  987. query_states = self.q_proj(hidden_states).view(hidden_shape)
  988. query_states = self.q_norm(query_states)
  989. query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
  990. query_states = query_states.transpose(1, 2)
  991. # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
  992. # We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
  993. # once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
  994. if self.is_kv_shared_layer:
  995. key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
  996. # Device of past layer may be different from current one
  997. key_states = key_states.to(query_states.device)
  998. value_states = value_states.to(query_states.device)
  999. else:
  1000. key_states = self.k_proj(hidden_states).view(hidden_shape)
  1001. value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
  1002. key_states = self.k_norm(key_states)
  1003. key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
  1004. key_states = key_states.transpose(1, 2)
  1005. value_states = self.v_norm(value_states)
  1006. value_states = value_states.transpose(1, 2)
  1007. if past_key_values is not None and not self.is_kv_shared_layer:
  1008. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  1009. if self.store_full_length_kv:
  1010. shared_kv_states[self.layer_idx] = key_states, value_states
  1011. attention_interface: Callable = eager_attention_forward
  1012. if self.config._attn_implementation != "eager":
  1013. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  1014. attn_output, attn_weights = attention_interface(
  1015. self,
  1016. query_states,
  1017. key_states,
  1018. value_states,
  1019. attention_mask,
  1020. dropout=self.attention_dropout if self.training else 0.0,
  1021. scaling=self.scaling,
  1022. sliding_window=self.sliding_window,
  1023. **kwargs,
  1024. )
  1025. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  1026. attn_output = self.o_proj(attn_output)
  1027. return attn_output, attn_weights
  1028. @use_experts_implementation
  1029. class Gemma4TextExperts(nn.Module):
  1030. """Collection of expert weights stored as 3D tensors."""
  1031. def __init__(self, config: Gemma4TextConfig):
  1032. super().__init__()
  1033. self.num_experts = config.num_experts
  1034. self.hidden_dim = config.hidden_size
  1035. self.intermediate_dim = config.moe_intermediate_size
  1036. self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
  1037. self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
  1038. self.act_fn = ACT2FN[config.hidden_activation]
  1039. def forward(
  1040. self,
  1041. hidden_states: torch.Tensor,
  1042. top_k_index: torch.Tensor,
  1043. top_k_weights: torch.Tensor,
  1044. ) -> torch.Tensor:
  1045. final_hidden_states = torch.zeros_like(hidden_states)
  1046. with torch.no_grad():
  1047. expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
  1048. expert_mask = expert_mask.permute(2, 1, 0)
  1049. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  1050. for expert_idx in expert_hit:
  1051. expert_idx = expert_idx[0]
  1052. if expert_idx == self.num_experts:
  1053. continue
  1054. top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
  1055. current_state = hidden_states[token_idx]
  1056. gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
  1057. current_hidden_states = self.act_fn(gate) * up
  1058. current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
  1059. current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
  1060. final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
  1061. return final_hidden_states
  1062. class Gemma4TextRouter(nn.Module):
  1063. def __init__(self, config: Gemma4TextConfig):
  1064. super().__init__()
  1065. self.config = config
  1066. self.hidden_size = config.hidden_size
  1067. self.scalar_root_size = self.hidden_size**-0.5
  1068. self.eps = config.rms_norm_eps
  1069. self.norm = Gemma4RMSNorm(self.hidden_size, eps=self.eps, with_scale=False)
  1070. self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False)
  1071. self.scale = nn.Parameter(torch.ones(self.hidden_size))
  1072. self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))
  1073. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  1074. hidden_states = self.norm(hidden_states)
  1075. hidden_states = hidden_states * self.scale * self.scalar_root_size
  1076. expert_scores = self.proj(hidden_states) # [B*S, E]
  1077. router_probabilities = nn.functional.softmax(expert_scores, dim=-1)
  1078. # topk returns both values (probabilities) and indices directly
  1079. top_k_weights, top_k_index = torch.topk(
  1080. router_probabilities,
  1081. k=self.config.top_k_experts,
  1082. dim=-1,
  1083. ) # both [B*S, K]
  1084. # Normalize the top-k weights so they sum to 1 per token
  1085. top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
  1086. # Apply per-expert scale directly to the weights
  1087. top_k_weights = top_k_weights * self.per_expert_scale[top_k_index]
  1088. return router_probabilities, top_k_weights, top_k_index
  1089. class Gemma4TextDecoderLayer(GradientCheckpointingLayer):
  1090. def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int):
  1091. super().__init__()
  1092. self.config = config
  1093. self.hidden_size = config.hidden_size
  1094. self.layer_idx = layer_idx
  1095. self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx)
  1096. self.mlp = Gemma4TextMLP(config, layer_idx)
  1097. self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1098. self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1099. self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1100. self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1101. self.register_buffer("layer_scalar", torch.ones(1))
  1102. self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
  1103. if self.hidden_size_per_layer_input:
  1104. self.act_fn = ACT2FN[config.hidden_activation]
  1105. self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
  1106. self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
  1107. self.post_per_layer_input_norm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1108. self.enable_moe_block = config.enable_moe_block
  1109. if self.enable_moe_block:
  1110. self.router = Gemma4TextRouter(config)
  1111. self.experts = Gemma4TextExperts(config)
  1112. self.post_feedforward_layernorm_1 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1113. self.post_feedforward_layernorm_2 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1114. self.pre_feedforward_layernorm_2 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1115. def forward(
  1116. self,
  1117. hidden_states: torch.Tensor,
  1118. per_layer_input: torch.Tensor = None,
  1119. shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]] | None = None,
  1120. position_embeddings: torch.Tensor = None,
  1121. attention_mask: torch.Tensor | None = None,
  1122. position_ids: torch.LongTensor | None = None,
  1123. past_key_values: Cache | None = None,
  1124. **kwargs,
  1125. ) -> torch.Tensor:
  1126. residual = hidden_states
  1127. hidden_states = self.input_layernorm(hidden_states)
  1128. hidden_states, _ = self.self_attn(
  1129. hidden_states=hidden_states,
  1130. position_embeddings=position_embeddings,
  1131. attention_mask=attention_mask,
  1132. shared_kv_states=shared_kv_states,
  1133. position_ids=position_ids,
  1134. past_key_values=past_key_values,
  1135. **kwargs,
  1136. )
  1137. hidden_states = self.post_attention_layernorm(hidden_states)
  1138. hidden_states = residual + hidden_states
  1139. residual = hidden_states
  1140. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  1141. hidden_states = self.mlp(hidden_states)
  1142. if self.enable_moe_block:
  1143. hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states)
  1144. # Take hidden states before MLP here
  1145. hidden_states_flat = residual.reshape(-1, residual.shape[-1])
  1146. _, top_k_weights, top_k_index = self.router(hidden_states_flat)
  1147. hidden_states_2 = self.pre_feedforward_layernorm_2(hidden_states_flat)
  1148. hidden_states_2 = self.experts(hidden_states_2, top_k_index, top_k_weights)
  1149. hidden_states_2 = hidden_states_2.reshape(residual.shape)
  1150. hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2)
  1151. # Combine mlp and moe outputs
  1152. hidden_states = hidden_states_1 + hidden_states_2
  1153. hidden_states = self.post_feedforward_layernorm(hidden_states)
  1154. hidden_states = residual + hidden_states
  1155. if self.hidden_size_per_layer_input:
  1156. residual = hidden_states
  1157. hidden_states = self.per_layer_input_gate(hidden_states)
  1158. hidden_states = self.act_fn(hidden_states)
  1159. hidden_states = hidden_states * per_layer_input
  1160. hidden_states = self.per_layer_projection(hidden_states)
  1161. hidden_states = self.post_per_layer_input_norm(hidden_states)
  1162. hidden_states = residual + hidden_states
  1163. hidden_states *= self.layer_scalar
  1164. return hidden_states
  1165. class Gemma4TextScaledWordEmbedding(nn.Embedding):
  1166. """
  1167. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  1168. """
  1169. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
  1170. super().__init__(num_embeddings, embedding_dim, padding_idx)
  1171. self.scalar_embed_scale = embed_scale
  1172. self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
  1173. def forward(self, input_ids: torch.Tensor):
  1174. return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
  1175. # ---- Model Classes ----
  1176. class Gemma4PreTrainedModel(PreTrainedModel):
  1177. config: Gemma4Config
  1178. supports_gradient_checkpointing = True
  1179. _supports_flash_attn = True
  1180. _supports_sdpa = True
  1181. _supports_flex_attn = True
  1182. _can_compile_fullgraph = True
  1183. _supports_attention_backend = True
  1184. _no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"]
  1185. _skip_keys_device_placement = ["past_key_values", "shared_kv_states"]
  1186. input_modalities = ("image", "text", "video", "audio")
  1187. @torch.no_grad()
  1188. def _init_weights(self, module):
  1189. super()._init_weights(module)
  1190. if isinstance(module, Gemma4VisionPatchEmbedder):
  1191. init.ones_(module.position_embedding_table)
  1192. elif isinstance(module, Gemma4AudioRelPositionalEncoding):
  1193. min_timescale = 1.0
  1194. max_timescale = 10000.0
  1195. num_timescales = module.hidden_size // 2
  1196. log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1)
  1197. inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
  1198. init.copy_(module.inv_timescales, inv_timescales.unsqueeze(0).unsqueeze(0))
  1199. elif isinstance(module, Gemma4AudioAttention):
  1200. init.constant_(module.softcap, module.attention_logits_soft_cap)
  1201. init.zeros_(module.per_dim_scale)
  1202. elif isinstance(module, Gemma4TextRotaryEmbedding):
  1203. for layer_type, rope_init_fn in module.rope_init_fns.items():
  1204. rope_init_fn_kwargs = {"layer_type": layer_type}
  1205. if layer_type == "full_attention" and module.rope_type[layer_type] == "proportional":
  1206. rope_init_fn_kwargs["head_dim_key"] = "global_head_dim"
  1207. curr_inv_freq, _ = rope_init_fn(module.config, **rope_init_fn_kwargs)
  1208. init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
  1209. init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
  1210. elif isinstance(module, Gemma4VisionRotaryEmbedding):
  1211. rope_fn = (
  1212. ROPE_INIT_FUNCTIONS[module.rope_type]
  1213. if module.rope_type != "default"
  1214. else module.compute_default_rope_parameters
  1215. )
  1216. buffer_value, _ = rope_fn(module.config)
  1217. init.copy_(module.inv_freq, buffer_value)
  1218. init.copy_(module.original_inv_freq, buffer_value)
  1219. elif isinstance(module, Gemma4TextScaledWordEmbedding):
  1220. init.constant_(module.embed_scale, module.scalar_embed_scale)
  1221. elif isinstance(module, Gemma4TextRouter):
  1222. init.ones_(module.scale)
  1223. init.ones_(module.per_expert_scale)
  1224. elif isinstance(module, Gemma4TextExperts):
  1225. std = self.config.initializer_range
  1226. init.normal_(module.gate_up_proj, mean=0.0, std=std)
  1227. init.normal_(module.down_proj, mean=0.0, std=std)
  1228. elif isinstance(module, Gemma4TextDecoderLayer):
  1229. init.ones_(module.layer_scalar)
  1230. elif isinstance(module, Gemma4ClippableLinear) and module.use_clipped_linears:
  1231. init.constant_(module.input_min, -float("inf"))
  1232. init.constant_(module.input_max, float("inf"))
  1233. init.constant_(module.output_min, -float("inf"))
  1234. init.constant_(module.output_max, float("inf"))
  1235. elif isinstance(module, Gemma4VisionModel) and module.config.standardize:
  1236. init.zeros_(module.std_bias)
  1237. init.ones_(module.std_scale)
  1238. @auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.")
  1239. class Gemma4TextModel(Gemma4PreTrainedModel):
  1240. config: Gemma4TextConfig
  1241. input_modalities = ("text",)
  1242. _can_record_outputs = {
  1243. "router_logits": OutputRecorder(Gemma4TextRouter, index=0),
  1244. "hidden_states": Gemma4TextDecoderLayer,
  1245. "attentions": Gemma4TextAttention,
  1246. }
  1247. def __init__(self, config: Gemma4TextConfig):
  1248. super().__init__(config)
  1249. self.padding_idx = config.pad_token_id
  1250. self.vocab_size = config.vocab_size
  1251. # Gemma4 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
  1252. self.embed_tokens = Gemma4TextScaledWordEmbedding(
  1253. config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
  1254. )
  1255. self.layers = nn.ModuleList(
  1256. [Gemma4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  1257. )
  1258. self.norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  1259. self.rotary_emb = Gemma4TextRotaryEmbedding(config)
  1260. self.gradient_checkpointing = False
  1261. self.unique_layer_types = set(self.config.layer_types)
  1262. self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
  1263. if self.hidden_size_per_layer_input:
  1264. self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding(
  1265. config.vocab_size_per_layer_input,
  1266. config.num_hidden_layers * config.hidden_size_per_layer_input,
  1267. self.padding_idx,
  1268. embed_scale=config.hidden_size_per_layer_input**0.5,
  1269. )
  1270. self.per_layer_input_scale = 2.0**-0.5
  1271. self.per_layer_model_projection = nn.Linear(
  1272. config.hidden_size,
  1273. config.num_hidden_layers * config.hidden_size_per_layer_input,
  1274. bias=False,
  1275. )
  1276. self.per_layer_model_projection_scale = config.hidden_size**-0.5
  1277. self.per_layer_projection_norm = Gemma4RMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps)
  1278. # Update `_keys_to_ignore_on_load_unexpected` to drop all k/v proj and norms for the shared layers
  1279. self._keys_to_ignore_on_load_unexpected = []
  1280. for i, layer in enumerate(self.layers):
  1281. if layer.self_attn.is_kv_shared_layer:
  1282. self._keys_to_ignore_on_load_unexpected.extend(
  1283. [f"layers.{i}.self_attn.{name}" for name in ("k_proj", "v_proj", "k_norm", "v_norm")]
  1284. )
  1285. # Initialize weights and apply final processing
  1286. self.post_init()
  1287. @merge_with_config_defaults
  1288. @capture_outputs
  1289. @auto_docstring
  1290. def forward(
  1291. self,
  1292. input_ids: torch.LongTensor | None = None,
  1293. attention_mask: torch.Tensor | None = None,
  1294. position_ids: torch.LongTensor | None = None,
  1295. past_key_values: Cache | None = None,
  1296. inputs_embeds: torch.FloatTensor | None = None,
  1297. per_layer_inputs: torch.Tensor | None = None,
  1298. use_cache: bool | None = None,
  1299. **kwargs: Unpack[TransformersKwargs],
  1300. ) -> BaseModelOutputWithPast:
  1301. r"""
  1302. per_layer_inputs (`torch.Tensor` of shape `(batch_size, sequence_length, num_hidden_layers, hidden_size_per_layer_input)`, *optional*):
  1303. Pre-computed per-layer input embeddings. When provided, these are used directly instead of being
  1304. computed from `input_ids` via `get_per_layer_inputs()`. This is primarily used by the multimodal
  1305. model (`Gemma4Model`) which pre-computes per-layer inputs from the original `input_ids` *before*
  1306. merging multimodal soft tokens into `inputs_embeds` — at which point the original token ids are
  1307. no longer recoverable.
  1308. """
  1309. if (input_ids is None) ^ (inputs_embeds is not None):
  1310. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1311. if input_ids is not None:
  1312. inputs_embeds = self.embed_tokens(input_ids)
  1313. if self.hidden_size_per_layer_input:
  1314. if per_layer_inputs is None:
  1315. per_layer_inputs = self.get_per_layer_inputs(input_ids, inputs_embeds)
  1316. per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
  1317. if use_cache and past_key_values is None:
  1318. past_key_values = DynamicCache(config=self.config)
  1319. if position_ids is None:
  1320. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1321. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  1322. position_ids = position_ids.unsqueeze(0)
  1323. # It may already have been prepared by e.g. `generate`
  1324. if not isinstance(causal_mask_mapping := attention_mask, dict):
  1325. # Prepare mask arguments
  1326. mask_kwargs = {
  1327. "config": self.config,
  1328. "inputs_embeds": inputs_embeds,
  1329. "attention_mask": attention_mask,
  1330. "past_key_values": past_key_values,
  1331. "position_ids": position_ids,
  1332. }
  1333. # Create the masks
  1334. causal_mask_mapping = {
  1335. "full_attention": create_causal_mask(**mask_kwargs),
  1336. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  1337. }
  1338. # embed positions
  1339. hidden_states = inputs_embeds
  1340. position_embeddings = {}
  1341. for layer_type in self.unique_layer_types:
  1342. position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
  1343. # Initialize as empty dict - it will be filled in the right layers
  1344. shared_kv_states = {}
  1345. # decoder layers
  1346. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  1347. per_layer_input = per_layer_inputs[:, :, i, :] if per_layer_inputs is not None else None
  1348. hidden_states = decoder_layer(
  1349. hidden_states,
  1350. per_layer_input,
  1351. shared_kv_states=shared_kv_states,
  1352. position_embeddings=position_embeddings[self.config.layer_types[i]],
  1353. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  1354. position_ids=position_ids,
  1355. past_key_values=past_key_values,
  1356. **kwargs,
  1357. )
  1358. hidden_states = self.norm(hidden_states)
  1359. return BaseModelOutputWithPast(
  1360. last_hidden_state=hidden_states,
  1361. past_key_values=past_key_values,
  1362. )
  1363. def get_per_layer_inputs(self, input_ids: torch.Tensor | None, inputs_embeds: torch.Tensor | None) -> torch.Tensor:
  1364. if not self.hidden_size_per_layer_input:
  1365. raise RuntimeError(
  1366. "Attempting to call get_per_layer_inputs() from a model initialized with a config that does not support"
  1367. f" per-layer embeddings. {self.config}"
  1368. )
  1369. # If only inputs_embeds are provided, reverse main embedding to find the input_ids - this allows to `generate`
  1370. # from `inputs_embeds` only as other models (otherwise it would need the value from both embeddings)
  1371. if input_ids is None:
  1372. with torch.no_grad():
  1373. input_ids = (
  1374. (
  1375. inputs_embeds[:, :, None, :]
  1376. == self.embed_tokens.weight[None, None, :, :] * self.config.hidden_size**0.5
  1377. )
  1378. .all(dim=3)
  1379. .nonzero()[:, 2]
  1380. )
  1381. try:
  1382. input_ids = input_ids.view(inputs_embeds.shape[:2])
  1383. except RuntimeError:
  1384. raise RuntimeError(
  1385. "It seems like you tried to call `forward` from `inputs_embeds` without providing `input_ids`, and that "
  1386. "the `inputs_embeds` you provided do not exactly match the embedding weights. Since Gemma4 needs to reverse "
  1387. "the embedding to compute another embedding, make sure you provide exact `inputs_embeds`"
  1388. )
  1389. return self.embed_tokens_per_layer(input_ids).reshape(
  1390. *input_ids.shape,
  1391. self.config.num_hidden_layers,
  1392. self.hidden_size_per_layer_input,
  1393. )
  1394. def project_per_layer_inputs(
  1395. self,
  1396. inputs_embeds: torch.Tensor,
  1397. per_layer_inputs: torch.Tensor | None = None,
  1398. ) -> torch.Tensor:
  1399. if not self.hidden_size_per_layer_input:
  1400. raise RuntimeError(
  1401. "Attempting to call project_per_layer_inputs() from a model initialized with a config that does not"
  1402. f" support per-layer embeddings. {self.config}"
  1403. )
  1404. per_layer_projection = self.per_layer_model_projection(inputs_embeds) * self.per_layer_model_projection_scale
  1405. per_layer_projection = per_layer_projection.reshape(
  1406. *inputs_embeds.shape[:-1],
  1407. self.config.num_hidden_layers,
  1408. self.hidden_size_per_layer_input,
  1409. )
  1410. per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
  1411. if per_layer_inputs is None:
  1412. return per_layer_projection
  1413. return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
  1414. @auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.")
  1415. class Gemma4ForCausalLM(Gemma4PreTrainedModel, GenerationMixin):
  1416. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  1417. _tp_plan = {"lm_head": "colwise_gather_output"}
  1418. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  1419. config: Gemma4TextConfig
  1420. base_model_prefix = "model"
  1421. def __init__(self, config: Gemma4TextConfig):
  1422. super().__init__(config)
  1423. self.model = Gemma4TextModel(config)
  1424. self.vocab_size = config.vocab_size
  1425. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1426. # Grab the ones from the child
  1427. self._keys_to_ignore_on_load_unexpected = [
  1428. f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
  1429. ]
  1430. # Initialize weights and apply final processing
  1431. self.post_init()
  1432. @can_return_tuple
  1433. @auto_docstring
  1434. def forward(
  1435. self,
  1436. input_ids: torch.LongTensor | None = None,
  1437. attention_mask: torch.Tensor | None = None,
  1438. position_ids: torch.LongTensor | None = None,
  1439. past_key_values: Cache | None = None,
  1440. inputs_embeds: torch.FloatTensor | None = None,
  1441. labels: torch.LongTensor | None = None,
  1442. use_cache: bool | None = None,
  1443. logits_to_keep: int | torch.Tensor = 0,
  1444. **kwargs: Unpack[TransformersKwargs],
  1445. ) -> CausalLMOutputWithPast:
  1446. r"""
  1447. Example:
  1448. ```python
  1449. >>> from transformers import AutoTokenizer, Gemma4ForCausalLM
  1450. >>> model = Gemma4ForCausalLM.from_pretrained("google/gemma-2-9b")
  1451. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
  1452. >>> prompt = "What is your favorite condiment?"
  1453. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1454. >>> # Generate
  1455. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1456. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1457. "What is your favorite condiment?"
  1458. ```"""
  1459. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1460. outputs: BaseModelOutputWithPast = self.model(
  1461. input_ids=input_ids,
  1462. attention_mask=attention_mask,
  1463. position_ids=position_ids,
  1464. past_key_values=past_key_values,
  1465. inputs_embeds=inputs_embeds,
  1466. use_cache=use_cache,
  1467. **kwargs,
  1468. )
  1469. hidden_states = outputs.last_hidden_state
  1470. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1471. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1472. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1473. if self.config.final_logit_softcapping is not None:
  1474. logits = logits / self.config.final_logit_softcapping
  1475. logits = torch.tanh(logits)
  1476. logits = logits * self.config.final_logit_softcapping
  1477. loss = None
  1478. if labels is not None:
  1479. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  1480. return CausalLMOutputWithPast(
  1481. loss=loss,
  1482. logits=logits,
  1483. past_key_values=outputs.past_key_values,
  1484. hidden_states=outputs.hidden_states,
  1485. attentions=outputs.attentions,
  1486. )
  1487. def sliding_window_mask_function(sliding_window: tuple[int, int]) -> Callable:
  1488. """
  1489. This creates uni/bidirectional attention mask with sliding window.
  1490. """
  1491. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  1492. left_window_size, right_window_size = sliding_window
  1493. dist = q_idx - kv_idx
  1494. left_mask = (dist >= 0) & (dist < left_window_size)
  1495. right_mask = (dist < 0) & (-dist < right_window_size)
  1496. return left_mask | right_mask
  1497. return inner_mask
  1498. class Gemma4AudioModel(Gemma4PreTrainedModel):
  1499. """An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture."""
  1500. config: Gemma4AudioConfig
  1501. main_input_name = "input_features"
  1502. base_model_prefix = "model.audio_tower" # prefix for Gemma4ForConditionalGeneration saved checkpoints, required for Gemma4AudioModel.from_pretrained()
  1503. _can_record_outputs = {
  1504. "hidden_states": Gemma4AudioLayer,
  1505. "attentions": Gemma4AudioAttention,
  1506. }
  1507. def __init__(self, config: Gemma4AudioConfig):
  1508. super().__init__(config)
  1509. self.config = config
  1510. self.subsample_conv_projection = Gemma4AudioSubSampleConvProjection(config)
  1511. self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config)
  1512. self.layers = nn.ModuleList(
  1513. [Gemma4AudioLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  1514. )
  1515. self.output_proj = nn.Linear(config.hidden_size, config.output_proj_dims, bias=True)
  1516. self.post_init()
  1517. def _convert_4d_mask_to_blocked_5d(self, mask_4d: torch.Tensor) -> torch.Tensor:
  1518. """
  1519. Convert a standard 4D attention mask `[batch_size, 1, seq_len, seq_len]` to the 5D blocked format
  1520. `[batch_size, 1, num_blocks, chunk_size, context_size]` expected by the chunked local attention,
  1521. """
  1522. batch_size, _, seq_len, _ = mask_4d.shape
  1523. device = mask_4d.device
  1524. chunk_size = self.config.attention_chunk_size
  1525. max_past_horizon = self.config.attention_context_left - 1
  1526. max_future_horizon = self.config.attention_context_right
  1527. num_blocks = (seq_len + chunk_size - 1) // chunk_size
  1528. padded_seq_len = num_blocks * chunk_size
  1529. pad_amount = padded_seq_len - seq_len
  1530. mask_4d = F.pad(mask_4d, (0, pad_amount, 0, pad_amount), value=False)
  1531. mask_5d = mask_4d.reshape(batch_size, 1, num_blocks, chunk_size, padded_seq_len)
  1532. mask_5d = F.pad(mask_5d, (max_past_horizon, max_future_horizon), value=False)
  1533. block_starts = torch.arange(num_blocks, device=device) * chunk_size
  1534. offsets = torch.arange(chunk_size + max_past_horizon + max_future_horizon, device=device)
  1535. kv_indices = block_starts[:, None] + offsets[None, :]
  1536. kv_indices = kv_indices[None, None, :, None, :].expand(batch_size, 1, -1, chunk_size, -1)
  1537. return mask_5d.gather(-1, kv_indices)
  1538. @merge_with_config_defaults
  1539. @capture_outputs
  1540. @auto_docstring(custom_intro="Encodes audio features to soft tokens.")
  1541. def forward(
  1542. self,
  1543. input_features: torch.Tensor,
  1544. attention_mask: torch.Tensor | None = None,
  1545. **kwargs: Unpack[TransformersKwargs],
  1546. ) -> tuple[torch.Tensor, torch.BoolTensor]:
  1547. hidden_states, output_mask = self.subsample_conv_projection(input_features, attention_mask)
  1548. position_embeddings = self.rel_pos_enc(hidden_states)
  1549. attention_mask = create_bidirectional_mask(
  1550. config=self.config,
  1551. inputs_embeds=hidden_states,
  1552. attention_mask=output_mask,
  1553. and_mask_function=sliding_window_mask_function(
  1554. (self.config.attention_context_left - 1, self.config.attention_context_right)
  1555. ),
  1556. )
  1557. attention_mask = self._convert_4d_mask_to_blocked_5d(attention_mask)
  1558. for encoder_layer in self.layers[: self.config.num_hidden_layers]:
  1559. hidden_states = encoder_layer(
  1560. hidden_states,
  1561. attention_mask=attention_mask,
  1562. position_embeddings=position_embeddings,
  1563. **kwargs,
  1564. )
  1565. hidden_states = self.output_proj(hidden_states)
  1566. return Gemma4AudioModelOutput(last_hidden_state=hidden_states, attention_mask=output_mask)
  1567. class Gemma4VisionModel(Gemma4PreTrainedModel):
  1568. """The Gemma 4 Vision Encoder."""
  1569. config = Gemma4VisionConfig
  1570. _can_record_outputs = {
  1571. "hidden_states": Gemma4VisionEncoderLayer,
  1572. "attentions": Gemma4VisionAttention,
  1573. }
  1574. def __init__(self, config: Gemma4VisionConfig):
  1575. super().__init__(config)
  1576. self.patch_embedder = Gemma4VisionPatchEmbedder(config)
  1577. self.encoder = Gemma4VisionEncoder(config)
  1578. self.pooler = Gemma4VisionPooler(config)
  1579. if self.config.standardize:
  1580. self.register_buffer("std_bias", torch.empty(self.config.hidden_size))
  1581. self.register_buffer("std_scale", torch.empty(self.config.hidden_size))
  1582. self.post_init()
  1583. @merge_with_config_defaults
  1584. @capture_outputs
  1585. @auto_docstring(custom_intro="Encodes image pixels to soft tokens from patches.")
  1586. def forward(
  1587. self,
  1588. pixel_values: torch.FloatTensor,
  1589. pixel_position_ids: torch.LongTensor,
  1590. **kwargs: Unpack[TransformersKwargs],
  1591. ) -> BaseModelOutputWithPast:
  1592. r"""
  1593. pixel_values (`torch.FloatTensor` or `list[torch.FloatTensor]`):
  1594. The images to encode. Either a single `[batch, channels, height, width]` tensor
  1595. (all images same size) or a list of `[1, channels, height, width]` tensors (different sizes).
  1596. pixel_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`):
  1597. The patch positions as (x, y) coordinates in the image. Padding patches are indicated by (-1, -1).
  1598. """
  1599. pooling_kernel_size = self.config.pooling_kernel_size
  1600. output_length = pixel_values.shape[-2] // (pooling_kernel_size * pooling_kernel_size)
  1601. padding_positions = (pixel_position_ids == -1).all(dim=-1)
  1602. inputs_embeds = self.patch_embedder(pixel_values, pixel_position_ids, padding_positions)
  1603. output = self.encoder(
  1604. inputs_embeds=inputs_embeds,
  1605. attention_mask=~padding_positions, # encoder expects True=valid, padding_positions is True=padding
  1606. pixel_position_ids=pixel_position_ids,
  1607. **kwargs,
  1608. )
  1609. hidden_states, pooler_mask = self.pooler(
  1610. hidden_states=output.last_hidden_state,
  1611. pixel_position_ids=pixel_position_ids,
  1612. padding_positions=padding_positions,
  1613. output_length=output_length,
  1614. )
  1615. # Strip padding tokens. pooler_mask is True = valid, False = padding.
  1616. hidden_states = hidden_states[pooler_mask]
  1617. if self.config.standardize:
  1618. hidden_states = (hidden_states - self.std_bias) * self.std_scale
  1619. return BaseModelOutputWithPast(last_hidden_state=hidden_states)
  1620. class Gemma4MultimodalEmbedder(nn.Module):
  1621. """Embeds token ids or soft tokens for multimodal content into language model space."""
  1622. def __init__(
  1623. self,
  1624. multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig,
  1625. text_config: Gemma4TextConfig,
  1626. ):
  1627. super().__init__()
  1628. self.multimodal_hidden_size = getattr(multimodal_config, "output_proj_dims", multimodal_config.hidden_size)
  1629. self.eps = multimodal_config.rms_norm_eps
  1630. self.text_hidden_size = text_config.hidden_size
  1631. self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
  1632. self.embedding_pre_projection_norm = Gemma4RMSNorm(self.multimodal_hidden_size, eps=self.eps, with_scale=False)
  1633. def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
  1634. """Embeds token ids or soft tokens for multimodal content into language model space.
  1635. Args:
  1636. inputs_embeds: A torch.Tensor containing the soft tokens to embed.
  1637. Returns:
  1638. A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
  1639. """
  1640. embs_normed = self.embedding_pre_projection_norm(inputs_embeds)
  1641. return self.embedding_projection(embs_normed)
  1642. # Identical as Gemma3 but modular can't resolve if we simply import. FIXME: @cyril
  1643. def token_type_ids_mask_function(
  1644. token_type_ids: torch.Tensor | None,
  1645. image_group_ids: torch.Tensor | None,
  1646. ) -> Callable | None:
  1647. """
  1648. This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
  1649. not start and end indices.
  1650. """
  1651. # Do not return an additional mask in this case
  1652. if token_type_ids is None:
  1653. return None
  1654. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  1655. seq_length = image_group_ids.shape[-1]
  1656. # clamp indices because with static cache they can go beyond `image_group_ids.shape[-1]`
  1657. q_idx_clamped = q_idx.clamp(max=seq_length - 1)
  1658. kv_idx_clamped = kv_idx.clamp(max=seq_length - 1)
  1659. # Unmask if the q and kv come from same group which is not -1 (i.e. non-text)
  1660. q_group = image_group_ids[batch_idx, q_idx_clamped]
  1661. kv_group = image_group_ids[batch_idx, kv_idx_clamped]
  1662. q_group = torch.where(q_idx < seq_length, q_group, -1)
  1663. kv_group = torch.where(kv_idx < seq_length, kv_group, -1)
  1664. return (q_group == kv_group) & (q_group >= 0)
  1665. return inner_mask
  1666. # Similar to Gemma3 but `sliding_mask_kwargs` and `mask_kwargs` are different and `token_type_ids->mm_token_type_ids`
  1667. def create_causal_mask_mapping(
  1668. config: PreTrainedConfig,
  1669. inputs_embeds: torch.Tensor,
  1670. attention_mask: torch.Tensor | None,
  1671. past_key_values: Cache | None,
  1672. position_ids: torch.Tensor | None,
  1673. mm_token_type_ids: torch.Tensor | None = None,
  1674. pixel_values: torch.FloatTensor | None = None,
  1675. is_training: bool = False,
  1676. is_first_iteration: bool | None = None,
  1677. **kwargs,
  1678. ) -> dict:
  1679. """
  1680. Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
  1681. for all kinds of forward passes. Gemma4 uses a bidirectional mask for images.
  1682. Uses `pixel_values` as an optional input to disambiguate edge cases.
  1683. """
  1684. if is_training and mm_token_type_ids is None:
  1685. raise ValueError("`mm_token_type_ids` is required as a model input when training")
  1686. mask_kwargs = {
  1687. "config": config.get_text_config(),
  1688. "inputs_embeds": inputs_embeds,
  1689. "attention_mask": attention_mask,
  1690. "past_key_values": past_key_values,
  1691. "position_ids": position_ids,
  1692. }
  1693. sliding_mask_kwargs = mask_kwargs.copy()
  1694. # NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
  1695. # (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
  1696. # means). Determining prefill in that case requires checking data values, which is not compile-compatible.
  1697. is_first_iteration = (
  1698. is_first_iteration
  1699. if is_first_iteration is not None
  1700. else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
  1701. )
  1702. if mm_token_type_ids is not None and is_first_iteration:
  1703. # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
  1704. # undo the causal masking)
  1705. # First find where a new vision block starts. Vision tokens cannot attend to
  1706. # future vision tokens, but can attend to all prev tokens and to itself bidirectionally
  1707. is_vision = (mm_token_type_ids == 1) | (mm_token_type_ids == 2)
  1708. is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1)
  1709. is_prev_vision[..., 0] = False
  1710. new_vision_starts = is_vision & ~is_prev_vision
  1711. vision_group_ids = torch.cumsum(new_vision_starts.int(), dim=1) - 1
  1712. vision_group_ids = torch.where(is_vision, vision_group_ids, -1)
  1713. sliding_mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
  1714. mm_token_type_ids.to(inputs_embeds.device), vision_group_ids
  1715. )
  1716. return {
  1717. "full_attention": create_causal_mask(**mask_kwargs),
  1718. "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
  1719. }
  1720. @auto_docstring(
  1721. custom_intro="""
  1722. The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a
  1723. language modeling head.
  1724. """
  1725. )
  1726. class Gemma4Model(Gemma4PreTrainedModel):
  1727. # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
  1728. accepts_loss_kwargs = False
  1729. def __init__(self, config: Gemma4Config):
  1730. super().__init__(config)
  1731. self.vocab_size = config.text_config.vocab_size
  1732. language_model = AutoModel.from_config(config=config.text_config)
  1733. self.language_model = language_model
  1734. self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
  1735. self.vision_tower = AutoModel.from_config(config.vision_config) if config.vision_config is not None else None
  1736. self.embed_vision = (
  1737. Gemma4MultimodalEmbedder(config.vision_config, config.text_config)
  1738. if config.vision_config is not None
  1739. else None
  1740. )
  1741. self.audio_tower = AutoModel.from_config(config.audio_config) if config.audio_config is not None else None
  1742. self.embed_audio = (
  1743. Gemma4MultimodalEmbedder(config.audio_config, config.text_config)
  1744. if config.audio_config is not None
  1745. else None
  1746. )
  1747. # Grab the ones from the child
  1748. self._keys_to_ignore_on_load_unexpected = [
  1749. f"language_model.{name}" for name in self.language_model._keys_to_ignore_on_load_unexpected
  1750. ]
  1751. self.post_init()
  1752. def get_input_embeddings(self):
  1753. return self.language_model.get_input_embeddings()
  1754. def set_input_embeddings(self, value):
  1755. self.language_model.set_input_embeddings(value)
  1756. @can_return_tuple
  1757. @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.")
  1758. def get_image_features(
  1759. self,
  1760. pixel_values: torch.FloatTensor,
  1761. image_position_ids: torch.LongTensor | None = None,
  1762. **kwargs: Unpack[TransformersKwargs],
  1763. ) -> BaseModelOutputWithPooling:
  1764. r"""
  1765. image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*):
  1766. The patch positions as (x, y) coordinates in the image. Padding patches are indicated by (-1, -1).
  1767. """
  1768. vision_outputs = self.vision_tower(
  1769. pixel_values=pixel_values,
  1770. pixel_position_ids=image_position_ids,
  1771. **kwargs,
  1772. )
  1773. last_hidden_state = vision_outputs.last_hidden_state
  1774. vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state)
  1775. return vision_outputs
  1776. def get_placeholder_mask(
  1777. self,
  1778. input_ids: torch.LongTensor | None = None,
  1779. inputs_embeds: torch.FloatTensor | None = None,
  1780. ) -> tuple[torch.BoolTensor, torch.BoolTensor, torch.BoolTensor]:
  1781. """
  1782. Obtains mask for multimodal placeholders (replaced by soft tokens) and hard text tokens.
  1783. Masks will be obtained from `mm_token_type_ids`, `input_ids`, or `inputs_embeds` as available and in that
  1784. precedence order. If passing `input_ids` or `inputs_embeds`, the image mask will be derived using
  1785. `config.image_token_id`. Same goes for audio and video masks
  1786. Args:
  1787. input_ids: A tensor containing the hard token IDs from the text tokenizer.
  1788. inputs_embeds: A tensor containing the embeddings for all hard text tokens.
  1789. Returns:
  1790. image_mask, video_mask, audio_mask
  1791. """
  1792. if input_ids is not None:
  1793. special_image_mask = input_ids == self.config.image_token_id
  1794. special_video_mask = input_ids == self.config.video_token_id
  1795. special_audio_mask = input_ids == self.config.audio_token_id
  1796. else:
  1797. special_image_mask = (
  1798. inputs_embeds
  1799. == self.get_input_embeddings()(
  1800. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  1801. )
  1802. ).all(-1)
  1803. special_video_mask = (
  1804. inputs_embeds
  1805. == self.get_input_embeddings()(
  1806. torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
  1807. )
  1808. ).all(-1)
  1809. special_audio_mask = (
  1810. inputs_embeds
  1811. == self.get_input_embeddings()(
  1812. torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
  1813. )
  1814. ).all(-1)
  1815. return special_image_mask, special_video_mask, special_audio_mask
  1816. @merge_with_config_defaults
  1817. @can_return_tuple
  1818. @auto_docstring
  1819. def forward(
  1820. self,
  1821. input_ids: torch.LongTensor | None = None,
  1822. pixel_values: torch.FloatTensor | None = None,
  1823. pixel_values_videos: torch.FloatTensor | None = None,
  1824. input_features: torch.FloatTensor | None = None,
  1825. attention_mask: torch.Tensor | None = None,
  1826. input_features_mask: torch.Tensor | None = None,
  1827. position_ids: torch.LongTensor | None = None,
  1828. past_key_values: Cache | None = None,
  1829. mm_token_type_ids: torch.LongTensor | None = None,
  1830. inputs_embeds: torch.FloatTensor | None = None,
  1831. use_cache: bool | None = None,
  1832. image_position_ids: torch.LongTensor | None = None,
  1833. video_position_ids: torch.LongTensor | None = None,
  1834. **kwargs: Unpack[TransformersKwargs],
  1835. ) -> Gemma4ModelOutputWithPast:
  1836. r"""
  1837. input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
  1838. The attention mask for the input audio.
  1839. image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*):
  1840. 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding.
  1841. Passed through to the vision encoder for positional embedding computation.
  1842. video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*):
  1843. 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding.
  1844. Passed through to the vision encoder for positional embedding computation.
  1845. """
  1846. if (input_ids is None) ^ (inputs_embeds is not None):
  1847. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1848. image_mask, video_mask, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds)
  1849. multimodal_mask = image_mask | video_mask | audio_mask
  1850. # Replace image id with PAD if the image token if OOV, to avoid index-errors
  1851. llm_input_ids = None
  1852. if inputs_embeds is None:
  1853. llm_input_ids = input_ids.clone()
  1854. llm_input_ids[multimodal_mask] = self.config.text_config.pad_token_id
  1855. inputs_embeds = self.get_input_embeddings()(llm_input_ids)
  1856. if self.config.get_text_config().hidden_size_per_layer_input:
  1857. pad_embedding = self.language_model.embed_tokens.weight[self.config.text_config.pad_token_id, :]
  1858. llm_inputs_embeds = torch.where(multimodal_mask[..., None], pad_embedding.view(1, 1, -1), inputs_embeds)
  1859. per_layer_inputs = self.language_model.get_per_layer_inputs(llm_input_ids, llm_inputs_embeds)
  1860. else:
  1861. per_layer_inputs = None
  1862. # Merge text and images
  1863. if pixel_values is not None:
  1864. image_features = self.get_image_features(pixel_values, image_position_ids, return_dict=True).pooler_output
  1865. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  1866. # Confirm the number of soft tokens from the vision tower matches the number of slots in the embeddings.
  1867. n_image_tokens = image_mask.sum()
  1868. image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1869. torch_compilable_check(
  1870. inputs_embeds[image_mask].numel() == image_features.numel(),
  1871. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features:"
  1872. f" {image_features.shape[0]}",
  1873. )
  1874. inputs_embeds = inputs_embeds.masked_scatter(
  1875. image_mask.to(inputs_embeds.device), image_features.to(inputs_embeds.device)
  1876. )
  1877. if pixel_values_videos is not None:
  1878. video_features = self.get_video_features(
  1879. pixel_values_videos, video_position_ids, return_dict=True
  1880. ).pooler_output
  1881. video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
  1882. # Confirm the number of soft tokens from the vision tower matches the number of slots in the embeddings.
  1883. n_video_tokens = video_mask.sum()
  1884. video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1885. torch_compilable_check(
  1886. inputs_embeds[video_mask].numel() == video_features.numel(),
  1887. f"Video features and video tokens do not match, tokens: {n_video_tokens}, features:"
  1888. f" {video_features.shape[0]}",
  1889. )
  1890. inputs_embeds = inputs_embeds.masked_scatter(
  1891. video_mask.to(inputs_embeds.device), video_features.to(inputs_embeds.device)
  1892. )
  1893. # Merge text and audio
  1894. if input_features is not None and input_features_mask is not None:
  1895. audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True)
  1896. audio_features = audio_output.pooler_output
  1897. audio_mask_from_encoder = audio_output.attention_mask # True = valid
  1898. # Strip padding tokens: only keep real (non-padding) audio soft tokens.
  1899. # audio_mask_from_encoder is True for valid positions, False for padding tokens.
  1900. # This mirrors the vision encoder's padding stripping (see Gemma4VisionEncoder.forward).
  1901. audio_features = audio_features[audio_mask_from_encoder]
  1902. n_audio_tokens = audio_mask.sum()
  1903. audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1904. torch_compilable_check(
  1905. inputs_embeds[audio_mask].numel() == audio_features.numel(),
  1906. f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features:"
  1907. f" {audio_features.shape[0] * audio_features.shape[1]}",
  1908. )
  1909. inputs_embeds = inputs_embeds.masked_scatter(
  1910. audio_mask.to(inputs_embeds.device), audio_features.to(inputs_embeds.device)
  1911. )
  1912. # It may already have been prepared by, e.g., `generate`
  1913. if position_ids is None:
  1914. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1915. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  1916. position_ids = position_ids.unsqueeze(0)
  1917. if not isinstance(causal_mask_mapping := attention_mask, dict):
  1918. if self.config.get_text_config().use_bidirectional_attention == "vision":
  1919. # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs
  1920. causal_mask_mapping = create_causal_mask_mapping(
  1921. self.config,
  1922. inputs_embeds,
  1923. attention_mask,
  1924. past_key_values,
  1925. position_ids,
  1926. mm_token_type_ids,
  1927. pixel_values,
  1928. is_training=self.training,
  1929. )
  1930. else:
  1931. # Smaller Gemma models use a conventional casual attention mask
  1932. causal_mask_mapping = create_masks_for_generate(
  1933. self.config,
  1934. inputs_embeds,
  1935. attention_mask,
  1936. past_key_values,
  1937. position_ids,
  1938. )
  1939. outputs = self.language_model(
  1940. per_layer_inputs=per_layer_inputs,
  1941. attention_mask=causal_mask_mapping,
  1942. position_ids=position_ids,
  1943. past_key_values=past_key_values,
  1944. inputs_embeds=inputs_embeds,
  1945. use_cache=use_cache,
  1946. return_dict=True,
  1947. **kwargs,
  1948. )
  1949. return Gemma4ModelOutputWithPast(
  1950. last_hidden_state=outputs.last_hidden_state,
  1951. past_key_values=outputs.past_key_values,
  1952. hidden_states=outputs.hidden_states,
  1953. attentions=outputs.attentions,
  1954. image_hidden_states=image_features if pixel_values is not None else None,
  1955. audio_hidden_states=audio_features if input_features is not None else None,
  1956. )
  1957. @can_return_tuple
  1958. @auto_docstring(custom_intro="Projects the last hidden state from the audio encoder into language model space.")
  1959. def get_audio_features(
  1960. self,
  1961. input_features: torch.Tensor,
  1962. input_features_mask: torch.Tensor,
  1963. **kwargs: Unpack[TransformersKwargs],
  1964. ) -> tuple | Gemma4AudioModelOutput:
  1965. r"""
  1966. input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`):
  1967. The tensors corresponding to the input audio.
  1968. input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
  1969. The attention mask for the input audio.
  1970. """
  1971. if self.audio_tower is None:
  1972. raise ValueError(
  1973. "Audio features were requested, but the model was initialized without an audio_config. "
  1974. "Cannot process audio without an audio tower and audio embedder."
  1975. )
  1976. audio_outputs = self.audio_tower(input_features, input_features_mask, return_dict=True, **kwargs)
  1977. audio_outputs.pooler_output = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state)
  1978. return audio_outputs
  1979. @can_return_tuple
  1980. @auto_docstring(custom_intro="Projects the last hidden state from the vision encoder into language model space.")
  1981. def get_video_features(
  1982. self,
  1983. pixel_values_videos: torch.FloatTensor,
  1984. video_position_ids: torch.LongTensor | None = None,
  1985. **kwargs: Unpack[TransformersKwargs],
  1986. ) -> BaseModelOutputWithPooling:
  1987. r"""
  1988. video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*):
  1989. 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding.
  1990. Passed through to the vision encoder for positional embedding computation.
  1991. """
  1992. pixel_values_videos = pixel_values_videos.flatten(0, 1)
  1993. video_position_ids = video_position_ids.flatten(0, 1)
  1994. vision_outputs = self.vision_tower(
  1995. pixel_values=pixel_values_videos,
  1996. pixel_position_ids=video_position_ids,
  1997. **kwargs,
  1998. )
  1999. last_hidden_state = vision_outputs.last_hidden_state
  2000. vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state)
  2001. return vision_outputs
  2002. @auto_docstring(
  2003. custom_intro="""
  2004. The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling
  2005. head.
  2006. """
  2007. )
  2008. class Gemma4ForConditionalGeneration(Gemma4PreTrainedModel, GenerationMixin):
  2009. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  2010. base_model_prefix = "model"
  2011. def __init__(self, config: Gemma4Config):
  2012. super().__init__(config)
  2013. self.model = Gemma4Model(config)
  2014. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  2015. # Grab the ones from the child
  2016. self._keys_to_ignore_on_load_unexpected = [
  2017. f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
  2018. ]
  2019. self.post_init()
  2020. def get_input_embeddings(self):
  2021. return self.model.get_input_embeddings()
  2022. def set_input_embeddings(self, value):
  2023. self.model.set_input_embeddings(value)
  2024. @auto_docstring
  2025. def get_image_features(
  2026. self,
  2027. pixel_values: torch.FloatTensor,
  2028. image_position_ids: torch.LongTensor | None = None,
  2029. **kwargs: Unpack[TransformersKwargs],
  2030. ):
  2031. r"""
  2032. image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*):
  2033. 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding.
  2034. Passed through to the vision encoder for positional embedding computation.
  2035. """
  2036. return self.model.get_image_features(pixel_values, image_position_ids, **kwargs)
  2037. @can_return_tuple
  2038. @auto_docstring
  2039. def forward(
  2040. self,
  2041. input_ids: torch.LongTensor | None = None,
  2042. pixel_values: torch.FloatTensor | None = None,
  2043. pixel_values_videos: torch.FloatTensor | None = None,
  2044. input_features: torch.FloatTensor | None = None,
  2045. attention_mask: torch.Tensor | None = None,
  2046. input_features_mask: torch.Tensor | None = None,
  2047. position_ids: torch.LongTensor | None = None,
  2048. image_position_ids: torch.LongTensor | None = None,
  2049. video_position_ids: torch.LongTensor | None = None,
  2050. past_key_values: Cache | None = None,
  2051. mm_token_type_ids: torch.LongTensor | None = None,
  2052. inputs_embeds: torch.FloatTensor | None = None,
  2053. labels: torch.LongTensor | None = None,
  2054. use_cache: bool | None = None,
  2055. logits_to_keep: int | torch.Tensor = 0,
  2056. **kwargs: Unpack[TransformersKwargs],
  2057. ) -> Gemma4CausalLMOutputWithPast:
  2058. r"""
  2059. input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
  2060. The attention mask for the input audio.
  2061. image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*):
  2062. 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding.
  2063. Passed through to the vision encoder for positional embedding computation.
  2064. video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*):
  2065. 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding.
  2066. Passed through to the vision encoder for positional embedding computation.
  2067. """
  2068. outputs = self.model(
  2069. input_ids=input_ids,
  2070. pixel_values=pixel_values,
  2071. pixel_values_videos=pixel_values_videos,
  2072. input_features=input_features,
  2073. attention_mask=attention_mask,
  2074. input_features_mask=input_features_mask,
  2075. position_ids=position_ids,
  2076. past_key_values=past_key_values,
  2077. mm_token_type_ids=mm_token_type_ids,
  2078. inputs_embeds=inputs_embeds,
  2079. labels=labels,
  2080. use_cache=use_cache,
  2081. image_position_ids=image_position_ids,
  2082. video_position_ids=video_position_ids,
  2083. return_dict=True,
  2084. **kwargs,
  2085. )
  2086. hidden_states = outputs.last_hidden_state
  2087. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  2088. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  2089. logits = self.lm_head(hidden_states[:, slice_indices, :])
  2090. if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None:
  2091. logits = logits / final_logit_softcapping
  2092. logits = torch.tanh(logits)
  2093. logits = logits * final_logit_softcapping
  2094. loss = None
  2095. if labels is not None:
  2096. # Upcast to float if we need to compute the loss to avoid potential precision issues
  2097. logits = logits.float()
  2098. shift_logits = logits[..., :-1, :]
  2099. shift_labels = labels[..., 1:]
  2100. if attention_mask is not None:
  2101. # we use the input attention mask to shift the logits and labels, because it is 2D.
  2102. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
  2103. shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
  2104. shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
  2105. shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
  2106. else:
  2107. shift_logits = shift_logits.contiguous()
  2108. shift_labels = shift_labels.contiguous()
  2109. # Flatten the tokens
  2110. loss_fct = nn.CrossEntropyLoss()
  2111. flat_logits = shift_logits.view(-1, self.config.get_text_config().vocab_size)
  2112. flat_labels = shift_labels.view(-1).to(shift_logits.device)
  2113. loss = loss_fct(flat_logits, flat_labels)
  2114. return Gemma4CausalLMOutputWithPast(
  2115. loss=loss,
  2116. logits=logits,
  2117. past_key_values=outputs.past_key_values,
  2118. hidden_states=outputs.hidden_states,
  2119. attentions=outputs.attentions,
  2120. image_hidden_states=outputs.image_hidden_states,
  2121. audio_hidden_states=outputs.audio_hidden_states,
  2122. )
  2123. def prepare_inputs_for_generation(
  2124. self,
  2125. input_ids,
  2126. past_key_values=None,
  2127. inputs_embeds=None,
  2128. position_ids=None,
  2129. pixel_values=None,
  2130. pixel_values_videos=None,
  2131. input_features=None,
  2132. attention_mask=None,
  2133. input_features_mask=None,
  2134. token_type_ids=None,
  2135. use_cache=True,
  2136. logits_to_keep=None,
  2137. labels=None,
  2138. is_first_iteration=False,
  2139. **kwargs,
  2140. ):
  2141. # Overwritten -- custom `position_ids` and `pixel_values` handling
  2142. model_inputs = super().prepare_inputs_for_generation(
  2143. input_ids,
  2144. past_key_values=past_key_values,
  2145. inputs_embeds=inputs_embeds,
  2146. attention_mask=attention_mask,
  2147. position_ids=position_ids,
  2148. use_cache=use_cache,
  2149. logits_to_keep=logits_to_keep,
  2150. token_type_ids=token_type_ids,
  2151. is_first_iteration=is_first_iteration,
  2152. **kwargs,
  2153. )
  2154. # If we're in cached decoding stage, multimodal inputs are already cached and can be dropped
  2155. if is_first_iteration or not use_cache:
  2156. model_inputs["pixel_values"] = pixel_values
  2157. model_inputs["pixel_values_videos"] = pixel_values_videos
  2158. model_inputs["input_features"] = input_features
  2159. model_inputs["input_features_mask"] = input_features_mask
  2160. return model_inputs
  2161. @staticmethod
  2162. def create_masks_for_generate(
  2163. config: PreTrainedConfig,
  2164. inputs_embeds: torch.Tensor,
  2165. attention_mask: torch.Tensor | None,
  2166. past_key_values: Cache | None,
  2167. position_ids: torch.Tensor | None,
  2168. mm_token_type_ids: torch.Tensor | None = None,
  2169. is_first_iteration: bool | None = False,
  2170. **kwargs,
  2171. ) -> dict:
  2172. if getattr(config.get_text_config(), "use_bidirectional_attention", None) == "vision":
  2173. # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs
  2174. return create_causal_mask_mapping(
  2175. config,
  2176. inputs_embeds,
  2177. attention_mask,
  2178. past_key_values,
  2179. position_ids,
  2180. mm_token_type_ids,
  2181. is_first_iteration=is_first_iteration,
  2182. **{k: v for k, v in kwargs.items() if k != "pixel_values"},
  2183. )
  2184. else:
  2185. # Smaller Gemma models use a conventional casual attention mask
  2186. return create_masks_for_generate(
  2187. config, inputs_embeds, attention_mask, past_key_values, position_ids, **kwargs
  2188. )
  2189. __all__ = [
  2190. "Gemma4AudioModel",
  2191. "Gemma4ForCausalLM",
  2192. "Gemma4ForConditionalGeneration",
  2193. "Gemma4Model",
  2194. "Gemma4PreTrainedModel",
  2195. "Gemma4TextModel",
  2196. "Gemma4VisionModel",
  2197. ]