modular_gemma4.py 97 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196
  1. # Copyright 2026 the HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. from functools import cached_property
  18. import torch
  19. from torch import nn
  20. from torch.nn import functional as F
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache
  24. from ...configuration_utils import PreTrainedConfig
  25. from ...integrations import use_kernelized_func
  26. from ...masking_utils import (
  27. create_bidirectional_mask,
  28. create_causal_mask,
  29. create_masks_for_generate,
  30. create_sliding_window_causal_mask,
  31. )
  32. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  33. from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
  34. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  35. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  36. from ...processing_utils import Unpack
  37. from ...utils import (
  38. TransformersKwargs,
  39. auto_docstring,
  40. can_return_tuple,
  41. logging,
  42. torch_compilable_check,
  43. )
  44. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  45. from ...utils.output_capturing import OutputRecorder, capture_outputs
  46. from ..auto.modeling_auto import AutoModel
  47. from ..gemma3.modeling_gemma3 import (
  48. Gemma3Attention,
  49. Gemma3DecoderLayer,
  50. Gemma3ForCausalLM,
  51. Gemma3MLP,
  52. Gemma3RotaryEmbedding,
  53. Gemma3TextModel,
  54. Gemma3TextScaledWordEmbedding,
  55. )
  56. from ..gemma3n.modeling_gemma3n import (
  57. Gemma3nCausalLMOutputWithPast,
  58. Gemma3nForConditionalGeneration,
  59. Gemma3nModel,
  60. Gemma3nModelOutputWithPast,
  61. Gemma3nMultimodalEmbedder,
  62. Gemma3nRMSNorm,
  63. apply_rotary_pos_emb,
  64. eager_attention_forward,
  65. )
  66. from ..llama.modeling_llama import LlamaRotaryEmbedding
  67. from ..mixtral.modeling_mixtral import MixtralExperts
  68. from ..moonshine_streaming.modeling_moonshine_streaming import sliding_window_mask_function
  69. from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig
  70. logger = logging.get_logger(__name__)
  71. class Gemma4ModelOutputWithPast(Gemma3nModelOutputWithPast):
  72. pass
  73. class Gemma4CausalLMOutputWithPast(Gemma3nCausalLMOutputWithPast):
  74. pass
  75. @dataclass
  76. @auto_docstring
  77. class Gemma4AudioModelOutput(BaseModelOutputWithPooling):
  78. r"""
  79. attention_mask (`torch.BoolTensor`, *optional*):
  80. A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding.
  81. """
  82. attention_mask: torch.BoolTensor | None = None
  83. class Gemma4ClippableLinear(nn.Module):
  84. def __init__(
  85. self,
  86. config: Gemma4VisionConfig | Gemma4AudioConfig,
  87. in_features: int,
  88. out_features: int,
  89. ) -> None:
  90. super().__init__()
  91. self.use_clipped_linears = config.use_clipped_linears
  92. self.linear = nn.Linear(in_features, out_features, bias=False)
  93. if self.use_clipped_linears:
  94. self.register_buffer("input_min", torch.tensor(-float("inf")))
  95. self.register_buffer("input_max", torch.tensor(float("inf")))
  96. self.register_buffer("output_min", torch.tensor(-float("inf")))
  97. self.register_buffer("output_max", torch.tensor(float("inf")))
  98. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  99. if self.use_clipped_linears:
  100. hidden_states = torch.clamp(hidden_states, self.input_min, self.input_max)
  101. hidden_states = self.linear(hidden_states)
  102. if self.use_clipped_linears:
  103. hidden_states = torch.clamp(hidden_states, self.output_min, self.output_max)
  104. return hidden_states
  105. class Gemma4RMSNorm(Gemma3nRMSNorm):
  106. pass
  107. class Gemma4AudioRelPositionalEncoding(nn.Module):
  108. """Sinusoidal relative positional encoding for the audio encoder.
  109. Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with
  110. concatenated [sin..., cos...] layout matching the original Gemma4 convention.
  111. """
  112. inv_timescales: torch.Tensor
  113. def __init__(self, config: Gemma4AudioConfig):
  114. super().__init__()
  115. self.hidden_size = config.hidden_size
  116. self.context_size = (
  117. config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right
  118. )
  119. min_timescale = 1.0
  120. max_timescale = 10000.0
  121. num_timescales = self.hidden_size // 2
  122. log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1)
  123. inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
  124. self.register_buffer("inv_timescales", inv_timescales.unsqueeze(0).unsqueeze(0), persistent=False)
  125. @torch.no_grad()
  126. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  127. position_ids = torch.arange(12, -1, -1, device=hidden_states.device)
  128. position_ids = position_ids[..., None]
  129. scaled_time = position_ids * self.inv_timescales.to(device=hidden_states.device)
  130. pos_embed = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
  131. return pos_embed.to(dtype=hidden_states.dtype)
  132. class Gemma4AudioAttention(nn.Module):
  133. """Chunked local attention with relative position bias"""
  134. def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
  135. super().__init__()
  136. self.config = config
  137. self.layer_idx = layer_idx
  138. self.attention_logits_soft_cap = config.attention_logit_cap
  139. self.head_dim = config.hidden_size // config.num_attention_heads
  140. self.num_heads = config.num_attention_heads
  141. self.q_scale = (self.head_dim**-0.5) / math.log(2)
  142. self.k_scale = math.log(1 + math.e) / math.log(2)
  143. self.chunk_size = config.attention_chunk_size
  144. self.max_past_horizon = config.attention_context_left - 1
  145. self.max_future_horizon = config.attention_context_right
  146. self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
  147. self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
  148. self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
  149. self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
  150. self.post = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
  151. self.relative_k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
  152. self.per_dim_scale = nn.Parameter(torch.zeros(self.head_dim))
  153. self.register_buffer("softcap", torch.tensor(self.attention_logits_soft_cap), persistent=False)
  154. def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
  155. """Splits a `(batch_size, seq_len, num_heads, head_dim)` tensor into non-overlapping blocks of `chunk_size` along the sequence dim."""
  156. batch_size, seq_len, num_heads, head_dim = hidden_states.shape
  157. num_blocks = (seq_len + self.chunk_size - 1) // self.chunk_size
  158. pad = num_blocks * self.chunk_size - seq_len
  159. hidden_states = F.pad(hidden_states, (0, 0, 0, 0, 0, pad))
  160. return hidden_states.reshape(batch_size, num_blocks, self.chunk_size, num_heads, head_dim).contiguous()
  161. def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
  162. """Extracts overlapping context windows of `context_size` for every block, strided by `chunk_size`."""
  163. batch_size, seq_len, num_heads, head_dim = hidden_states.shape
  164. hidden_states = F.pad(
  165. hidden_states, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1)
  166. )
  167. hidden_states = hidden_states.unfold(1, self.context_size, self.chunk_size)
  168. hidden_states = torch.movedim(hidden_states, -1, 2)
  169. return hidden_states.contiguous()
  170. def _rel_shift(self, x: torch.Tensor) -> torch.Tensor:
  171. """Relative position shift for blocked attention. See appendix B of https://huggingface.co/papers/1901.02860."""
  172. batch_size, num_heads, num_blocks, block_size, position_length = x.shape
  173. context_size = self.context_size
  174. x = F.pad(x, (0, context_size + 1 - position_length))
  175. x = x.view(batch_size, num_heads, num_blocks, block_size * (context_size + 1))
  176. x = x[..., : block_size * context_size]
  177. return x.view(batch_size, num_heads, num_blocks, block_size, context_size)
  178. def forward(
  179. self,
  180. hidden_states: torch.Tensor,
  181. position_embeddings: torch.Tensor,
  182. attention_mask: torch.BoolTensor | None = None,
  183. ) -> tuple[torch.Tensor, None]:
  184. batch_size, seq_length, _ = hidden_states.shape
  185. hidden_shape = (batch_size, seq_length, self.num_heads, self.head_dim)
  186. query_states = self.q_proj(hidden_states).float().view(hidden_shape)
  187. key_states = self.k_proj(hidden_states).float().view(hidden_shape)
  188. value_states = self.v_proj(hidden_states).float().view(hidden_shape)
  189. query_states = query_states * self.q_scale * F.softplus(self.per_dim_scale)
  190. key_states = key_states * self.k_scale
  191. query_states = self._convert_to_block(query_states)
  192. key_states = self._extract_block_context(key_states)
  193. value_states = self._extract_block_context(value_states)
  194. num_blocks = query_states.shape[1]
  195. relative_key_states = self.relative_k_proj(position_embeddings)
  196. relative_key_states = relative_key_states.view(-1, self.num_heads, self.head_dim)
  197. relative_key_states = relative_key_states.to(dtype=query_states.dtype)
  198. queries = query_states.permute(0, 3, 1, 2, 4)
  199. matrix_ac = queries @ key_states.permute(0, 3, 1, 4, 2)
  200. queries_flat = queries.reshape(batch_size, self.num_heads, -1, self.head_dim)
  201. matrix_bd = queries_flat @ relative_key_states.permute(1, 2, 0)
  202. matrix_bd = matrix_bd.reshape(batch_size, self.num_heads, num_blocks, self.chunk_size, -1)
  203. matrix_bd = self._rel_shift(matrix_bd)
  204. attn_weights = matrix_ac + matrix_bd
  205. attn_weights = attn_weights / self.softcap
  206. attn_weights = torch.tanh(attn_weights)
  207. attn_weights = attn_weights * self.softcap
  208. if attention_mask is not None:
  209. attn_weights = attn_weights.masked_fill(
  210. attention_mask.logical_not(), self.config.attention_invalid_logits_value
  211. )
  212. attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
  213. attn_output = attn_weights @ value_states.permute(0, 3, 1, 2, 4)
  214. attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, num_blocks * self.chunk_size, -1)
  215. attn_output = attn_output[:, :seq_length].contiguous()
  216. attn_output = self.post(attn_output.to(dtype=self.post.linear.weight.dtype))
  217. return attn_output, attn_weights
  218. class Gemma4AudioSubSampleConvProjectionLayer(nn.Module):
  219. def __init__(self, in_channels, out_channels, norm_eps):
  220. super().__init__()
  221. self.conv = nn.Conv2d(
  222. in_channels=in_channels,
  223. out_channels=out_channels,
  224. kernel_size=(3, 3),
  225. stride=(2, 2),
  226. padding=1,
  227. bias=False,
  228. )
  229. self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False)
  230. self.act = nn.ReLU()
  231. def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor | None = None):
  232. if mask is not None:
  233. mask = mask.to(device=hidden_states.device)
  234. hidden_states = hidden_states * mask[:, None, :, None]
  235. hidden_states = self.conv(hidden_states.to(self.conv.weight.dtype))
  236. hidden_states = self.act(self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous())
  237. if mask is not None:
  238. mask = mask[:, ::2]
  239. return hidden_states, mask
  240. class Gemma4AudioSubSampleConvProjection(nn.Module):
  241. def __init__(self, config: Gemma4AudioConfig):
  242. super().__init__()
  243. self.layer0 = Gemma4AudioSubSampleConvProjectionLayer(
  244. in_channels=1,
  245. out_channels=config.subsampling_conv_channels[0],
  246. norm_eps=config.rms_norm_eps,
  247. )
  248. self.layer1 = Gemma4AudioSubSampleConvProjectionLayer(
  249. in_channels=config.subsampling_conv_channels[0],
  250. out_channels=config.subsampling_conv_channels[1],
  251. norm_eps=config.rms_norm_eps,
  252. )
  253. proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1]
  254. self.input_proj_linear = nn.Linear(proj_input_dim, config.hidden_size, bias=False)
  255. def forward(
  256. self,
  257. input_features: torch.Tensor,
  258. input_features_mask: torch.Tensor | None = None,
  259. ) -> tuple[torch.Tensor, torch.Tensor]:
  260. hidden_states = input_features.unsqueeze(1)
  261. hidden_states, mask = self.layer0(hidden_states, input_features_mask)
  262. hidden_states, mask = self.layer1(hidden_states, mask)
  263. batch_size, _, seq_len, _ = hidden_states.shape
  264. hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1)
  265. return self.input_proj_linear(hidden_states), mask
  266. class Gemma4AudioFeedForward(nn.Module):
  267. def __init__(self, config: Gemma4AudioConfig):
  268. super().__init__()
  269. self.config = config
  270. self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4)
  271. self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size)
  272. self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size)
  273. self.post_layer_norm = Gemma4RMSNorm(config.hidden_size)
  274. self.act_fn = ACT2FN[config.hidden_act]
  275. self.gradient_clipping = config.gradient_clipping
  276. self.post_layer_scale = config.residual_weight
  277. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  278. # This is needed to avoid any underflow/overflow issues when clipping
  279. gradient_clipping = min(self.gradient_clipping, torch.finfo(self.ffw_layer_1.linear.weight.dtype).max)
  280. residual = hidden_states
  281. hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
  282. hidden_states = self.pre_layer_norm(hidden_states)
  283. hidden_states = self.ffw_layer_1(hidden_states)
  284. hidden_states = self.act_fn(hidden_states)
  285. hidden_states = self.ffw_layer_2(hidden_states)
  286. hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
  287. hidden_states = self.post_layer_norm(hidden_states)
  288. hidden_states *= self.post_layer_scale
  289. hidden_states += residual
  290. return hidden_states
  291. # TODO: this could be imported from Voxtral realtime
  292. class Gemma4AudioCausalConv1d(nn.Conv1d):
  293. # def __init__(
  294. # self,
  295. # in_channels: int,
  296. # out_channels: int,
  297. # kernel_size: int,
  298. # # cache_key: str,
  299. # stride: int = 1,
  300. # dilation: int = 1,
  301. # bias: bool = True,
  302. # ):
  303. # super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias)
  304. # self.cache_key = cache_key
  305. @cached_property
  306. def left_pad(self):
  307. effective_kernel_size = (self.kernel_size[0] - 1) * self.dilation[0] + 1
  308. return effective_kernel_size - self.stride[0]
  309. def forward(
  310. self,
  311. x: torch.Tensor,
  312. # padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, # TODO: we might want to add a cache?
  313. ) -> torch.Tensor:
  314. # if padding_cache is not None:
  315. # x = padding_cache.update(x, self.cache_key, self)
  316. # else:
  317. # x = nn.functional.pad(x, (self.left_pad, 0))
  318. x = nn.functional.pad(x, (self.left_pad, 0))
  319. return super().forward(x)
  320. class Gemma4AudioLightConv1d(nn.Module):
  321. def __init__(self, config: Gemma4AudioConfig):
  322. super().__init__()
  323. self.config = config
  324. self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2)
  325. self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
  326. self.depthwise_conv1d = Gemma4AudioCausalConv1d(
  327. in_channels=config.hidden_size,
  328. out_channels=config.hidden_size,
  329. kernel_size=config.conv_kernel_size,
  330. groups=config.hidden_size,
  331. bias=False,
  332. )
  333. self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True)
  334. self.conv_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True)
  335. self.act_fn = ACT2FN[config.hidden_act]
  336. self.gradient_clipping = config.gradient_clipping
  337. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  338. residual = hidden_states
  339. hidden_states = self.pre_layer_norm(hidden_states)
  340. hidden_states = self.linear_start(hidden_states)
  341. hidden_states = nn.functional.glu(hidden_states, dim=-1)
  342. hidden_states = self.depthwise_conv1d(hidden_states.transpose(1, 2)).transpose(1, 2)
  343. # This is needed to avoid any underflow/overflow issues when clipping
  344. gradient_clipping = min(self.gradient_clipping, torch.finfo(self.linear_start.linear.weight.dtype).max)
  345. hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
  346. hidden_states = self.conv_norm(hidden_states)
  347. hidden_states = self.act_fn(hidden_states)
  348. hidden_states = self.linear_end(hidden_states)
  349. hidden_states += residual
  350. return hidden_states
  351. class Gemma4AudioLayer(nn.Module):
  352. def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
  353. super().__init__()
  354. self.config = config
  355. self.feed_forward1 = Gemma4AudioFeedForward(config)
  356. self.feed_forward2 = Gemma4AudioFeedForward(config)
  357. self.self_attn = Gemma4AudioAttention(config, layer_idx)
  358. self.lconv1d = Gemma4AudioLightConv1d(config)
  359. self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size)
  360. self.norm_post_attn = Gemma4RMSNorm(config.hidden_size)
  361. self.norm_out = Gemma4RMSNorm(config.hidden_size)
  362. self.gradient_clipping = config.gradient_clipping
  363. def forward(
  364. self,
  365. hidden_states: torch.Tensor,
  366. attention_mask: torch.BoolTensor | None,
  367. position_embeddings: torch.Tensor,
  368. **kwargs: Unpack[TransformersKwargs],
  369. ) -> torch.Tensor:
  370. # This is needed to avoid any underflow/overflow issues when clipping
  371. gradient_clipping = min(self.gradient_clipping, torch.finfo(self.norm_pre_attn.weight.dtype).max)
  372. hidden_states = self.feed_forward1(hidden_states)
  373. residual = hidden_states
  374. hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
  375. hidden_states = self.norm_pre_attn(hidden_states)
  376. hidden_states, _ = self.self_attn(
  377. hidden_states=hidden_states,
  378. position_embeddings=position_embeddings,
  379. attention_mask=attention_mask,
  380. )
  381. hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
  382. hidden_states = self.norm_post_attn(hidden_states)
  383. hidden_states += residual
  384. hidden_states = self.lconv1d(hidden_states)
  385. hidden_states = self.feed_forward2(hidden_states)
  386. hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
  387. hidden_states = self.norm_out(hidden_states)
  388. return hidden_states
  389. # ---- Vision Encoder Layers ----
  390. class Gemma4VisionPatchEmbedder(nn.Module):
  391. def __init__(self, config: Gemma4VisionConfig):
  392. super().__init__()
  393. self.config = config
  394. self.hidden_size = config.hidden_size
  395. self.patch_size = config.patch_size
  396. self.position_embedding_size = config.position_embedding_size
  397. self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False)
  398. self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size))
  399. def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor:
  400. """Prepare patch positions map for matmul with positon embedding table."""
  401. # Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul.
  402. clamped_positions = pixel_position_ids.clamp(min=0)
  403. one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size)
  404. one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table)
  405. # Compute positional embeddings and sum across x and y.
  406. position_embeddings = one_hot @ self.position_embedding_table
  407. position_embeddings = position_embeddings.sum(dim=1)
  408. # Zero out embeddings for any padding patches.
  409. position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
  410. return position_embeddings
  411. def forward(
  412. self, pixel_values: torch.Tensor, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor
  413. ) -> torch.Tensor:
  414. # Gemma4 applies no normalization and instead scales in model code
  415. pixel_values = 2 * (pixel_values - 0.5)
  416. hidden_states = self.input_proj(pixel_values.to(self.input_proj.weight.dtype))
  417. position_embeddings = self._position_embeddings(pixel_position_ids, padding_positions)
  418. return hidden_states + position_embeddings
  419. class Gemma4VisionPooler(nn.Module):
  420. """Scaling and optional spatial pooling for vision encodings"""
  421. def __init__(self, config: Gemma4VisionConfig):
  422. super().__init__()
  423. self.hidden_size = config.hidden_size
  424. self.root_hidden_size = self.hidden_size**0.5
  425. def _avg_pool_by_positions(
  426. self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int
  427. ) -> tuple[torch.Tensor, torch.Tensor]:
  428. """
  429. 2D spatial pooling according to patch positions.
  430. Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between
  431. input and output lengths
  432. """
  433. input_seq_len = hidden_states.shape[1]
  434. k = int((input_seq_len // length) ** 0.5)
  435. k_squared = k**2
  436. if k_squared * length != input_seq_len:
  437. raise ValueError(
  438. f"Cannot pool {hidden_states.shape} to {length}: {k=}^2 times {length=} must be {input_seq_len}."
  439. )
  440. # Clamp padding positions (which are -1) to 0 so they don't break one_hot.
  441. # Padding patches have zero hidden states so they contribute nothing to the average.
  442. clamped_positions = pixel_position_ids.clamp(min=0)
  443. max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1
  444. kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor")
  445. kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1]
  446. weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared
  447. output = weights.transpose(1, 2) @ hidden_states.float()
  448. mask = torch.logical_not((weights == 0).all(dim=1))
  449. return output.to(hidden_states.dtype), mask
  450. def forward(
  451. self,
  452. hidden_states: torch.Tensor,
  453. pixel_position_ids: torch.Tensor,
  454. padding_positions: torch.Tensor,
  455. output_length: int | None = None,
  456. ) -> tuple[torch.Tensor, torch.Tensor]:
  457. if output_length > hidden_states.shape[1]:
  458. raise ValueError(
  459. f"Cannot output more soft tokens (requested {output_length}) than there are patches"
  460. f" ({hidden_states.shape[1]}). Change the value of `num_soft_tokens` when processing."
  461. )
  462. hidden_states = hidden_states.masked_fill(padding_positions.unsqueeze(-1), 0.0)
  463. if hidden_states.shape[1] != output_length:
  464. hidden_states, padding_positions = self._avg_pool_by_positions(
  465. hidden_states, pixel_position_ids, output_length
  466. )
  467. hidden_states *= self.root_hidden_size
  468. return hidden_states, padding_positions
  469. class Gemma4VisionMLP(Gemma3MLP):
  470. def __init__(self, config: Gemma4VisionConfig):
  471. super().__init__(self, config)
  472. self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
  473. self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
  474. self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size)
  475. def apply_multidimensional_rope(
  476. x: torch.Tensor,
  477. cos: torch.Tensor,
  478. sin: torch.Tensor,
  479. position_ids: torch.Tensor,
  480. unsqueeze_dim: int = 2,
  481. ) -> torch.Tensor:
  482. """Applies multidimensional RoPE to inputs.
  483. Args:
  484. x (`torch.Tensor`): The tensor to embed.
  485. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  486. sin (`torch.Tensor`): The sine part of the rotary embedding.
  487. position_ids (`torch.Tensor`, *optional*):
  488. If position_ids.ndim + 2 == x.ndim, then this function passes through to `apply_rotary_pos_emb()`.
  489. Otherwise, position_ids is used to split the inputs, x, into multiple pieces, where each piece is fed to
  490. `apply_rotary_pos_emb()`, and then concatenated back together.
  491. unsqueeze_dim (`int`, *optional*, defaults to 1):
  492. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  493. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  494. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  495. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  496. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  497. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  498. Returns:
  499. Tensor of shape [B, L, N, H] with RoPE applied.
  500. """
  501. ndim = position_ids.shape[-1]
  502. num_input_channels = x.shape[-1]
  503. num_rotated_channels_per_dim = 2 * (num_input_channels // (2 * ndim))
  504. if num_rotated_channels_per_dim <= 0:
  505. raise ValueError(
  506. "Invalid configuration: num_rotated_channels_per_dim must be > 0, got"
  507. f" {num_rotated_channels_per_dim} (num_input_channels={num_input_channels},"
  508. f" ndim={ndim})"
  509. )
  510. # Correctly split the input tensor into ndim parts
  511. split_sizes = [num_rotated_channels_per_dim] * ndim
  512. x_parts = torch.split(x, split_sizes, dim=-1)
  513. cos_parts = torch.split(cos, split_sizes, dim=-1)
  514. sin_parts = torch.split(sin, split_sizes, dim=-1)
  515. y_parts = [
  516. apply_rotary_pos_emb(
  517. x=x_parts[k],
  518. cos=cos_parts[k],
  519. sin=sin_parts[k],
  520. unsqueeze_dim=unsqueeze_dim,
  521. )
  522. for k in range(ndim)
  523. ]
  524. return torch.cat(y_parts, dim=-1)
  525. class Gemma4VisionRotaryEmbedding(LlamaRotaryEmbedding):
  526. @staticmethod
  527. def compute_default_rope_parameters(
  528. config: Gemma4VisionConfig | None = None,
  529. device: torch.device | None = None,
  530. seq_len: int | None = None,
  531. ) -> tuple["torch.Tensor", float]:
  532. """
  533. Computes the inverse frequencies according to the original RoPE implementation
  534. Args:
  535. config ([`~transformers.PreTrainedConfig`]):
  536. The model configuration.
  537. device (`torch.device`):
  538. The device to use for initialization of the inverse frequencies.
  539. seq_len (`int`, *optional*):
  540. The current sequence length. Unused for this type of RoPE.
  541. Returns:
  542. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  543. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  544. """
  545. base = config.rope_parameters["rope_theta"]
  546. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  547. # The reference implementation computes RoPE frequencies INDEPENDENTLY
  548. # for each spatial dimension using the partitioned head_dim (head_dim // ndim),
  549. # so both x and y dimensions get identical frequency ranges.
  550. # This is different from splitting the global inv_freq between dimensions.
  551. spatial_dim = dim // 2
  552. attention_factor = 1.0 # Unused in this type of RoPE
  553. inv_freq = 1.0 / (
  554. base
  555. ** (torch.arange(0, spatial_dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / spatial_dim)
  556. )
  557. return inv_freq, attention_factor
  558. @torch.no_grad()
  559. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  560. def forward(self, x, position_ids):
  561. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  562. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  563. # Multidimensional positions: [batch, num_patches, ndim]. Apply rotations to each spatial dim separately
  564. all_cos, all_sin = [], []
  565. for i in range(2):
  566. dim_position_ids = position_ids[:, :, i]
  567. dim_position_ids_expanded = dim_position_ids[:, None, :].float()
  568. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  569. freqs = (inv_freq_expanded.float() @ dim_position_ids_expanded.float()).transpose(1, 2)
  570. emb = torch.cat((freqs, freqs), dim=-1)
  571. cos = emb.cos() * self.attention_scaling
  572. sin = emb.sin() * self.attention_scaling
  573. all_cos.append(cos)
  574. all_sin.append(sin)
  575. cos = torch.cat(all_cos, dim=-1).to(dtype=x.dtype)
  576. sin = torch.cat(all_sin, dim=-1).to(dtype=x.dtype)
  577. return cos, sin
  578. class Gemma4VisionAttention(Gemma3Attention):
  579. def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
  580. super().__init__(self, config, layer_idx)
  581. del self.attn_logit_softcapping
  582. del self.sliding_window
  583. del self.is_sliding
  584. self.scaling = 1.0
  585. self.is_causal = False
  586. self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
  587. self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_attention_heads * self.head_dim)
  588. self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
  589. self.o_proj = Gemma4ClippableLinear(config, config.num_attention_heads * self.head_dim, config.hidden_size)
  590. self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False)
  591. def forward(
  592. self,
  593. hidden_states: torch.Tensor,
  594. position_embeddings: torch.Tensor = None,
  595. attention_mask: torch.Tensor | None = None,
  596. position_ids: torch.LongTensor | None = None,
  597. **kwargs: Unpack[TransformersKwargs],
  598. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  599. input_shape = hidden_states.shape[:-1]
  600. hidden_shape = (*input_shape, -1, self.head_dim)
  601. cos, sin = position_embeddings
  602. query_states = self.q_proj(hidden_states).view(hidden_shape)
  603. query_states = self.q_norm(query_states)
  604. query_states = apply_multidimensional_rope(query_states, cos, sin, position_ids)
  605. query_states = query_states.transpose(1, 2)
  606. key_states = self.k_proj(hidden_states).view(hidden_shape)
  607. key_states = self.k_norm(key_states)
  608. key_states = apply_multidimensional_rope(key_states, cos, sin, position_ids)
  609. key_states = key_states.transpose(1, 2)
  610. value_states = self.v_proj(hidden_states).view(hidden_shape)
  611. value_states = self.v_norm(value_states)
  612. value_states = value_states.transpose(1, 2)
  613. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  614. self.config._attn_implementation, eager_attention_forward
  615. )
  616. attn_output, attn_weights = attention_interface(
  617. self,
  618. query_states,
  619. key_states,
  620. value_states,
  621. attention_mask,
  622. dropout=self.attention_dropout if self.training else 0.0,
  623. scaling=self.scaling,
  624. **kwargs,
  625. )
  626. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  627. attn_output = self.o_proj(attn_output)
  628. return attn_output, attn_weights
  629. # Same forward as Gemma3 but no cache
  630. class Gemma4VisionEncoderLayer(Gemma3DecoderLayer):
  631. def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
  632. super().__init__(self, config, layer_idx)
  633. self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx)
  634. self.mlp = Gemma4VisionMLP(config)
  635. def forward(
  636. self,
  637. hidden_states: torch.Tensor,
  638. position_embeddings: torch.Tensor = None,
  639. attention_mask: torch.Tensor | None = None,
  640. position_ids: torch.LongTensor | None = None,
  641. **kwargs: Unpack[TransformersKwargs],
  642. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  643. residual = hidden_states
  644. hidden_states = self.input_layernorm(hidden_states)
  645. hidden_states, _ = self.self_attn(
  646. hidden_states=hidden_states,
  647. position_embeddings=position_embeddings,
  648. attention_mask=attention_mask,
  649. position_ids=position_ids,
  650. **kwargs,
  651. )
  652. hidden_states = self.post_attention_layernorm(hidden_states)
  653. hidden_states = residual + hidden_states
  654. residual = hidden_states
  655. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  656. hidden_states = self.mlp(hidden_states)
  657. hidden_states = self.post_feedforward_layernorm(hidden_states)
  658. hidden_states = residual + hidden_states
  659. return hidden_states
  660. class Gemma4VisionEncoder(nn.Module):
  661. def __init__(self, config: Gemma4VisionConfig):
  662. super().__init__()
  663. self.config = config
  664. self.num_layers = config.num_hidden_layers
  665. self.rotary_emb = Gemma4VisionRotaryEmbedding(config)
  666. self.layers = nn.ModuleList(
  667. [Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)]
  668. )
  669. def forward(
  670. self,
  671. inputs_embeds: torch.Tensor,
  672. attention_mask: torch.Tensor,
  673. pixel_position_ids: torch.LongTensor | None = None,
  674. **kwargs: Unpack[TransformersKwargs],
  675. ) -> BaseModelOutputWithPast:
  676. r"""
  677. pixel_position_ids (torch.Tensor):
  678. Patch positions as (x, y) coordinates in the image as [batch, num_patches, 2].
  679. """
  680. attention_mask = create_bidirectional_mask(
  681. config=self.config,
  682. inputs_embeds=inputs_embeds,
  683. attention_mask=attention_mask,
  684. )
  685. # embed positions
  686. hidden_states = inputs_embeds
  687. position_embeddings = self.rotary_emb(hidden_states, pixel_position_ids)
  688. # decoder layers
  689. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  690. hidden_states = decoder_layer(
  691. hidden_states,
  692. attention_mask=attention_mask,
  693. position_embeddings=position_embeddings,
  694. position_ids=pixel_position_ids,
  695. **kwargs,
  696. )
  697. return BaseModelOutputWithPast(last_hidden_state=hidden_states)
  698. # ---- Text model ----
  699. class Gemma4TextMLP(Gemma3MLP):
  700. def __init__(self, config: Gemma4TextConfig, layer_idx: int):
  701. first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers
  702. is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
  703. use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
  704. super().__init__()
  705. self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1)
  706. class Gemma4TextRotaryEmbedding(Gemma3RotaryEmbedding):
  707. def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None):
  708. nn.Module.__init__(self)
  709. self.max_seq_len_cached = config.max_position_embeddings
  710. self.original_max_seq_len = config.max_position_embeddings
  711. self.config = config
  712. self.layer_types = set(config.layer_types)
  713. self.rope_init_fns: dict[str, Callable[..., tuple[torch.Tensor, float]]] = {}
  714. self.rope_type: dict[str, str] = {}
  715. for layer_type in self.layer_types:
  716. rope_params = self.config.rope_parameters[layer_type]
  717. if rope_params is None:
  718. continue
  719. if (rope_type := rope_params["rope_type"]) != "default":
  720. rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
  721. else:
  722. rope_init_fn = self.compute_default_rope_parameters
  723. self.rope_init_fns[layer_type] = rope_init_fn
  724. self.rope_type[layer_type] = rope_type
  725. rope_init_fn_kwargs = {"device": device, "layer_type": layer_type}
  726. if layer_type == "full_attention" and rope_type == "proportional":
  727. rope_init_fn_kwargs["head_dim_key"] = "global_head_dim"
  728. curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, **rope_init_fn_kwargs)
  729. self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
  730. self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
  731. setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
  732. @use_kernelized_func(apply_rotary_pos_emb)
  733. class Gemma4TextAttention(nn.Module):
  734. """Multi-headed attention from 'Attention Is All You Need' paper"""
  735. def __init__(self, config: Gemma4TextConfig, layer_idx: int):
  736. super().__init__()
  737. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  738. self.config = config
  739. self.layer_idx = layer_idx
  740. self.is_sliding = self.layer_type == "sliding_attention"
  741. self.sliding_window = config.sliding_window if self.is_sliding else None
  742. self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
  743. self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding
  744. num_key_value_heads = (
  745. config.num_global_key_value_heads if self.use_alternative_attention else config.num_key_value_heads
  746. )
  747. self.num_key_value_groups = config.num_attention_heads // num_key_value_heads
  748. self.scaling = 1.0
  749. self.attention_dropout = self.config.attention_dropout
  750. self.is_causal = config.use_bidirectional_attention != "all"
  751. # Shared kv cache
  752. first_kv_shared_layer_idx = self.config.num_hidden_layers - getattr(self.config, "num_kv_shared_layers", 0)
  753. self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
  754. prev_layers = config.layer_types[:first_kv_shared_layer_idx]
  755. if self.is_kv_shared_layer:
  756. # For shared layers, find the last non-shared layer of the same type before sharing starts
  757. self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
  758. self.store_full_length_kv = False
  759. else:
  760. self.kv_shared_layer_index = None
  761. # For non-shared layers, store full-length kv if this is the last non-shared layer of its type
  762. self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
  763. config.layer_types[layer_idx]
  764. )
  765. self.q_proj = nn.Linear(
  766. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  767. )
  768. self.q_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps)
  769. # Layers sharing kv states don't need any weight matrices
  770. if not self.is_kv_shared_layer:
  771. self.k_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps)
  772. self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False)
  773. self.k_proj = nn.Linear(
  774. config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias
  775. )
  776. self.v_proj = (
  777. nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias)
  778. if not self.use_alternative_attention
  779. else None
  780. )
  781. self.o_proj = nn.Linear(
  782. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  783. )
  784. def forward(
  785. self,
  786. hidden_states: torch.Tensor,
  787. position_embeddings: torch.Tensor,
  788. attention_mask: torch.Tensor | None,
  789. shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
  790. past_key_values: Cache | None = None,
  791. **kwargs: Unpack[FlashAttentionKwargs],
  792. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  793. input_shape = hidden_states.shape[:-1]
  794. hidden_shape = (*input_shape, -1, self.head_dim)
  795. cos, sin = position_embeddings
  796. query_states = self.q_proj(hidden_states).view(hidden_shape)
  797. query_states = self.q_norm(query_states)
  798. query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
  799. query_states = query_states.transpose(1, 2)
  800. # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
  801. # We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
  802. # once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
  803. if self.is_kv_shared_layer:
  804. key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
  805. # Device of past layer may be different from current one
  806. key_states = key_states.to(query_states.device)
  807. value_states = value_states.to(query_states.device)
  808. else:
  809. key_states = self.k_proj(hidden_states).view(hidden_shape)
  810. value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
  811. key_states = self.k_norm(key_states)
  812. key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
  813. key_states = key_states.transpose(1, 2)
  814. value_states = self.v_norm(value_states)
  815. value_states = value_states.transpose(1, 2)
  816. if past_key_values is not None and not self.is_kv_shared_layer:
  817. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  818. if self.store_full_length_kv:
  819. shared_kv_states[self.layer_idx] = key_states, value_states
  820. attention_interface: Callable = eager_attention_forward
  821. if self.config._attn_implementation != "eager":
  822. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  823. attn_output, attn_weights = attention_interface(
  824. self,
  825. query_states,
  826. key_states,
  827. value_states,
  828. attention_mask,
  829. dropout=self.attention_dropout if self.training else 0.0,
  830. scaling=self.scaling,
  831. sliding_window=self.sliding_window,
  832. **kwargs,
  833. )
  834. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  835. attn_output = self.o_proj(attn_output)
  836. return attn_output, attn_weights
  837. class Gemma4TextExperts(MixtralExperts):
  838. def __init__(self, config: Gemma4TextConfig):
  839. super().__init__()
  840. self.num_experts = config.num_experts
  841. self.intermediate_dim = config.moe_intermediate_size
  842. self.act_fn = ACT2FN[config.hidden_activation]
  843. class Gemma4TextRouter(nn.Module):
  844. def __init__(self, config: Gemma4TextConfig):
  845. super().__init__()
  846. self.config = config
  847. self.hidden_size = config.hidden_size
  848. self.scalar_root_size = self.hidden_size**-0.5
  849. self.eps = config.rms_norm_eps
  850. self.norm = Gemma4RMSNorm(self.hidden_size, eps=self.eps, with_scale=False)
  851. self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False)
  852. self.scale = nn.Parameter(torch.ones(self.hidden_size))
  853. self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))
  854. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  855. hidden_states = self.norm(hidden_states)
  856. hidden_states = hidden_states * self.scale * self.scalar_root_size
  857. expert_scores = self.proj(hidden_states) # [B*S, E]
  858. router_probabilities = nn.functional.softmax(expert_scores, dim=-1)
  859. # topk returns both values (probabilities) and indices directly
  860. top_k_weights, top_k_index = torch.topk(
  861. router_probabilities,
  862. k=self.config.top_k_experts,
  863. dim=-1,
  864. ) # both [B*S, K]
  865. # Normalize the top-k weights so they sum to 1 per token
  866. top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
  867. # Apply per-expert scale directly to the weights
  868. top_k_weights = top_k_weights * self.per_expert_scale[top_k_index]
  869. return router_probabilities, top_k_weights, top_k_index
  870. class Gemma4TextDecoderLayer(Gemma3DecoderLayer):
  871. def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int):
  872. super().__init__(config, layer_idx)
  873. self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx)
  874. self.mlp = Gemma4TextMLP(config, layer_idx)
  875. self.register_buffer("layer_scalar", torch.ones(1))
  876. self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
  877. if self.hidden_size_per_layer_input:
  878. self.act_fn = ACT2FN[config.hidden_activation]
  879. self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
  880. self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
  881. self.post_per_layer_input_norm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  882. self.enable_moe_block = config.enable_moe_block
  883. if self.enable_moe_block:
  884. self.router = Gemma4TextRouter(config)
  885. self.experts = Gemma4TextExperts(config)
  886. self.post_feedforward_layernorm_1 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  887. self.post_feedforward_layernorm_2 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  888. self.pre_feedforward_layernorm_2 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  889. def forward(
  890. self,
  891. hidden_states: torch.Tensor,
  892. per_layer_input: torch.Tensor = None,
  893. shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]] | None = None,
  894. position_embeddings: torch.Tensor = None,
  895. attention_mask: torch.Tensor | None = None,
  896. position_ids: torch.LongTensor | None = None,
  897. past_key_values: Cache | None = None,
  898. **kwargs,
  899. ) -> torch.Tensor:
  900. residual = hidden_states
  901. hidden_states = self.input_layernorm(hidden_states)
  902. hidden_states, _ = self.self_attn(
  903. hidden_states=hidden_states,
  904. position_embeddings=position_embeddings,
  905. attention_mask=attention_mask,
  906. shared_kv_states=shared_kv_states,
  907. position_ids=position_ids,
  908. past_key_values=past_key_values,
  909. **kwargs,
  910. )
  911. hidden_states = self.post_attention_layernorm(hidden_states)
  912. hidden_states = residual + hidden_states
  913. residual = hidden_states
  914. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  915. hidden_states = self.mlp(hidden_states)
  916. if self.enable_moe_block:
  917. hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states)
  918. # Take hidden states before MLP here
  919. hidden_states_flat = residual.reshape(-1, residual.shape[-1])
  920. _, top_k_weights, top_k_index = self.router(hidden_states_flat)
  921. hidden_states_2 = self.pre_feedforward_layernorm_2(hidden_states_flat)
  922. hidden_states_2 = self.experts(hidden_states_2, top_k_index, top_k_weights)
  923. hidden_states_2 = hidden_states_2.reshape(residual.shape)
  924. hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2)
  925. # Combine mlp and moe outputs
  926. hidden_states = hidden_states_1 + hidden_states_2
  927. hidden_states = self.post_feedforward_layernorm(hidden_states)
  928. hidden_states = residual + hidden_states
  929. if self.hidden_size_per_layer_input:
  930. residual = hidden_states
  931. hidden_states = self.per_layer_input_gate(hidden_states)
  932. hidden_states = self.act_fn(hidden_states)
  933. hidden_states = hidden_states * per_layer_input
  934. hidden_states = self.per_layer_projection(hidden_states)
  935. hidden_states = self.post_per_layer_input_norm(hidden_states)
  936. hidden_states = residual + hidden_states
  937. hidden_states *= self.layer_scalar
  938. return hidden_states
  939. class Gemma4TextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
  940. pass
  941. # ---- Model Classes ----
  942. class Gemma4PreTrainedModel(PreTrainedModel):
  943. config: Gemma4Config
  944. supports_gradient_checkpointing = True
  945. _supports_flash_attn = True
  946. _supports_sdpa = True
  947. _supports_flex_attn = True
  948. _can_compile_fullgraph = True
  949. _supports_attention_backend = True
  950. _no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"]
  951. _skip_keys_device_placement = ["past_key_values", "shared_kv_states"]
  952. input_modalities = ("image", "text", "video", "audio")
  953. @torch.no_grad()
  954. def _init_weights(self, module):
  955. super()._init_weights(module)
  956. if isinstance(module, Gemma4VisionPatchEmbedder):
  957. init.ones_(module.position_embedding_table)
  958. elif isinstance(module, Gemma4AudioRelPositionalEncoding):
  959. min_timescale = 1.0
  960. max_timescale = 10000.0
  961. num_timescales = module.hidden_size // 2
  962. log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1)
  963. inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
  964. init.copy_(module.inv_timescales, inv_timescales.unsqueeze(0).unsqueeze(0))
  965. elif isinstance(module, Gemma4AudioAttention):
  966. init.constant_(module.softcap, module.attention_logits_soft_cap)
  967. init.zeros_(module.per_dim_scale)
  968. elif isinstance(module, Gemma4TextRotaryEmbedding):
  969. for layer_type, rope_init_fn in module.rope_init_fns.items():
  970. rope_init_fn_kwargs = {"layer_type": layer_type}
  971. if layer_type == "full_attention" and module.rope_type[layer_type] == "proportional":
  972. rope_init_fn_kwargs["head_dim_key"] = "global_head_dim"
  973. curr_inv_freq, _ = rope_init_fn(module.config, **rope_init_fn_kwargs)
  974. init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
  975. init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
  976. elif isinstance(module, Gemma4VisionRotaryEmbedding):
  977. rope_fn = (
  978. ROPE_INIT_FUNCTIONS[module.rope_type]
  979. if module.rope_type != "default"
  980. else module.compute_default_rope_parameters
  981. )
  982. buffer_value, _ = rope_fn(module.config)
  983. init.copy_(module.inv_freq, buffer_value)
  984. init.copy_(module.original_inv_freq, buffer_value)
  985. elif isinstance(module, Gemma4TextScaledWordEmbedding):
  986. init.constant_(module.embed_scale, module.scalar_embed_scale)
  987. elif isinstance(module, Gemma4TextRouter):
  988. init.ones_(module.scale)
  989. init.ones_(module.per_expert_scale)
  990. elif isinstance(module, Gemma4TextExperts):
  991. std = self.config.initializer_range
  992. init.normal_(module.gate_up_proj, mean=0.0, std=std)
  993. init.normal_(module.down_proj, mean=0.0, std=std)
  994. elif isinstance(module, Gemma4TextDecoderLayer):
  995. init.ones_(module.layer_scalar)
  996. elif isinstance(module, Gemma4ClippableLinear) and module.use_clipped_linears:
  997. init.constant_(module.input_min, -float("inf"))
  998. init.constant_(module.input_max, float("inf"))
  999. init.constant_(module.output_min, -float("inf"))
  1000. init.constant_(module.output_max, float("inf"))
  1001. elif isinstance(module, Gemma4VisionModel) and module.config.standardize:
  1002. init.zeros_(module.std_bias)
  1003. init.ones_(module.std_scale)
  1004. @auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.")
  1005. class Gemma4TextModel(Gemma3TextModel):
  1006. config: Gemma4TextConfig
  1007. _can_record_outputs = {
  1008. "router_logits": OutputRecorder(Gemma4TextRouter, index=0),
  1009. "hidden_states": Gemma4TextDecoderLayer,
  1010. "attentions": Gemma4TextAttention,
  1011. }
  1012. def __init__(self, config: Gemma4TextConfig):
  1013. super().__init__(config)
  1014. self.layers = nn.ModuleList(
  1015. [Gemma4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  1016. )
  1017. self.rotary_emb = Gemma4TextRotaryEmbedding(config)
  1018. self.unique_layer_types = set(self.config.layer_types)
  1019. self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
  1020. if self.hidden_size_per_layer_input:
  1021. self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding(
  1022. config.vocab_size_per_layer_input,
  1023. config.num_hidden_layers * config.hidden_size_per_layer_input,
  1024. self.padding_idx,
  1025. embed_scale=config.hidden_size_per_layer_input**0.5,
  1026. )
  1027. self.per_layer_input_scale = 2.0**-0.5
  1028. self.per_layer_model_projection = nn.Linear(
  1029. config.hidden_size,
  1030. config.num_hidden_layers * config.hidden_size_per_layer_input,
  1031. bias=False,
  1032. )
  1033. self.per_layer_model_projection_scale = config.hidden_size**-0.5
  1034. self.per_layer_projection_norm = Gemma4RMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps)
  1035. # Update `_keys_to_ignore_on_load_unexpected` to drop all k/v proj and norms for the shared layers
  1036. self._keys_to_ignore_on_load_unexpected = []
  1037. for i, layer in enumerate(self.layers):
  1038. if layer.self_attn.is_kv_shared_layer:
  1039. self._keys_to_ignore_on_load_unexpected.extend(
  1040. [f"layers.{i}.self_attn.{name}" for name in ("k_proj", "v_proj", "k_norm", "v_norm")]
  1041. )
  1042. def get_per_layer_inputs(self, input_ids: torch.Tensor | None, inputs_embeds: torch.Tensor | None) -> torch.Tensor:
  1043. if not self.hidden_size_per_layer_input:
  1044. raise RuntimeError(
  1045. "Attempting to call get_per_layer_inputs() from a model initialized with a config that does not support"
  1046. f" per-layer embeddings. {self.config}"
  1047. )
  1048. # If only inputs_embeds are provided, reverse main embedding to find the input_ids - this allows to `generate`
  1049. # from `inputs_embeds` only as other models (otherwise it would need the value from both embeddings)
  1050. if input_ids is None:
  1051. with torch.no_grad():
  1052. input_ids = (
  1053. (
  1054. inputs_embeds[:, :, None, :]
  1055. == self.embed_tokens.weight[None, None, :, :] * self.config.hidden_size**0.5
  1056. )
  1057. .all(dim=3)
  1058. .nonzero()[:, 2]
  1059. )
  1060. try:
  1061. input_ids = input_ids.view(inputs_embeds.shape[:2])
  1062. except RuntimeError:
  1063. raise RuntimeError(
  1064. "It seems like you tried to call `forward` from `inputs_embeds` without providing `input_ids`, and that "
  1065. "the `inputs_embeds` you provided do not exactly match the embedding weights. Since Gemma4 needs to reverse "
  1066. "the embedding to compute another embedding, make sure you provide exact `inputs_embeds`"
  1067. )
  1068. return self.embed_tokens_per_layer(input_ids).reshape(
  1069. *input_ids.shape,
  1070. self.config.num_hidden_layers,
  1071. self.hidden_size_per_layer_input,
  1072. )
  1073. def project_per_layer_inputs(
  1074. self,
  1075. inputs_embeds: torch.Tensor,
  1076. per_layer_inputs: torch.Tensor | None = None,
  1077. ) -> torch.Tensor:
  1078. if not self.hidden_size_per_layer_input:
  1079. raise RuntimeError(
  1080. "Attempting to call project_per_layer_inputs() from a model initialized with a config that does not"
  1081. f" support per-layer embeddings. {self.config}"
  1082. )
  1083. per_layer_projection = self.per_layer_model_projection(inputs_embeds) * self.per_layer_model_projection_scale
  1084. per_layer_projection = per_layer_projection.reshape(
  1085. *inputs_embeds.shape[:-1],
  1086. self.config.num_hidden_layers,
  1087. self.hidden_size_per_layer_input,
  1088. )
  1089. per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
  1090. if per_layer_inputs is None:
  1091. return per_layer_projection
  1092. return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
  1093. @merge_with_config_defaults
  1094. @capture_outputs
  1095. @auto_docstring
  1096. def forward(
  1097. self,
  1098. input_ids: torch.LongTensor | None = None,
  1099. attention_mask: torch.Tensor | None = None,
  1100. position_ids: torch.LongTensor | None = None,
  1101. past_key_values: Cache | None = None,
  1102. inputs_embeds: torch.FloatTensor | None = None,
  1103. per_layer_inputs: torch.Tensor | None = None,
  1104. use_cache: bool | None = None,
  1105. **kwargs: Unpack[TransformersKwargs],
  1106. ) -> BaseModelOutputWithPast:
  1107. r"""
  1108. per_layer_inputs (`torch.Tensor` of shape `(batch_size, sequence_length, num_hidden_layers, hidden_size_per_layer_input)`, *optional*):
  1109. Pre-computed per-layer input embeddings. When provided, these are used directly instead of being
  1110. computed from `input_ids` via `get_per_layer_inputs()`. This is primarily used by the multimodal
  1111. model (`Gemma4Model`) which pre-computes per-layer inputs from the original `input_ids` *before*
  1112. merging multimodal soft tokens into `inputs_embeds` — at which point the original token ids are
  1113. no longer recoverable.
  1114. """
  1115. if (input_ids is None) ^ (inputs_embeds is not None):
  1116. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1117. if input_ids is not None:
  1118. inputs_embeds = self.embed_tokens(input_ids)
  1119. if self.hidden_size_per_layer_input:
  1120. if per_layer_inputs is None:
  1121. per_layer_inputs = self.get_per_layer_inputs(input_ids, inputs_embeds)
  1122. per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
  1123. if use_cache and past_key_values is None:
  1124. past_key_values = DynamicCache(config=self.config)
  1125. if position_ids is None:
  1126. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1127. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  1128. position_ids = position_ids.unsqueeze(0)
  1129. # It may already have been prepared by e.g. `generate`
  1130. if not isinstance(causal_mask_mapping := attention_mask, dict):
  1131. # Prepare mask arguments
  1132. mask_kwargs = {
  1133. "config": self.config,
  1134. "inputs_embeds": inputs_embeds,
  1135. "attention_mask": attention_mask,
  1136. "past_key_values": past_key_values,
  1137. "position_ids": position_ids,
  1138. }
  1139. # Create the masks
  1140. causal_mask_mapping = {
  1141. "full_attention": create_causal_mask(**mask_kwargs),
  1142. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  1143. }
  1144. # embed positions
  1145. hidden_states = inputs_embeds
  1146. position_embeddings = {}
  1147. for layer_type in self.unique_layer_types:
  1148. position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
  1149. # Initialize as empty dict - it will be filled in the right layers
  1150. shared_kv_states = {}
  1151. # decoder layers
  1152. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  1153. per_layer_input = per_layer_inputs[:, :, i, :] if per_layer_inputs is not None else None
  1154. hidden_states = decoder_layer(
  1155. hidden_states,
  1156. per_layer_input,
  1157. shared_kv_states=shared_kv_states,
  1158. position_embeddings=position_embeddings[self.config.layer_types[i]],
  1159. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  1160. position_ids=position_ids,
  1161. past_key_values=past_key_values,
  1162. **kwargs,
  1163. )
  1164. hidden_states = self.norm(hidden_states)
  1165. return BaseModelOutputWithPast(
  1166. last_hidden_state=hidden_states,
  1167. past_key_values=past_key_values,
  1168. )
  1169. @auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.")
  1170. class Gemma4ForCausalLM(Gemma3ForCausalLM):
  1171. base_model_prefix = "model"
  1172. def __init__(self, config: Gemma4TextConfig):
  1173. super().__init__(config)
  1174. # Grab the ones from the child
  1175. self._keys_to_ignore_on_load_unexpected = [
  1176. f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
  1177. ]
  1178. class Gemma4AudioModel(Gemma4PreTrainedModel):
  1179. """An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture."""
  1180. config: Gemma4AudioConfig
  1181. main_input_name = "input_features"
  1182. base_model_prefix = "model.audio_tower" # prefix for Gemma4ForConditionalGeneration saved checkpoints, required for Gemma4AudioModel.from_pretrained()
  1183. _can_record_outputs = {
  1184. "hidden_states": Gemma4AudioLayer,
  1185. "attentions": Gemma4AudioAttention,
  1186. }
  1187. def __init__(self, config: Gemma4AudioConfig):
  1188. super().__init__(config)
  1189. self.config = config
  1190. self.subsample_conv_projection = Gemma4AudioSubSampleConvProjection(config)
  1191. self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config)
  1192. self.layers = nn.ModuleList(
  1193. [Gemma4AudioLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  1194. )
  1195. self.output_proj = nn.Linear(config.hidden_size, config.output_proj_dims, bias=True)
  1196. self.post_init()
  1197. def _convert_4d_mask_to_blocked_5d(self, mask_4d: torch.Tensor) -> torch.Tensor:
  1198. """
  1199. Convert a standard 4D attention mask `[batch_size, 1, seq_len, seq_len]` to the 5D blocked format
  1200. `[batch_size, 1, num_blocks, chunk_size, context_size]` expected by the chunked local attention,
  1201. """
  1202. batch_size, _, seq_len, _ = mask_4d.shape
  1203. device = mask_4d.device
  1204. chunk_size = self.config.attention_chunk_size
  1205. max_past_horizon = self.config.attention_context_left - 1
  1206. max_future_horizon = self.config.attention_context_right
  1207. num_blocks = (seq_len + chunk_size - 1) // chunk_size
  1208. padded_seq_len = num_blocks * chunk_size
  1209. pad_amount = padded_seq_len - seq_len
  1210. mask_4d = F.pad(mask_4d, (0, pad_amount, 0, pad_amount), value=False)
  1211. mask_5d = mask_4d.reshape(batch_size, 1, num_blocks, chunk_size, padded_seq_len)
  1212. mask_5d = F.pad(mask_5d, (max_past_horizon, max_future_horizon), value=False)
  1213. block_starts = torch.arange(num_blocks, device=device) * chunk_size
  1214. offsets = torch.arange(chunk_size + max_past_horizon + max_future_horizon, device=device)
  1215. kv_indices = block_starts[:, None] + offsets[None, :]
  1216. kv_indices = kv_indices[None, None, :, None, :].expand(batch_size, 1, -1, chunk_size, -1)
  1217. return mask_5d.gather(-1, kv_indices)
  1218. @merge_with_config_defaults
  1219. @capture_outputs
  1220. @auto_docstring(custom_intro="Encodes audio features to soft tokens.")
  1221. def forward(
  1222. self,
  1223. input_features: torch.Tensor,
  1224. attention_mask: torch.Tensor | None = None,
  1225. **kwargs: Unpack[TransformersKwargs],
  1226. ) -> tuple[torch.Tensor, torch.BoolTensor]:
  1227. hidden_states, output_mask = self.subsample_conv_projection(input_features, attention_mask)
  1228. position_embeddings = self.rel_pos_enc(hidden_states)
  1229. attention_mask = create_bidirectional_mask(
  1230. config=self.config,
  1231. inputs_embeds=hidden_states,
  1232. attention_mask=output_mask,
  1233. and_mask_function=sliding_window_mask_function(
  1234. (self.config.attention_context_left - 1, self.config.attention_context_right)
  1235. ),
  1236. )
  1237. attention_mask = self._convert_4d_mask_to_blocked_5d(attention_mask)
  1238. for encoder_layer in self.layers[: self.config.num_hidden_layers]:
  1239. hidden_states = encoder_layer(
  1240. hidden_states,
  1241. attention_mask=attention_mask,
  1242. position_embeddings=position_embeddings,
  1243. **kwargs,
  1244. )
  1245. hidden_states = self.output_proj(hidden_states)
  1246. return Gemma4AudioModelOutput(last_hidden_state=hidden_states, attention_mask=output_mask)
  1247. class Gemma4VisionModel(Gemma4PreTrainedModel):
  1248. """The Gemma 4 Vision Encoder."""
  1249. config = Gemma4VisionConfig
  1250. _can_record_outputs = {
  1251. "hidden_states": Gemma4VisionEncoderLayer,
  1252. "attentions": Gemma4VisionAttention,
  1253. }
  1254. def __init__(self, config: Gemma4VisionConfig):
  1255. super().__init__(config)
  1256. self.patch_embedder = Gemma4VisionPatchEmbedder(config)
  1257. self.encoder = Gemma4VisionEncoder(config)
  1258. self.pooler = Gemma4VisionPooler(config)
  1259. if self.config.standardize:
  1260. self.register_buffer("std_bias", torch.empty(self.config.hidden_size))
  1261. self.register_buffer("std_scale", torch.empty(self.config.hidden_size))
  1262. self.post_init()
  1263. @merge_with_config_defaults
  1264. @capture_outputs
  1265. @auto_docstring(custom_intro="Encodes image pixels to soft tokens from patches.")
  1266. def forward(
  1267. self,
  1268. pixel_values: torch.FloatTensor,
  1269. pixel_position_ids: torch.LongTensor,
  1270. **kwargs: Unpack[TransformersKwargs],
  1271. ) -> BaseModelOutputWithPast:
  1272. r"""
  1273. pixel_values (`torch.FloatTensor` or `list[torch.FloatTensor]`):
  1274. The images to encode. Either a single `[batch, channels, height, width]` tensor
  1275. (all images same size) or a list of `[1, channels, height, width]` tensors (different sizes).
  1276. pixel_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`):
  1277. The patch positions as (x, y) coordinates in the image. Padding patches are indicated by (-1, -1).
  1278. """
  1279. pooling_kernel_size = self.config.pooling_kernel_size
  1280. output_length = pixel_values.shape[-2] // (pooling_kernel_size * pooling_kernel_size)
  1281. padding_positions = (pixel_position_ids == -1).all(dim=-1)
  1282. inputs_embeds = self.patch_embedder(pixel_values, pixel_position_ids, padding_positions)
  1283. output = self.encoder(
  1284. inputs_embeds=inputs_embeds,
  1285. attention_mask=~padding_positions, # encoder expects True=valid, padding_positions is True=padding
  1286. pixel_position_ids=pixel_position_ids,
  1287. **kwargs,
  1288. )
  1289. hidden_states, pooler_mask = self.pooler(
  1290. hidden_states=output.last_hidden_state,
  1291. pixel_position_ids=pixel_position_ids,
  1292. padding_positions=padding_positions,
  1293. output_length=output_length,
  1294. )
  1295. # Strip padding tokens. pooler_mask is True = valid, False = padding.
  1296. hidden_states = hidden_states[pooler_mask]
  1297. if self.config.standardize:
  1298. hidden_states = (hidden_states - self.std_bias) * self.std_scale
  1299. return BaseModelOutputWithPast(last_hidden_state=hidden_states)
  1300. class Gemma4MultimodalEmbedder(Gemma3nMultimodalEmbedder):
  1301. def __init__(
  1302. self,
  1303. multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig,
  1304. text_config: Gemma4TextConfig,
  1305. ):
  1306. # Audio tower may use a different output dimension (output_proj_dims) than the
  1307. # internal hidden_size. Use the tower-specific dimension if specified.
  1308. super().__init__(multimodal_config, text_config)
  1309. del self.embedding
  1310. del self.hard_embedding_norm
  1311. del self.soft_embedding_norm
  1312. del self.vocab_offset
  1313. del self.vocab_size
  1314. del self.embedding_post_projection_norm
  1315. self.multimodal_hidden_size = getattr(multimodal_config, "output_proj_dims", multimodal_config.hidden_size)
  1316. self.embedding_pre_projection_norm = Gemma4RMSNorm(self.multimodal_hidden_size, eps=self.eps, with_scale=False)
  1317. def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
  1318. """Embeds token ids or soft tokens for multimodal content into language model space.
  1319. Args:
  1320. inputs_embeds: A torch.Tensor containing the soft tokens to embed.
  1321. Returns:
  1322. A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
  1323. """
  1324. embs_normed = self.embedding_pre_projection_norm(inputs_embeds)
  1325. return self.embedding_projection(embs_normed)
  1326. # Identical as Gemma3 but modular can't resolve if we simply import. FIXME: @cyril
  1327. def token_type_ids_mask_function(
  1328. token_type_ids: torch.Tensor | None,
  1329. image_group_ids: torch.Tensor | None,
  1330. ) -> Callable | None:
  1331. """
  1332. This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
  1333. not start and end indices.
  1334. """
  1335. # Do not return an additional mask in this case
  1336. if token_type_ids is None:
  1337. return None
  1338. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  1339. seq_length = image_group_ids.shape[-1]
  1340. # clamp indices because with static cache they can go beyond `image_group_ids.shape[-1]`
  1341. q_idx_clamped = q_idx.clamp(max=seq_length - 1)
  1342. kv_idx_clamped = kv_idx.clamp(max=seq_length - 1)
  1343. # Unmask if the q and kv come from same group which is not -1 (i.e. non-text)
  1344. q_group = image_group_ids[batch_idx, q_idx_clamped]
  1345. kv_group = image_group_ids[batch_idx, kv_idx_clamped]
  1346. q_group = torch.where(q_idx < seq_length, q_group, -1)
  1347. kv_group = torch.where(kv_idx < seq_length, kv_group, -1)
  1348. return (q_group == kv_group) & (q_group >= 0)
  1349. return inner_mask
  1350. # Similar to Gemma3 but `sliding_mask_kwargs` and `mask_kwargs` are different and `token_type_ids->mm_token_type_ids`
  1351. def create_causal_mask_mapping(
  1352. config: PreTrainedConfig,
  1353. inputs_embeds: torch.Tensor,
  1354. attention_mask: torch.Tensor | None,
  1355. past_key_values: Cache | None,
  1356. position_ids: torch.Tensor | None,
  1357. mm_token_type_ids: torch.Tensor | None = None,
  1358. pixel_values: torch.FloatTensor | None = None,
  1359. is_training: bool = False,
  1360. is_first_iteration: bool | None = None,
  1361. **kwargs,
  1362. ) -> dict:
  1363. """
  1364. Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
  1365. for all kinds of forward passes. Gemma4 uses a bidirectional mask for images.
  1366. Uses `pixel_values` as an optional input to disambiguate edge cases.
  1367. """
  1368. if is_training and mm_token_type_ids is None:
  1369. raise ValueError("`mm_token_type_ids` is required as a model input when training")
  1370. mask_kwargs = {
  1371. "config": config.get_text_config(),
  1372. "inputs_embeds": inputs_embeds,
  1373. "attention_mask": attention_mask,
  1374. "past_key_values": past_key_values,
  1375. "position_ids": position_ids,
  1376. }
  1377. sliding_mask_kwargs = mask_kwargs.copy()
  1378. # NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
  1379. # (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
  1380. # means). Determining prefill in that case requires checking data values, which is not compile-compatible.
  1381. is_first_iteration = (
  1382. is_first_iteration
  1383. if is_first_iteration is not None
  1384. else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
  1385. )
  1386. if mm_token_type_ids is not None and is_first_iteration:
  1387. # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
  1388. # undo the causal masking)
  1389. # First find where a new vision block starts. Vision tokens cannot attend to
  1390. # future vision tokens, but can attend to all prev tokens and to itself bidirectionally
  1391. is_vision = (mm_token_type_ids == 1) | (mm_token_type_ids == 2)
  1392. is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1)
  1393. is_prev_vision[..., 0] = False
  1394. new_vision_starts = is_vision & ~is_prev_vision
  1395. vision_group_ids = torch.cumsum(new_vision_starts.int(), dim=1) - 1
  1396. vision_group_ids = torch.where(is_vision, vision_group_ids, -1)
  1397. sliding_mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
  1398. mm_token_type_ids.to(inputs_embeds.device), vision_group_ids
  1399. )
  1400. return {
  1401. "full_attention": create_causal_mask(**mask_kwargs),
  1402. "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
  1403. }
  1404. @auto_docstring(
  1405. custom_intro="""
  1406. The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a
  1407. language modeling head.
  1408. """
  1409. )
  1410. class Gemma4Model(Gemma3nModel):
  1411. def __init__(self, config: Gemma4Config):
  1412. super().__init__(config)
  1413. del self.vision_tower
  1414. del self.embed_vision
  1415. self.vision_tower = AutoModel.from_config(config.vision_config) if config.vision_config is not None else None
  1416. self.embed_vision = (
  1417. Gemma4MultimodalEmbedder(config.vision_config, config.text_config)
  1418. if config.vision_config is not None
  1419. else None
  1420. )
  1421. del self.audio_tower
  1422. del self.embed_audio
  1423. self.audio_tower = AutoModel.from_config(config.audio_config) if config.audio_config is not None else None
  1424. self.embed_audio = (
  1425. Gemma4MultimodalEmbedder(config.audio_config, config.text_config)
  1426. if config.audio_config is not None
  1427. else None
  1428. )
  1429. # Grab the ones from the child
  1430. self._keys_to_ignore_on_load_unexpected = [
  1431. f"language_model.{name}" for name in self.language_model._keys_to_ignore_on_load_unexpected
  1432. ]
  1433. @can_return_tuple
  1434. @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.")
  1435. def get_image_features(
  1436. self,
  1437. pixel_values: torch.FloatTensor,
  1438. image_position_ids: torch.LongTensor | None = None,
  1439. **kwargs: Unpack[TransformersKwargs],
  1440. ) -> BaseModelOutputWithPooling:
  1441. r"""
  1442. image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*):
  1443. The patch positions as (x, y) coordinates in the image. Padding patches are indicated by (-1, -1).
  1444. """
  1445. vision_outputs = self.vision_tower(
  1446. pixel_values=pixel_values,
  1447. pixel_position_ids=image_position_ids,
  1448. **kwargs,
  1449. )
  1450. last_hidden_state = vision_outputs.last_hidden_state
  1451. vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state)
  1452. return vision_outputs
  1453. @can_return_tuple
  1454. @auto_docstring(custom_intro="Projects the last hidden state from the vision encoder into language model space.")
  1455. def get_video_features(
  1456. self,
  1457. pixel_values_videos: torch.FloatTensor,
  1458. video_position_ids: torch.LongTensor | None = None,
  1459. **kwargs: Unpack[TransformersKwargs],
  1460. ) -> BaseModelOutputWithPooling:
  1461. r"""
  1462. video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*):
  1463. 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding.
  1464. Passed through to the vision encoder for positional embedding computation.
  1465. """
  1466. pixel_values_videos = pixel_values_videos.flatten(0, 1)
  1467. video_position_ids = video_position_ids.flatten(0, 1)
  1468. vision_outputs = self.vision_tower(
  1469. pixel_values=pixel_values_videos,
  1470. pixel_position_ids=video_position_ids,
  1471. **kwargs,
  1472. )
  1473. last_hidden_state = vision_outputs.last_hidden_state
  1474. vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state)
  1475. return vision_outputs
  1476. def get_placeholder_mask(
  1477. self,
  1478. input_ids: torch.LongTensor | None = None,
  1479. inputs_embeds: torch.FloatTensor | None = None,
  1480. ) -> tuple[torch.BoolTensor, torch.BoolTensor, torch.BoolTensor]:
  1481. """
  1482. Obtains mask for multimodal placeholders (replaced by soft tokens) and hard text tokens.
  1483. Masks will be obtained from `mm_token_type_ids`, `input_ids`, or `inputs_embeds` as available and in that
  1484. precedence order. If passing `input_ids` or `inputs_embeds`, the image mask will be derived using
  1485. `config.image_token_id`. Same goes for audio and video masks
  1486. Args:
  1487. input_ids: A tensor containing the hard token IDs from the text tokenizer.
  1488. inputs_embeds: A tensor containing the embeddings for all hard text tokens.
  1489. Returns:
  1490. image_mask, video_mask, audio_mask
  1491. """
  1492. if input_ids is not None:
  1493. special_image_mask = input_ids == self.config.image_token_id
  1494. special_video_mask = input_ids == self.config.video_token_id
  1495. special_audio_mask = input_ids == self.config.audio_token_id
  1496. else:
  1497. special_image_mask = (
  1498. inputs_embeds
  1499. == self.get_input_embeddings()(
  1500. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  1501. )
  1502. ).all(-1)
  1503. special_video_mask = (
  1504. inputs_embeds
  1505. == self.get_input_embeddings()(
  1506. torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
  1507. )
  1508. ).all(-1)
  1509. special_audio_mask = (
  1510. inputs_embeds
  1511. == self.get_input_embeddings()(
  1512. torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
  1513. )
  1514. ).all(-1)
  1515. return special_image_mask, special_video_mask, special_audio_mask
  1516. @merge_with_config_defaults
  1517. @can_return_tuple
  1518. @auto_docstring
  1519. def forward(
  1520. self,
  1521. input_ids: torch.LongTensor | None = None,
  1522. pixel_values: torch.FloatTensor | None = None,
  1523. pixel_values_videos: torch.FloatTensor | None = None,
  1524. input_features: torch.FloatTensor | None = None,
  1525. attention_mask: torch.Tensor | None = None,
  1526. input_features_mask: torch.Tensor | None = None,
  1527. position_ids: torch.LongTensor | None = None,
  1528. past_key_values: Cache | None = None,
  1529. mm_token_type_ids: torch.LongTensor | None = None,
  1530. inputs_embeds: torch.FloatTensor | None = None,
  1531. use_cache: bool | None = None,
  1532. image_position_ids: torch.LongTensor | None = None,
  1533. video_position_ids: torch.LongTensor | None = None,
  1534. **kwargs: Unpack[TransformersKwargs],
  1535. ) -> Gemma4ModelOutputWithPast:
  1536. r"""
  1537. input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
  1538. The attention mask for the input audio.
  1539. image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*):
  1540. 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding.
  1541. Passed through to the vision encoder for positional embedding computation.
  1542. video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*):
  1543. 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding.
  1544. Passed through to the vision encoder for positional embedding computation.
  1545. """
  1546. if (input_ids is None) ^ (inputs_embeds is not None):
  1547. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1548. image_mask, video_mask, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds)
  1549. multimodal_mask = image_mask | video_mask | audio_mask
  1550. # Replace image id with PAD if the image token if OOV, to avoid index-errors
  1551. llm_input_ids = None
  1552. if inputs_embeds is None:
  1553. llm_input_ids = input_ids.clone()
  1554. llm_input_ids[multimodal_mask] = self.config.text_config.pad_token_id
  1555. inputs_embeds = self.get_input_embeddings()(llm_input_ids)
  1556. if self.config.get_text_config().hidden_size_per_layer_input:
  1557. pad_embedding = self.language_model.embed_tokens.weight[self.config.text_config.pad_token_id, :]
  1558. llm_inputs_embeds = torch.where(multimodal_mask[..., None], pad_embedding.view(1, 1, -1), inputs_embeds)
  1559. per_layer_inputs = self.language_model.get_per_layer_inputs(llm_input_ids, llm_inputs_embeds)
  1560. else:
  1561. per_layer_inputs = None
  1562. # Merge text and images
  1563. if pixel_values is not None:
  1564. image_features = self.get_image_features(pixel_values, image_position_ids, return_dict=True).pooler_output
  1565. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  1566. # Confirm the number of soft tokens from the vision tower matches the number of slots in the embeddings.
  1567. n_image_tokens = image_mask.sum()
  1568. image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1569. torch_compilable_check(
  1570. inputs_embeds[image_mask].numel() == image_features.numel(),
  1571. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features:"
  1572. f" {image_features.shape[0]}",
  1573. )
  1574. inputs_embeds = inputs_embeds.masked_scatter(
  1575. image_mask.to(inputs_embeds.device), image_features.to(inputs_embeds.device)
  1576. )
  1577. if pixel_values_videos is not None:
  1578. video_features = self.get_video_features(
  1579. pixel_values_videos, video_position_ids, return_dict=True
  1580. ).pooler_output
  1581. video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
  1582. # Confirm the number of soft tokens from the vision tower matches the number of slots in the embeddings.
  1583. n_video_tokens = video_mask.sum()
  1584. video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1585. torch_compilable_check(
  1586. inputs_embeds[video_mask].numel() == video_features.numel(),
  1587. f"Video features and video tokens do not match, tokens: {n_video_tokens}, features:"
  1588. f" {video_features.shape[0]}",
  1589. )
  1590. inputs_embeds = inputs_embeds.masked_scatter(
  1591. video_mask.to(inputs_embeds.device), video_features.to(inputs_embeds.device)
  1592. )
  1593. # Merge text and audio
  1594. if input_features is not None and input_features_mask is not None:
  1595. audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True)
  1596. audio_features = audio_output.pooler_output
  1597. audio_mask_from_encoder = audio_output.attention_mask # True = valid
  1598. # Strip padding tokens: only keep real (non-padding) audio soft tokens.
  1599. # audio_mask_from_encoder is True for valid positions, False for padding tokens.
  1600. # This mirrors the vision encoder's padding stripping (see Gemma4VisionEncoder.forward).
  1601. audio_features = audio_features[audio_mask_from_encoder]
  1602. n_audio_tokens = audio_mask.sum()
  1603. audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1604. torch_compilable_check(
  1605. inputs_embeds[audio_mask].numel() == audio_features.numel(),
  1606. f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features:"
  1607. f" {audio_features.shape[0] * audio_features.shape[1]}",
  1608. )
  1609. inputs_embeds = inputs_embeds.masked_scatter(
  1610. audio_mask.to(inputs_embeds.device), audio_features.to(inputs_embeds.device)
  1611. )
  1612. # It may already have been prepared by, e.g., `generate`
  1613. if position_ids is None:
  1614. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1615. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  1616. position_ids = position_ids.unsqueeze(0)
  1617. if not isinstance(causal_mask_mapping := attention_mask, dict):
  1618. if self.config.get_text_config().use_bidirectional_attention == "vision":
  1619. # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs
  1620. causal_mask_mapping = create_causal_mask_mapping(
  1621. self.config,
  1622. inputs_embeds,
  1623. attention_mask,
  1624. past_key_values,
  1625. position_ids,
  1626. mm_token_type_ids,
  1627. pixel_values,
  1628. is_training=self.training,
  1629. )
  1630. else:
  1631. # Smaller Gemma models use a conventional casual attention mask
  1632. causal_mask_mapping = create_masks_for_generate(
  1633. self.config,
  1634. inputs_embeds,
  1635. attention_mask,
  1636. past_key_values,
  1637. position_ids,
  1638. )
  1639. outputs = self.language_model(
  1640. per_layer_inputs=per_layer_inputs,
  1641. attention_mask=causal_mask_mapping,
  1642. position_ids=position_ids,
  1643. past_key_values=past_key_values,
  1644. inputs_embeds=inputs_embeds,
  1645. use_cache=use_cache,
  1646. return_dict=True,
  1647. **kwargs,
  1648. )
  1649. return Gemma4ModelOutputWithPast(
  1650. last_hidden_state=outputs.last_hidden_state,
  1651. past_key_values=outputs.past_key_values,
  1652. hidden_states=outputs.hidden_states,
  1653. attentions=outputs.attentions,
  1654. image_hidden_states=image_features if pixel_values is not None else None,
  1655. audio_hidden_states=audio_features if input_features is not None else None,
  1656. )
  1657. @can_return_tuple
  1658. @auto_docstring(custom_intro="Projects the last hidden state from the audio encoder into language model space.")
  1659. def get_audio_features(
  1660. self,
  1661. input_features: torch.Tensor,
  1662. input_features_mask: torch.Tensor,
  1663. **kwargs: Unpack[TransformersKwargs],
  1664. ) -> tuple | Gemma4AudioModelOutput:
  1665. r"""
  1666. input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`):
  1667. The tensors corresponding to the input audio.
  1668. input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
  1669. The attention mask for the input audio.
  1670. """
  1671. if self.audio_tower is None:
  1672. raise ValueError(
  1673. "Audio features were requested, but the model was initialized without an audio_config. "
  1674. "Cannot process audio without an audio tower and audio embedder."
  1675. )
  1676. audio_outputs = self.audio_tower(input_features, input_features_mask, return_dict=True, **kwargs)
  1677. audio_outputs.pooler_output = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state)
  1678. return audio_outputs
  1679. @auto_docstring(
  1680. custom_intro="""
  1681. The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling
  1682. head.
  1683. """
  1684. )
  1685. class Gemma4ForConditionalGeneration(Gemma3nForConditionalGeneration):
  1686. base_model_prefix = "model"
  1687. def __init__(self, config: Gemma4Config):
  1688. super().__init__(config)
  1689. # Grab the ones from the child
  1690. self._keys_to_ignore_on_load_unexpected = [
  1691. f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
  1692. ]
  1693. def forward(
  1694. self,
  1695. input_ids: torch.LongTensor | None = None,
  1696. pixel_values: torch.FloatTensor | None = None,
  1697. pixel_values_videos: torch.FloatTensor | None = None,
  1698. input_features: torch.FloatTensor | None = None,
  1699. attention_mask: torch.Tensor | None = None,
  1700. input_features_mask: torch.Tensor | None = None,
  1701. position_ids: torch.LongTensor | None = None,
  1702. image_position_ids: torch.LongTensor | None = None,
  1703. video_position_ids: torch.LongTensor | None = None,
  1704. past_key_values: Cache | None = None,
  1705. mm_token_type_ids: torch.LongTensor | None = None,
  1706. inputs_embeds: torch.FloatTensor | None = None,
  1707. labels: torch.LongTensor | None = None,
  1708. use_cache: bool | None = None,
  1709. logits_to_keep: int | torch.Tensor = 0,
  1710. **kwargs: Unpack[TransformersKwargs],
  1711. ) -> Gemma4CausalLMOutputWithPast:
  1712. r"""
  1713. input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
  1714. The attention mask for the input audio.
  1715. image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*):
  1716. 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding.
  1717. Passed through to the vision encoder for positional embedding computation.
  1718. video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*):
  1719. 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding.
  1720. Passed through to the vision encoder for positional embedding computation.
  1721. """
  1722. outputs = self.model(
  1723. input_ids=input_ids,
  1724. pixel_values=pixel_values,
  1725. pixel_values_videos=pixel_values_videos,
  1726. input_features=input_features,
  1727. attention_mask=attention_mask,
  1728. input_features_mask=input_features_mask,
  1729. position_ids=position_ids,
  1730. past_key_values=past_key_values,
  1731. mm_token_type_ids=mm_token_type_ids,
  1732. inputs_embeds=inputs_embeds,
  1733. labels=labels,
  1734. use_cache=use_cache,
  1735. image_position_ids=image_position_ids,
  1736. video_position_ids=video_position_ids,
  1737. return_dict=True,
  1738. **kwargs,
  1739. )
  1740. hidden_states = outputs.last_hidden_state
  1741. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1742. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1743. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1744. if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None:
  1745. logits = logits / final_logit_softcapping
  1746. logits = torch.tanh(logits)
  1747. logits = logits * final_logit_softcapping
  1748. loss = None
  1749. if labels is not None:
  1750. # Upcast to float if we need to compute the loss to avoid potential precision issues
  1751. logits = logits.float()
  1752. shift_logits = logits[..., :-1, :]
  1753. shift_labels = labels[..., 1:]
  1754. if attention_mask is not None:
  1755. # we use the input attention mask to shift the logits and labels, because it is 2D.
  1756. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
  1757. shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
  1758. shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
  1759. shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
  1760. else:
  1761. shift_logits = shift_logits.contiguous()
  1762. shift_labels = shift_labels.contiguous()
  1763. # Flatten the tokens
  1764. loss_fct = nn.CrossEntropyLoss()
  1765. flat_logits = shift_logits.view(-1, self.config.get_text_config().vocab_size)
  1766. flat_labels = shift_labels.view(-1).to(shift_logits.device)
  1767. loss = loss_fct(flat_logits, flat_labels)
  1768. return Gemma4CausalLMOutputWithPast(
  1769. loss=loss,
  1770. logits=logits,
  1771. past_key_values=outputs.past_key_values,
  1772. hidden_states=outputs.hidden_states,
  1773. attentions=outputs.attentions,
  1774. image_hidden_states=outputs.image_hidden_states,
  1775. audio_hidden_states=outputs.audio_hidden_states,
  1776. )
  1777. @auto_docstring
  1778. def get_image_features(
  1779. self,
  1780. pixel_values: torch.FloatTensor,
  1781. image_position_ids: torch.LongTensor | None = None,
  1782. **kwargs: Unpack[TransformersKwargs],
  1783. ):
  1784. r"""
  1785. image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*):
  1786. 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding.
  1787. Passed through to the vision encoder for positional embedding computation.
  1788. """
  1789. return self.model.get_image_features(pixel_values, image_position_ids, **kwargs)
  1790. @staticmethod
  1791. def create_masks_for_generate(
  1792. config: PreTrainedConfig,
  1793. inputs_embeds: torch.Tensor,
  1794. attention_mask: torch.Tensor | None,
  1795. past_key_values: Cache | None,
  1796. position_ids: torch.Tensor | None,
  1797. mm_token_type_ids: torch.Tensor | None = None,
  1798. is_first_iteration: bool | None = False,
  1799. **kwargs,
  1800. ) -> dict:
  1801. if getattr(config.get_text_config(), "use_bidirectional_attention", None) == "vision":
  1802. # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs
  1803. return create_causal_mask_mapping(
  1804. config,
  1805. inputs_embeds,
  1806. attention_mask,
  1807. past_key_values,
  1808. position_ids,
  1809. mm_token_type_ids,
  1810. is_first_iteration=is_first_iteration,
  1811. **{k: v for k, v in kwargs.items() if k != "pixel_values"},
  1812. )
  1813. else:
  1814. # Smaller Gemma models use a conventional casual attention mask
  1815. return create_masks_for_generate(
  1816. config, inputs_embeds, attention_mask, past_key_values, position_ids, **kwargs
  1817. )
  1818. def prepare_inputs_for_generation(
  1819. self,
  1820. input_ids,
  1821. past_key_values=None,
  1822. inputs_embeds=None,
  1823. position_ids=None,
  1824. pixel_values=None,
  1825. pixel_values_videos=None,
  1826. input_features=None,
  1827. attention_mask=None,
  1828. input_features_mask=None,
  1829. token_type_ids=None,
  1830. use_cache=True,
  1831. logits_to_keep=None,
  1832. labels=None,
  1833. is_first_iteration=False,
  1834. **kwargs,
  1835. ):
  1836. # Overwritten -- custom `position_ids` and `pixel_values` handling
  1837. model_inputs = super().prepare_inputs_for_generation(
  1838. input_ids,
  1839. past_key_values=past_key_values,
  1840. inputs_embeds=inputs_embeds,
  1841. attention_mask=attention_mask,
  1842. position_ids=position_ids,
  1843. use_cache=use_cache,
  1844. logits_to_keep=logits_to_keep,
  1845. token_type_ids=token_type_ids,
  1846. is_first_iteration=is_first_iteration,
  1847. **kwargs,
  1848. )
  1849. # If we're in cached decoding stage, multimodal inputs are already cached and can be dropped
  1850. if is_first_iteration or not use_cache:
  1851. model_inputs["pixel_values"] = pixel_values
  1852. model_inputs["pixel_values_videos"] = pixel_values_videos
  1853. model_inputs["input_features"] = input_features
  1854. model_inputs["input_features_mask"] = input_features_mask
  1855. return model_inputs
  1856. __all__ = [
  1857. "Gemma4AudioModel",
  1858. "Gemma4ForCausalLM",
  1859. "Gemma4ForConditionalGeneration",
  1860. "Gemma4Model",
  1861. "Gemma4PreTrainedModel",
  1862. "Gemma4TextModel",
  1863. "Gemma4VisionModel",
  1864. ]