modular_gemma3n.py 112 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410
  1. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. from collections.abc import Callable, Sequence
  17. from dataclasses import dataclass
  18. from typing import Any
  19. import torch
  20. import torch.nn as nn
  21. import torch.nn.functional as F
  22. from huggingface_hub.dataclasses import strict
  23. from ... import initialization as init
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...configuration_utils import PreTrainedConfig
  27. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  28. from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
  29. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
  33. from ...utils.generic import merge_with_config_defaults
  34. from ...utils.output_capturing import capture_outputs
  35. from ..auto import AutoModel
  36. from ..gemma2.modeling_gemma2 import (
  37. Gemma2MLP,
  38. Gemma2PreTrainedModel,
  39. eager_attention_forward,
  40. rotate_half,
  41. )
  42. from ..gemma3.configuration_gemma3 import Gemma3TextConfig
  43. from ..gemma3.modeling_gemma3 import (
  44. Gemma3Attention,
  45. Gemma3DecoderLayer,
  46. Gemma3ForCausalLM,
  47. Gemma3RotaryEmbedding,
  48. Gemma3TextModel,
  49. Gemma3TextScaledWordEmbedding,
  50. )
  51. from ..paligemma.modeling_paligemma import (
  52. PaliGemmaCausalLMOutputWithPast,
  53. PaliGemmaForConditionalGeneration,
  54. PaliGemmaModel,
  55. PaligemmaModelOutputWithPast,
  56. )
  57. from ..timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig
  58. logger = logging.get_logger(__name__)
  59. @auto_docstring(checkpoint="google/gemma-3n-E4B")
  60. @strict
  61. class Gemma3nTextConfig(Gemma3TextConfig):
  62. r"""
  63. vocab_size_per_layer_input (`int`, *optional*, defaults to 262144):
  64. Vocabulary size of the per-layer text embeddings that augment the standard embeddings.
  65. hidden_size_per_layer_input (`int`, *optional*, defaults to 256):
  66. Dimension of the hidden representations for per-layer emebeddings.
  67. altup_active_idx (`int`, *optional*, defaults to 0):
  68. The index of the prediction from which AltUp will compute additional predictions or correct
  69. altup_coef_clip (`float`, *optional*, defaults to 120.0):
  70. The maximum amplitude of an AltUp prediction or correction coefficient weight.
  71. altup_correct_scale (`bool`, *optional*, defaults to `True`):
  72. If True, apply the `AltUp.correct_output_scale` to the corrected prediction at `altup_active_idx`.
  73. altup_num_inputs (`int`, *optional*, defaults to 4):
  74. The number of predictions that AltUp should be make given the input sequence.
  75. num_kv_shared_layers (`int`, *optional*, defaults to 15):
  76. The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers`
  77. layers in the model "share" the KV values in that each local and global layer in this range uses the KV
  78. cache values computed for the last local or global layer, respectively, before entering this range. The
  79. value should be a multiple of the attention pattern size (see `layer_types` parameter).
  80. laurel_rank (int, *optional*, defaults to 64):
  81. The intermediate size for the linear projections in the Learned Augmented Residual Layer.
  82. activation_sparsity_pattern (Sequence[float], *optional*):
  83. The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must
  84. explicitly provide a sparsity value for each layer in the model. By default, the first 10 layers are
  85. sparse with a sparsity factor of 0.95 and the rest are dense.
  86. ```python
  87. >>> from transformers import Gemma3nTextModel, Gemma3nTextConfig
  88. >>> # Initializing a Gemma3nText gemma3n_text-E4B style configuration
  89. >>> configuration = Gemma3nTextConfig()
  90. >>> # Initializing a model from the gemma3n_text-E4B style configuration
  91. >>> model = Gemma3nTextModel(configuration)
  92. >>> # Accessing the model configuration
  93. >>> configuration = model.config
  94. ```
  95. """
  96. model_type = "gemma3n_text"
  97. base_model_tp_plan = {
  98. "layers.*.self_attn.q_proj": "colwise",
  99. "layers.*.self_attn.k_proj": "colwise",
  100. "layers.*.self_attn.v_proj": "colwise",
  101. "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
  102. "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
  103. "layers.*.self_attn.v_norm": "replicated_with_grad_allreduce",
  104. "layers.*.self_attn.o_proj": "rowwise",
  105. "layers.*.mlp.gate_proj": "colwise",
  106. "layers.*.mlp.up_proj": "colwise",
  107. "layers.*.mlp.down_proj": "rowwise",
  108. }
  109. default_theta = {"global": 1_000_000.0, "local": 10_000.0}
  110. vocab_size: int = 262_400
  111. vocab_size_per_layer_input: int = 262_144
  112. hidden_size: int = 2048
  113. hidden_size_per_layer_input: int = 256
  114. intermediate_size: int | list[int] = 16_384
  115. num_hidden_layers: int = 35
  116. num_key_value_heads: int = 2
  117. max_position_embeddings: int = 32_768
  118. sliding_window: int = 512
  119. layer_types: list[str] | None = None
  120. final_logit_softcapping: float = 30.0
  121. altup_active_idx: int = 0
  122. altup_coef_clip: float = 120.0
  123. altup_correct_scale: bool = True
  124. altup_num_inputs: int = 4
  125. num_kv_shared_layers: int = 15
  126. laurel_rank: int = 64
  127. activation_sparsity_pattern: float | list[float] | None = None
  128. attn_logit_softcapping = AttributeError()
  129. use_bidirectional_attention = AttributeError()
  130. query_pre_attn_scalar = AttributeError()
  131. def __post_init__(self, **kwargs):
  132. if (
  133. isinstance(self.intermediate_size, Sequence)
  134. and (intsize_len := len(self.intermediate_size)) != self.num_hidden_layers
  135. ):
  136. raise ValueError(
  137. "intermediate_size must have an explicit intermediate size for every layer or one for all layers. "
  138. f"Expected {self.num_hidden_layers} values but got {intsize_len}."
  139. )
  140. elif not isinstance(self.intermediate_size, Sequence):
  141. self.intermediate_size = [self.intermediate_size] * self.num_hidden_layers
  142. if self.layer_types is None:
  143. self.layer_types = [
  144. "full_attention" if (i + 1) % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers)
  145. ]
  146. if self.activation_sparsity_pattern is None:
  147. num_sparse_layers = 10 if self.num_hidden_layers > 10 else 0
  148. self.activation_sparsity_pattern = [0.95] * num_sparse_layers + [0.0] * (
  149. self.num_hidden_layers - num_sparse_layers
  150. )
  151. if (len_asp := len(self.activation_sparsity_pattern)) != self.num_hidden_layers:
  152. raise ValueError(
  153. "activation_sparsity_pattern must have an explicit activation sparsity value for every layer."
  154. f"Expected {self.num_hidden_layers} values but got {len_asp}."
  155. )
  156. PreTrainedConfig.__post_init__(**kwargs)
  157. def convert_rope_params_to_dict(self, **kwargs):
  158. rope_scaling = kwargs.pop("rope_scaling", None)
  159. # Try to set `rope_scaling` if available, otherwise use `rope_parameters`. If we find `rope_parameters`
  160. # as arg in the inputs, we can safely assume that it is in the new format. New naming used -> new format
  161. default_rope_params = {
  162. "sliding_attention": {"rope_type": "default"},
  163. "full_attention": {"rope_type": "default"},
  164. }
  165. self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else default_rope_params
  166. if rope_scaling is not None:
  167. self.rope_parameters["full_attention"].update(rope_scaling)
  168. # Set default values if not present
  169. if self.rope_parameters.get("full_attention") is None:
  170. self.rope_parameters["full_attention"] = {"rope_type": "default"}
  171. self.rope_parameters["full_attention"].setdefault(
  172. "rope_theta", kwargs.pop("rope_theta", self.default_theta["global"])
  173. )
  174. if self.rope_parameters.get("sliding_attention") is None:
  175. self.rope_parameters["sliding_attention"] = {"rope_type": "default"}
  176. self.rope_parameters["sliding_attention"].setdefault(
  177. "rope_theta", kwargs.pop("rope_local_base_freq", self.default_theta["local"])
  178. )
  179. # Standardize and validate the correctness of rotary position embeddings parameters
  180. self.standardize_rope_params()
  181. return kwargs
  182. @auto_docstring(checkpoint="google/gemma-3n-E4B")
  183. @strict
  184. class Gemma3nAudioConfig(PreTrainedConfig):
  185. r"""
  186. vocab_offset (`int`, *optional*, defaults to 262272):
  187. Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
  188. 0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
  189. input_feat_size (`int`, *optional*, defaults to 128):
  190. The number of channels in each mel-spectrogram frame.
  191. gradient_clipping (`float`, *optional*, defaults to 10000000000.0):
  192. Clipping value used to stabilize extremely large gradient values.
  193. conf_attention_chunk_size (`int`, *optional*, defaults to 12):
  194. The sub-sequence size for local attention processing inside the Conformer ("conf") section of the
  195. Universal Speech Model.
  196. conf_attention_context_left (`int`, *optional*, defaults to 13):
  197. The left context size of the local attention inside the Conformer ("conf") section of the
  198. Universal Speech Model.
  199. conf_attention_context_right (`int`, *optional*, defaults to 0):
  200. The right context size of the local attention inside the Conformer ("conf") section of the
  201. Universal Speech Model.
  202. conf_attention_logit_cap (`float`, *optional*, defaults to 50.0):
  203. Logit cap applied during local attention inside the Conformer ("conf") section of the
  204. Universal Speech Model.
  205. conf_num_attention_heads (`int`, *optional*, defaults to 8):
  206. The number of attention heads in local attention inside the Conformer ("conf") section of the
  207. Universal Speech Model.
  208. conf_num_hidden_layers (`int`, *optional*, defaults to 12):
  209. The number of layers that use local attention inside the Conformer ("conf") section of the
  210. Universal Speech Model.
  211. conf_conv_kernel_size (`int`, *optional*, defaults to 5):
  212. Convolution kernel size for the conformer block inside the Conformer ("conf") section of the
  213. Universal Speech Model.
  214. conf_reduction_factor (`int`, *optional*, defaults to 4):
  215. Reduction factor used in the conformer block inside the Conformer ("conf") section of the
  216. Universal Speech Model.
  217. conf_residual_weight (`float`, *optional*, defaults to 0.5):
  218. Residual connection weight inside the Conformer ("conf") section of the
  219. Universal Speech Model.
  220. sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`):
  221. The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection
  222. ("sscp") section of the Universal Speech Model.
  223. sscp_conv_group_norm_eps (`float`, *optional*, defaults to 0.001):
  224. Epsilon used in group normalization in the subsample convolution projection in the Sub-sample Convolution
  225. Projection ("sscp") section of the Universal Speech Model.
  226. sscp_conv_kernel_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((3, 3), (3, 3))`):
  227. Kernel sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
  228. Convolution Projection ("sscp") section of the Universal Speech Model. The kernel sizes are specified as a
  229. tuple of height and width for each layer, where the height corresponds to the time dimension and the width
  230. corresponds to the frequency dimension.
  231. sscp_conv_stride_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((2, 2), (2, 2))`):
  232. Stride sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
  233. Convolution Projection ("sscp") section of the Universal Speech Model. The stride sizes are specified as a
  234. tuple of height and width for each layer, where the height corresponds to the time dimension and the width
  235. corresponds to the frequency dimension.
  236. Example:
  237. ```python
  238. >>> from transformers import Gemma3nAudioConfig, Gemma3nAudioEncoder
  239. >>> # Initializing a Gemma3nAudioEncoder gemma3n_audio-E4B-style configuration
  240. >>> configuration = Gemma3nAudioConfig()
  241. >>> # Initializing a model from the gemma3n_audio-E4B style configuration
  242. >>> model = Gemma3nAudioEncoder(configuration)
  243. >>> # Accessing the model configuration
  244. >>> configuration = model.config
  245. ```
  246. """
  247. model_type = "gemma3n_audio"
  248. vocab_size: int = 128
  249. vocab_offset: int = 262_144 + 128 # text vocab size + vision vocab size
  250. input_feat_size: int = 128
  251. hidden_size: int = 1536
  252. rms_norm_eps: float = 1e-6
  253. gradient_clipping: float = 10_000_000_000.0
  254. conf_attention_chunk_size: int = 12
  255. conf_attention_context_left: int = 13
  256. conf_attention_context_right: int = 0
  257. conf_attention_logit_cap: float = 50.0
  258. conf_num_attention_heads: int = 8
  259. conf_num_hidden_layers: int = 12
  260. conf_conv_kernel_size: int = 5
  261. conf_reduction_factor: int = 4
  262. conf_residual_weight: float = 0.5
  263. sscp_conv_channel_size: list[int] | tuple[int, int] = (128, 32)
  264. sscp_conv_group_norm_eps: float = 1e-3
  265. sscp_conv_kernel_size: list | tuple[tuple[int, int], tuple[int, int]] = (
  266. (3, 3),
  267. (3, 3),
  268. )
  269. sscp_conv_stride_size: list | tuple[tuple[int, int], tuple[int, int]] = (
  270. (2, 2),
  271. (2, 2),
  272. )
  273. @auto_docstring(checkpoint="google/gemma-3n-E4B")
  274. @strict
  275. class Gemma3nVisionConfig(TimmWrapperConfig):
  276. r"""
  277. architecture (`str`, *optional*, defaults to `"resnet50"`):
  278. The timm architecture to load.
  279. do_pooling (`bool`, *optional*, defaults to `True`):
  280. Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not.
  281. model_args (`dict[str, Any]`, *optional*):
  282. Additional keyword arguments to pass to the `timm.create_model` function. e.g. `model_args={"depth": 3}`
  283. for `timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k` to create a model with 3 blocks. Defaults to `None`.
  284. vocab_offset (`int`, *optional*, defaults to 262144):
  285. Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
  286. 0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
  287. Example:
  288. ```python
  289. >>> from transformers import Gemma3nVisionConfig, TimmWrapper
  290. >>> # Initializing a TimmWrapper gemma3n_vision-E4B-style configuration
  291. >>> configuration = Gemma3nVisionConfig()
  292. >>> # Initializing a gemma3n_vision-E4B-style TimmWrapper from the configuration
  293. >>> model = TimmWrapper(configuration)
  294. >>> # Accessing the model configuration
  295. >>> configuration = model.config
  296. ```
  297. """
  298. model_type = "gemma3n_vision"
  299. initializer_range: float = 0.02
  300. do_pooling: bool = False
  301. architecture: str = "mobilenetv5_300m_enc"
  302. hidden_size: int = 2048
  303. vocab_size: int = 128
  304. vocab_offset: int = 262_144
  305. rms_norm_eps: float = 1e-06
  306. model_args: dict | None = None
  307. @auto_docstring(checkpoint="google/gemma-3n-E4B")
  308. @strict
  309. class Gemma3nConfig(PreTrainedConfig):
  310. r"""
  311. audio_soft_tokens_per_image (`int`, *optional*, defaults to 188):
  312. The number of soft tokens per audio clip.
  313. vision_soft_tokens_per_image (`int`, *optional*, defaults to 256):
  314. The number of soft tokens per image.
  315. boi_token_id (`int`, *optional*, defaults to 255999):
  316. The begin-of-image token index to wrap the image prompt.
  317. eoi_token_id (`int`, *optional*, defaults to 262144):
  318. The end-of-image token index to wrap the image prompt.
  319. boa_token_id (`int`, *optional*, defaults to 256000):
  320. The begin-of-audio token index to wrap the audio prompt.
  321. eoa_token_id (`int`, *optional*, defaults to 262272):
  322. The end-of-audio token index to wrap the audio prompt.
  323. Example:
  324. ```python
  325. >>> from transformers import Gemma3nForConditionalGeneration, Gemma3nConfig, Gemma3nTextConfig
  326. >>> # Initializing a MobileNet vision config, which is loaded from TIMM
  327. >>> vision_config = Gemma3nVisionConfig()
  328. >>> # Initializing a Gemma3n Audio config
  329. >>> audio_config = Gemma3nAudioConfig()
  330. >>> # Initializing a Gemma3n Text config
  331. >>> text_config = Gemma3nTextConfig()
  332. >>> # Initializing a Gemma3n gemma-3-4b style configuration
  333. >>> configuration = Gemma3nConfig(text_config, vision_config, audio_config)
  334. >>> # Initializing a model from the gemma-3-4b style configuration
  335. >>> model = Gemma3nTextConfig(configuration)
  336. >>> # Accessing the model configuration
  337. >>> configuration = model.config
  338. ```"""
  339. model_type = "gemma3n"
  340. sub_configs = {
  341. "text_config": Gemma3nTextConfig,
  342. "vision_config": Gemma3nVisionConfig,
  343. "audio_config": Gemma3nAudioConfig,
  344. }
  345. text_config: Gemma3nTextConfig | dict[str, Any] | None = None
  346. vision_config: Gemma3nVisionConfig | dict[str, Any] | None = None
  347. audio_config: Gemma3nAudioConfig | dict[str, Any] | None = None
  348. audio_soft_tokens_per_image: int | None = 188
  349. vision_soft_tokens_per_image: int | None = 256
  350. boi_token_id: int | None = 255_999
  351. eoi_token_id: int | None = 262_144
  352. image_token_id: int | None = 262_145
  353. boa_token_id: int | None = 256_000
  354. eoa_token_id: int | None = 262_272
  355. audio_token_id: int | None = 262_273
  356. initializer_range: float | None = 0.02
  357. tie_word_embeddings: bool | None = True
  358. def __post_init__(self, **kwargs):
  359. if self.text_config is None:
  360. self.text_config = Gemma3nTextConfig()
  361. logger.info("text_config is None, using default Gemma3nTextConfig text config.")
  362. elif isinstance(self.text_config, dict):
  363. self.text_config = Gemma3nTextConfig(**self.text_config)
  364. if isinstance(self.vision_config, dict):
  365. self.vision_config = Gemma3nVisionConfig(**self.vision_config)
  366. elif self.vision_config is None:
  367. self.vision_config = Gemma3nVisionConfig()
  368. logger.info("vision_config is None, using default Gemma3nVisionConfig vision config.")
  369. if isinstance(self.audio_config, dict):
  370. self.audio_config = Gemma3nAudioConfig(**self.audio_config)
  371. elif self.audio_config is None:
  372. self.audio_config = Gemma3nAudioConfig()
  373. logger.info("audio_config is None. Using default Gemma3nAudioConfig.")
  374. super().__post_init__(**kwargs)
  375. @dataclass
  376. @auto_docstring
  377. class Gemma3nAudioEncoderModelOutput(BaseModelOutputWithPooling):
  378. r"""
  379. audio_mel_mask (`torch.BoolTensor`, *optional*):
  380. A torch.BoolTensor of shape `(batch_size, num_frames)`
  381. """
  382. audio_mel_mask: torch.BoolTensor | None = None
  383. class Gemma3nModelOutputWithPast(PaligemmaModelOutputWithPast):
  384. r"""
  385. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  386. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  387. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  388. `past_key_values` input) to speed up sequential decoding.
  389. image_hidden_states (`torch.FloatTensor`, *optional*):
  390. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  391. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  392. audio_hidden_states (`torch.FloatTensor`, *optional*):
  393. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  394. audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
  395. """
  396. audio_hidden_states: torch.FloatTensor | None = None
  397. class Gemma3nCausalLMOutputWithPast(PaliGemmaCausalLMOutputWithPast):
  398. r"""
  399. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  400. Language modeling loss (for next-token prediction).
  401. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
  402. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  403. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  404. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  405. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  406. `past_key_values` input) to speed up sequential decoding.
  407. image_hidden_states (`torch.FloatTensor`, *optional*):
  408. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  409. image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
  410. audio_hidden_states (`torch.FloatTensor`, *optional*):
  411. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  412. audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
  413. """
  414. audio_hidden_states: torch.FloatTensor | None = None
  415. class Gemma3nRMSNorm(nn.Module):
  416. def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
  417. super().__init__()
  418. self.eps = eps
  419. self.with_scale = with_scale
  420. if self.with_scale:
  421. self.weight = nn.Parameter(torch.ones(dim), requires_grad=True)
  422. def _norm(self, hidden_states: torch.Tensor):
  423. mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps
  424. # Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX
  425. return hidden_states * torch.pow(mean_squared, -0.5)
  426. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  427. normed_output = self._norm(hidden_states.float())
  428. if self.with_scale:
  429. normed_output = normed_output * self.weight.float()
  430. return normed_output.type_as(hidden_states)
  431. # ==== Audio Encoder ====
  432. class Gemma3nAudioRelativePositionEmbedding(nn.Module):
  433. def __init__(self, config: Gemma3nAudioConfig):
  434. super().__init__()
  435. self.config = config
  436. self.num_heads = self.config.conf_num_attention_heads
  437. self.channels = self.config.hidden_size
  438. self.head_dim = self.channels // self.num_heads
  439. self.max_backward = max(0, self.config.conf_attention_context_left - 1)
  440. self.max_forward = self.config.conf_attention_context_right
  441. self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
  442. min_timescale = 1.0
  443. max_timescale = 1.0e4
  444. num_timescales = self.channels // 2
  445. log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
  446. inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
  447. self.register_buffer(
  448. "inv_timescales",
  449. inv_timescales.float().unsqueeze(0).unsqueeze(0),
  450. persistent=False,
  451. )
  452. def _get_timing_signal_1d_pos(self, position: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
  453. position = position.float().unsqueeze(-1)
  454. scaled_time = position * self.inv_timescales.to(device=position.device, dtype=torch.float32)
  455. timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
  456. return timing_signal.type(dtype)
  457. def _relative_shift(
  458. self,
  459. term_bd_before_shift: torch.Tensor,
  460. batch_size: int,
  461. num_heads: int,
  462. num_query_blocks: int,
  463. query_block_size: int,
  464. key_context_size: int,
  465. max_span_plus_1: int,
  466. ) -> torch.Tensor:
  467. """Performs the relative shift.
  468. Args:
  469. term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
  470. (B), num_heads (N), num_query_blocks (U), query_block_size (W),
  471. key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1).
  472. Returns:
  473. Tensor of shape [B, N, U, W, C].
  474. """
  475. # term_bd_before_shift shape: [B, N, U, W, F_span]
  476. # Target shape after shift: [B, N, U, W, C]
  477. # Padding amount for the last dimension (F_span) to become (C + 1)
  478. # C = key_context_size
  479. # F_span = max_span_plus_1
  480. pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
  481. # PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...)
  482. # We only pad the last dimension on the right.
  483. padding_tuple = (0, pad_amount_last_dim)
  484. term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
  485. # Shape after pad: [B, N, U, W, C+1]
  486. # Reshape for slicing (emulating JAX's behavior)
  487. # [B, N, U, W * (C+1)]
  488. term_bd_reshaped = term_bd_padded.reshape(
  489. (
  490. batch_size,
  491. num_heads,
  492. num_query_blocks,
  493. query_block_size * (key_context_size + 1),
  494. )
  495. )
  496. # Slice to effective [B, N, U, W * C]
  497. term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size]
  498. # Reshape back to [B, N, U, W, C]
  499. term_bd_shifted = term_bd_sliced.reshape(
  500. (
  501. batch_size,
  502. num_heads,
  503. num_query_blocks,
  504. query_block_size,
  505. key_context_size,
  506. )
  507. )
  508. return term_bd_shifted
  509. def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
  510. # queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim)
  511. # keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim)
  512. # C = W + L + R (key_context_size)
  513. # F_span = L + R + 1 (max_span + 1)
  514. batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
  515. _, _, key_context_size, _, _ = keys.shape
  516. # Relative positions for sinusoidal embeddings: [L, L-1, ..., -R]
  517. # Length is L+R+1 = self.max_span + 1
  518. pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze(
  519. 0
  520. ) # Shape [1, F_span]
  521. max_span_plus_1 = pos_indices.shape[1] # F_span
  522. sin_emb_timing_signal = self._get_timing_signal_1d_pos(
  523. pos_indices, dtype=queries.dtype
  524. ) # Shape [1, F_span, self.channels]
  525. # Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H]
  526. projected_sin_emb = self.pos_proj(sin_emb_timing_signal)
  527. # Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H]
  528. sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze(
  529. 0
  530. ) # Shape [F, N, H]
  531. # term_ac: Query-Key content interaction
  532. # queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul
  533. # keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul
  534. queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H]
  535. keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C]
  536. term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
  537. # term_bd: Query-Position interaction
  538. # Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
  539. # queries shape: [B, U, W, N, H]
  540. # sin_emb shape: [F, N, H]
  541. # Target output shape: [B, N, U, W, F]
  542. # Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb
  543. q_permuted = queries.permute(0, 3, 1, 2, 4)
  544. # Permute sin_emb to [N, H, F] to prepare for matmul
  545. # sin_emb original is [F, N, H]
  546. s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F]
  547. # Reshape queries for matmul: [B, N, U*W, H]
  548. q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
  549. # Perform matmul: [B, N, U*W, H] @ [N, H, F]
  550. # s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
  551. # Result: [B, N, U*W, F]
  552. term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
  553. # Reshape to target [B, N, U, W, F]
  554. term_bd_unshifed = term_bd_unshifed_matmul.reshape(
  555. batch_size,
  556. num_heads,
  557. num_query_blocks,
  558. query_block_size,
  559. max_span_plus_1,
  560. )
  561. # Apply relative shift to term_bd_unshifed
  562. term_bd_shifted = self._relative_shift(
  563. term_bd_unshifed,
  564. batch_size,
  565. num_heads,
  566. num_query_blocks,
  567. query_block_size,
  568. key_context_size,
  569. max_span_plus_1,
  570. ) # Shape [B, N, U, W, C]
  571. return term_ac + term_bd_shifted
  572. class Gemma3nAudioAttention(nn.Module):
  573. def __init__(self, config: Gemma3nAudioConfig):
  574. super().__init__()
  575. self.config = config
  576. self.num_heads = self.config.conf_num_attention_heads
  577. self.hidden_size = self.config.hidden_size
  578. self.head_dim = self.hidden_size // self.num_heads
  579. self.chunk_size = self.config.conf_attention_chunk_size
  580. self.max_future_horizon = self.config.conf_attention_context_right
  581. self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
  582. self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
  583. self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
  584. self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config)
  585. self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
  586. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  587. self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  588. self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  589. q_scale = self.head_dim**-0.5
  590. r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
  591. self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
  592. local_causal_valid_mask = self.create_local_causal_valid_mask()
  593. self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
  594. self.register_buffer(
  595. "softcap",
  596. torch.tensor(self.attention_logits_soft_cap).float(),
  597. persistent=False,
  598. )
  599. def create_local_causal_valid_mask(self):
  600. lower_causal_mask = torch.tril(
  601. torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
  602. diagonal=0,
  603. ).T
  604. upper_causal_mask = torch.tril(
  605. torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
  606. diagonal=self.max_past_horizon + self.max_future_horizon,
  607. )
  608. local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
  609. local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
  610. return local_causal_valid_mask
  611. def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
  612. batch, _, *tail_shape = x.shape
  613. left = x.new_zeros((batch, pad_left, *tail_shape))
  614. right = x.new_zeros((batch, pad_right, *tail_shape))
  615. x = torch.cat([left, x, right], dim=1)
  616. return x
  617. def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
  618. """Turns a sequence to non overlapping blocks.
  619. Args:
  620. hidden_states: a tensor of [batch, time, ...].
  621. Returns:
  622. A tensor of [batch, num_blocks, block_size, ...], with necessary
  623. paddings,
  624. where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
  625. """
  626. shape = hidden_states.shape
  627. b, t = shape[:2]
  628. num_blocks = (t + self.chunk_size - 1) // self.chunk_size
  629. if (padding_len := num_blocks * self.chunk_size - t) > 0:
  630. hidden_states = self._pad_dim1(hidden_states, 0, padding_len)
  631. permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
  632. hidden_states = hidden_states.reshape(permute_dims).contiguous()
  633. return hidden_states
  634. def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
  635. """Extracts temporal context for every block.
  636. Args:
  637. hidden_states: a tensor of [batch, time, ...].
  638. Returns:
  639. A tensor of [batch, num_blocks, context_size, ...], with necessary
  640. paddings,
  641. where context_size = block_size + left_context + right_context,
  642. and output[:, i, ...] are x[:, start-left_context:end+right_context,
  643. ...],
  644. start = i * block_size, end = (i + 1) * block_size.
  645. """
  646. pad_left = self.max_past_horizon
  647. # The JAX equivalent padding for signal.frame with pad_mode='valid' is
  648. # (left_context, right_context + block_size - 1) on the time dimension.
  649. # PyTorch's _pad_dim1 applies padding symmetrically if only one value is given,
  650. # or (pad_dim_start, pad_dim_end) if two are given.
  651. # Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
  652. # or dim 1 (time for [B,T]).
  653. # The current pad_right calculation matches the JAX effective padding.
  654. pad_right = self.max_future_horizon + self.chunk_size - 1
  655. hidden_states = self._pad_dim1(hidden_states, pad_left, pad_right)
  656. frame_len = self.context_size
  657. frame_step = self.chunk_size
  658. # Directly use unfold without the subframe_factor logic
  659. # x.unfold(dimension, size, step)
  660. # dimension=1 (time dimension, assuming x is [B, T_padded, ...])
  661. # size=frame_len (context_size)
  662. # step=frame_step (chunk_size)
  663. x_unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step)
  664. # If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
  665. # If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
  666. # We want to match JAX's typical output for such operations which might be
  667. # [B, num_blocks, frame_len, N, H] if N, H are present.
  668. # The relative_position_embedding expects keys as [B, U, C, N, H].
  669. # If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
  670. if hidden_states.ndim > 2 and x_unfolded.ndim > 3: # Check if inner dimensions (like N, H) exist
  671. # Current shape after unfold for [B, T_pad, N, H] is [B, U, N, H, C]
  672. # Target shape for keys in RPE: [B, U, C, N, H]
  673. x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2)
  674. return x_unfolded.contiguous()
  675. def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
  676. # sl.Dense uses jax.numpy.einsum("...a,abcd->...bcd") and jax.numpy.select()
  677. qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
  678. query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
  679. key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
  680. value_states = self.v_proj(hidden_states).reshape(qkv_shape).contiguous()
  681. per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale)
  682. broadcast_shape = (1, 1, 1, self.head_dim)
  683. per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape)
  684. query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast
  685. batch_size, q_time = query_states.shape[:2]
  686. query_blocks = self._convert_to_block(query_states)
  687. key_blocks = self._extract_block_context(key_states)
  688. value_blocks = self._extract_block_context(value_states)
  689. num_query_blocks = query_blocks.shape[1]
  690. # 1. Create a mask indicating originally valid positions.
  691. original_valid_mask = ~mask # True for valid, False for padded
  692. # 2. Extract blocks from this validity mask.
  693. extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask)
  694. # If subframe_factor was used in _extract_block_context for a [B, T] input mask,
  695. # the shape might be [B, U, C/SF, SF]. Reshape to [B, U, C].
  696. # batch_size and num_query_blocks are known from query_blocks.
  697. # self.context_size is C.
  698. if (
  699. extracted_valid_mask_blocks.ndim == 4
  700. and extracted_valid_mask_blocks.shape[2] * extracted_valid_mask_blocks.shape[3] == self.context_size
  701. ):
  702. extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
  703. batch_size, num_query_blocks, self.context_size
  704. )
  705. # After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask.
  706. # This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
  707. # but for the mask case, this should hold.
  708. if extracted_valid_mask_blocks.shape != (
  709. batch_size,
  710. num_query_blocks,
  711. self.context_size,
  712. ):
  713. raise ValueError(
  714. "Shape of extracted_valid_mask_blocks"
  715. f" {extracted_valid_mask_blocks.shape} is not ({batch_size},"
  716. f" {num_query_blocks}, {self.context_size}) after potential reshape."
  717. )
  718. # 3. Expand dimensions for broadcasting with logits and causal mask.
  719. # Target shape for broadcasting with logits [B,N,U,W,C]
  720. # extracted_valid_mask_blocks to [B, 1, U, 1, C]
  721. condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(1).unsqueeze(-2)
  722. # self.local_causal_valid_mask is [W, C], True where allowed by local window.
  723. # Expand to [1, 1, 1, W, C]
  724. condition_from_causality = self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
  725. # 4. Combine the two conditions.
  726. # final_condition will be True where a key is *both* originally valid *and* causally accessible.
  727. # Broadcasts to [B, 1, U, W, C]
  728. final_condition_for_where = torch.logical_and(
  729. condition_from_input_validity,
  730. condition_from_causality.to(condition_from_input_validity.device), # Ensure same device
  731. )
  732. # Embed queries and keys
  733. logits = self.relative_position_embedding(query_blocks, key_blocks)
  734. # Apply attention logit softcap
  735. # Ensure softcap is on the same device as logits
  736. softcap_val = self.softcap.to(logits.device)
  737. logits = logits / softcap_val
  738. logits = torch.tanh(logits)
  739. logits = logits * softcap_val
  740. # Apply the combined mask.
  741. # final_condition_for_where will broadcast with logits [B,N,U,W,C]
  742. logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
  743. probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
  744. # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
  745. b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
  746. h_dim = value_blocks.shape[-1]
  747. prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
  748. v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
  749. result_bmm = torch.bmm(prob_bun, v_bun)
  750. context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4)
  751. context_vectors = context_vectors.reshape(
  752. (
  753. batch_size,
  754. num_query_blocks * self.chunk_size,
  755. self.num_heads,
  756. self.head_dim,
  757. )
  758. )
  759. context_vectors = context_vectors[:, :q_time]
  760. return context_vectors
  761. class Gemma3nAudioCumulativeGroupNorm(nn.Module):
  762. """Applies Group Normalization cumulatively over the time dimension.
  763. This layer normalizes the input by calculating the mean and variance
  764. cumulatively over the time dimension (dim 1). The statistics are computed
  765. over all feature dimensions (specified by `feature_dims` and `num_channels`)
  766. for elements marked as valid by the optional `mask`.
  767. If a `mask` is provided (True for valid, False for invalid/padded),
  768. invalid time steps do not contribute to the statistics calculation, and
  769. their corresponding output values are zeroed out.
  770. Scale and bias, if enabled, are applied per-channel (last dimension).
  771. This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
  772. and `cumulative=True`.
  773. """
  774. def __init__(
  775. self,
  776. num_channels: int, # Number of channels (size of the last dimension)
  777. feature_dims: Sequence[int], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
  778. eps: float = 1e-3,
  779. ):
  780. super().__init__()
  781. self.num_channels = num_channels
  782. self.feature_dims = tuple(feature_dims)
  783. self.eps = eps
  784. # Scale parameter depends only on the channel dimension
  785. self.weight = nn.Parameter(torch.ones(num_channels))
  786. # Axes for normalization: all dimensions except Batch (0) and Time (1).
  787. # For input [B, T, *feature_dims, C], these are dims from 2 onwards.
  788. self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
  789. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  790. """Applies cumulative group norm, optionally using a mask.
  791. Args:
  792. hidden_states: Input tensor, shape [B, T, *feature_dims, C].
  793. Returns:
  794. Normalized tensor with the same shape as x.
  795. """
  796. expected_input_suffix = self.feature_dims + (self.num_channels,)
  797. if hidden_states.shape[2:] != expected_input_suffix:
  798. raise ValueError(
  799. f"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected"
  800. f" suffix (feature_dims + num_channels) {expected_input_suffix}"
  801. )
  802. input_dtype = hidden_states.dtype
  803. # Calculations are performed in float32 for numerical stability.
  804. calc_dtype = torch.float32
  805. x_calc = hidden_states.to(calc_dtype)
  806. # Prepare a broadcastable mask (`mask_calc`).
  807. # If no mask is provided, treat all elements as valid
  808. # (mask_calc is all ones).
  809. # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
  810. mask_calc = torch.ones_like(x_calc, dtype=calc_dtype)
  811. # Cumulative Statistics Calculation
  812. # 1. Sum of values over reduction axes at each time step.
  813. sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True)
  814. # 2. Cumulative sum of values over time.
  815. cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
  816. # 3. Count of valid elements in the normalization group at each time step.
  817. # (A "group" here consists of all features at a given Batch, Time).
  818. elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
  819. # 4. Cumulative count of valid elements over time.
  820. cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
  821. # Avoid division by zero if all preceding elements were masked.
  822. safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0)
  823. # 5. Cumulative mean.
  824. cum_mean = cum_sum_values / safe_cum_count_elements
  825. # 6. Sum of squared differences from the cumulative mean.
  826. # Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
  827. # Using x_calc here for the difference, as cum_mean already accounts for masking.
  828. squared_diff_from_mean = (x_calc - cum_mean).pow(2)
  829. sum_sq_diff_at_t = torch.sum(squared_diff_from_mean, dim=self.reduction_axes, keepdim=True)
  830. # 7. Cumulative sum of squared differences over time.
  831. cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1)
  832. # 8. Cumulative variance.
  833. cum_variance = cum_sum_sq_diff / safe_cum_count_elements
  834. # Normalize the input using the calculated cumulative statistics:
  835. # (x - E[x]) / sqrt(Var[x] + eps)
  836. normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps)
  837. # Apply affine transformation (scale and bias) if enabled.
  838. # Scale and bias are applied per-channel (last dimension).
  839. scale = self.weight.to(calc_dtype)
  840. # Reshape for broadcasting: [C] -> [1, ..., 1, C]
  841. scale_view_shape = [1] * (hidden_states.dim() - 1) + [self.num_channels]
  842. normalized_x = normalized_x * scale.view(scale_view_shape)
  843. # Zero out outputs for time steps that were originally masked (where mask_calc is 0).
  844. # This ensures padded/invalid positions in the input result in zero output.
  845. final_output = normalized_x * mask_calc
  846. return final_output.to(input_dtype)
  847. class Gemma3nAudioSSCPConvBlock(nn.Module):
  848. """A single convolution block for the SubSampleConvProjection.
  849. This block consists of a 2D convolution, followed by CumulativeGroupNorm,
  850. and a ReLU activation. It handles manual padding for the convolution.
  851. """
  852. def __init__(
  853. self,
  854. config: Gemma3nAudioConfig,
  855. idx: int,
  856. input_freq_dim: int, # Changed from input_spatial_dim
  857. manual_padding: tuple[int, int, int, int] = (0, 0, 0, 0),
  858. ):
  859. super().__init__()
  860. self.config = config
  861. self.manual_padding = manual_padding
  862. # in_channels is 1 for the first block, or C_out from previous block's conv
  863. in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
  864. out_channels = self.config.sscp_conv_channel_size[idx]
  865. kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
  866. stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
  867. self.conv = nn.Conv2d(
  868. in_channels=in_channels,
  869. out_channels=out_channels,
  870. kernel_size=(
  871. kernel_h,
  872. kernel_w,
  873. ), # Kernel (kH, kW) operates on (Time, Freq_dim)
  874. stride=(stride_h, stride_w),
  875. padding=(0, 0), # Manual padding is used
  876. bias=False,
  877. )
  878. # Calculate output frequency dimension (f_out_conv) after this convolution.
  879. # input_freq_dim is the unpadded width (feature dimension).
  880. # self.manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
  881. f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1]
  882. f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
  883. self.norm = Gemma3nAudioCumulativeGroupNorm(
  884. num_channels=out_channels, # Channels of the conv output
  885. feature_dims=(f_out_conv,), # The frequency dimension size after conv
  886. eps=self.config.sscp_conv_group_norm_eps,
  887. )
  888. self.activation = nn.ReLU()
  889. def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
  890. # Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
  891. # manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
  892. # F.pad applies to last two dims: F_in then T_in
  893. audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0).to(
  894. self.conv.weight.dtype
  895. )
  896. # Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
  897. # Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
  898. audio_encodings_conv = self.conv(audio_encodings_padded)
  899. # Expected conv output shape: [B, C_out, T_out, F_out]
  900. # Input to norm is [B, T_out, F_out, C_out]
  901. x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous()
  902. x_normed = self.norm(x_for_norm)
  903. # Output of norm is [B, T_out, F_out, C_out], permute back to [B, C_out, T_out, F_out]
  904. audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous()
  905. return self.activation(audio_encodings_normed)
  906. class Gemma3nAudioSubSampleConvProjection(nn.Module):
  907. def __init__(self, config: Gemma3nAudioConfig):
  908. super().__init__()
  909. self.config = config
  910. current_f_for_block_input = config.input_feat_size # Start with original feature dim
  911. calculated_block_padding = []
  912. calculated_f_out_dims = [] # Tracking frequency dimension output sizes
  913. for i in range(2): # Assuming 2 conv layers as per sscp_conv_... arrays
  914. kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
  915. stride_h, stride_w = config.sscp_conv_stride_size[i]
  916. # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
  917. # JAX 'reverse_causal' padding is (0, kernel_size - 1)
  918. pad_t_top = 0
  919. pad_t_bottom = kernel_h - 1
  920. # Frequency Padding (Width for Conv2d)
  921. # Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
  922. # and the successful test configuration.
  923. # If kernel/stride/input_freq for frequency changes, this might need re-evaluation
  924. # to match generic JAX 'SAME' behavior if it differs.
  925. pad_f_left = 1
  926. pad_f_right = 1
  927. manual_padding_tuple = (
  928. pad_f_left,
  929. pad_f_right,
  930. pad_t_top,
  931. pad_t_bottom,
  932. )
  933. calculated_block_padding.append(manual_padding_tuple)
  934. # Calculate output frequency dimension after this convolution
  935. # This uses the actual padding applied and kernel/stride.
  936. f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
  937. f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 # Assuming dilation_w = 1
  938. calculated_f_out_dims.append(f_out_after_conv)
  939. current_f_for_block_input = f_out_after_conv
  940. self.conv_0 = Gemma3nAudioSSCPConvBlock(
  941. idx=0,
  942. input_freq_dim=config.input_feat_size, # Pass original feature dim
  943. config=config,
  944. manual_padding=calculated_block_padding[0],
  945. )
  946. self.conv_1 = Gemma3nAudioSSCPConvBlock(
  947. idx=1,
  948. input_freq_dim=calculated_f_out_dims[0], # Output freq dim from conv_0
  949. config=config,
  950. manual_padding=calculated_block_padding[1],
  951. )
  952. final_c_out = config.sscp_conv_channel_size[-1]
  953. final_f_out = calculated_f_out_dims[-1] # Final frequency dimension
  954. self.input_proj_in_features = final_c_out * final_f_out
  955. self.input_proj_linear = nn.Linear(self.input_proj_in_features, self.config.hidden_size, bias=False)
  956. def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
  957. # audio_encodings is [B, T, F_in]
  958. # Reshape to [B, 1, T, F_in] (Batch, Channels=1, Height=Time, Width=F_in)
  959. audio_encodings_reshaped = audio_encodings.unsqueeze(1)
  960. x = self.conv_0(audio_encodings_reshaped)
  961. x = self.conv_1(x)
  962. # x from conv_1 is [B, C_out_1, T_out_1, F_out_1]
  963. b, c_out, t_out, f_out = x.shape
  964. # Permute to [B, T_out_1, F_out_1, C_out_1] then flatten F_out_1 and C_out_1
  965. x_permuted = x.permute(0, 2, 3, 1).contiguous()
  966. output_flattened = x_permuted.view(b, t_out, f_out * c_out)
  967. output = self.input_proj_linear(output_flattened)
  968. return output
  969. class Gemma3nAudioConformerAttention(nn.Module):
  970. def __init__(self, config: Gemma3nAudioConfig):
  971. super().__init__()
  972. self.config = config
  973. self.post_in_features = self.config.hidden_size
  974. self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
  975. self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
  976. self.attn = Gemma3nAudioAttention(config)
  977. self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
  978. self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
  979. def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
  980. audio_encodings_input_to_attn = audio_encodings
  981. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  982. audio_encodings_norm = self.pre_attn_norm(audio_encodings)
  983. # Output of self.attn is [B, T, NumHeads, HeadDim]
  984. audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask)
  985. # Reshape from [B, T, NumHeads, HeadDim] to [B, T, NumHeads * HeadDim]
  986. # NumHeads * HeadDim = hidden_size
  987. b, t, num_heads, head_dim = audio_encodings_attn_out.shape
  988. audio_encodings_reshaped = audio_encodings_attn_out.reshape(b, t, num_heads * head_dim)
  989. audio_encodings = self.post(audio_encodings_reshaped)
  990. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  991. return audio_encodings_input_to_attn + self.post_norm(audio_encodings)
  992. class Gemma3nAudioConformerFeedForward(nn.Module):
  993. def __init__(self, config: Gemma3nAudioConfig):
  994. super().__init__()
  995. self.config = config
  996. self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
  997. self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
  998. self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
  999. self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
  1000. self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
  1001. self.post_layer_scale = self.config.conf_residual_weight
  1002. def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
  1003. residual = audio_encodings
  1004. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  1005. audio_encodings = self.pre_layer_norm(audio_encodings)
  1006. audio_encodings: torch.Tensor = self.ffw_layer_1(audio_encodings)
  1007. audio_encodings = nn.functional.silu(audio_encodings)
  1008. audio_encodings: torch.Tensor = self.ffw_layer_2(audio_encodings)
  1009. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  1010. audio_encodings = self.post_layer_norm(audio_encodings)
  1011. return residual + (audio_encodings * self.post_layer_scale)
  1012. class Gemma3nAudioConformerLightConv1d(nn.Module):
  1013. def __init__(self, config: Gemma3nAudioConfig):
  1014. super().__init__()
  1015. self.config = config
  1016. self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  1017. self.linear_start = nn.Linear(self.config.hidden_size, self.config.hidden_size * 2, bias=False)
  1018. self.depthwise_conv1d = nn.Conv1d(
  1019. in_channels=self.config.hidden_size,
  1020. out_channels=self.config.hidden_size,
  1021. kernel_size=self.config.conf_conv_kernel_size,
  1022. stride=1,
  1023. padding=0, # Manual causal padding
  1024. groups=self.config.hidden_size, # Depthwise
  1025. bias=False,
  1026. )
  1027. self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
  1028. self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  1029. self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
  1030. self.causal_padding = self.config.conf_conv_kernel_size - 1
  1031. def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
  1032. audio_encodings_residual = audio_encodings # Save for residual connection
  1033. audio_encodings = self.pre_layer_norm(audio_encodings)
  1034. audio_encodings = self.linear_start(audio_encodings)
  1035. audio_encodings = torch.nn.functional.glu(audio_encodings, dim=-1)
  1036. # Permute for Conv1d: [B, T, D] -> [B, D, T]
  1037. audio_encodings_permuted = audio_encodings.permute(0, 2, 1)
  1038. # Apply manual causal padding
  1039. audio_encodings_permuted_padded = F.pad(audio_encodings_permuted, (self.causal_padding, 0))
  1040. audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded)
  1041. # Permute back: [B, D, T_out] -> [B, T_out, D]
  1042. audio_encodings = audio_encodings.permute(0, 2, 1)
  1043. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  1044. audio_encodings = self.conv_norm(audio_encodings)
  1045. audio_encodings = nn.functional.silu(audio_encodings)
  1046. audio_encodings = self.linear_end(audio_encodings)
  1047. output = audio_encodings + audio_encodings_residual
  1048. return output
  1049. class Gemma3nAudioConformerBlock(nn.Module):
  1050. def __init__(self, config: Gemma3nAudioConfig):
  1051. super().__init__()
  1052. self.config = config
  1053. self.ffw_layer_start = Gemma3nAudioConformerFeedForward(self.config)
  1054. self.attention = Gemma3nAudioConformerAttention(self.config)
  1055. self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
  1056. self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
  1057. self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
  1058. self.norm = Gemma3nRMSNorm(self.config.hidden_size)
  1059. def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
  1060. audio_encodings = self.ffw_layer_start(audio_encodings)
  1061. audio_encodings = self.attention(audio_encodings, audio_mel_mask)
  1062. validity_mask_for_lconv = ~audio_mel_mask # True for valid
  1063. audio_encodings_for_lconv_input = audio_encodings * validity_mask_for_lconv.unsqueeze(-1).to(
  1064. audio_encodings.dtype
  1065. )
  1066. audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
  1067. audio_encodings = self.ffw_layer_end(audio_encodings)
  1068. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  1069. output = self.norm(audio_encodings)
  1070. return output
  1071. # ==== Language Model ====
  1072. class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
  1073. pass
  1074. class Gemma3nTextLaurelBlock(nn.Module):
  1075. """Learned Augmented Residual Layer"""
  1076. def __init__(self, config: Gemma3nTextConfig):
  1077. super().__init__()
  1078. self.config = config
  1079. self.linear_left = nn.Linear(self.config.hidden_size, self.config.laurel_rank, bias=False)
  1080. self.linear_right = nn.Linear(self.config.laurel_rank, self.config.hidden_size, bias=False)
  1081. self.post_laurel_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  1082. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  1083. laurel_hidden_states: torch.Tensor = self.linear_left(hidden_states)
  1084. laurel_hidden_states: torch.Tensor = self.linear_right(laurel_hidden_states)
  1085. normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states)
  1086. return hidden_states + normed_laurel_hidden_states
  1087. class Gemma3nTextMLP(Gemma2MLP):
  1088. def __init__(self, config: Gemma3nTextConfig, layer_idx: int = 0):
  1089. super().__init__(config)
  1090. self.intermediate_size = config.intermediate_size[layer_idx]
  1091. self.activation_sparsity = config.activation_sparsity_pattern[layer_idx]
  1092. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  1093. gate_proj = self.gate_proj(hidden_states)
  1094. if self.activation_sparsity > 0.0:
  1095. gate_proj = self._gaussian_topk(gate_proj)
  1096. activations = self.act_fn(gate_proj)
  1097. up_proj = self.up_proj(hidden_states)
  1098. down_proj = self.down_proj(activations * up_proj)
  1099. return down_proj
  1100. def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
  1101. target_sparsity_tensor = torch.tensor(self.activation_sparsity, dtype=torch.float32, device=inputs.device)
  1102. # normal_dist and std_multiplier are adapted from jax.scipy.stats.norm.ppf().
  1103. #
  1104. # References:
  1105. # * https://docs.jax.dev/en/latest/_autosummary/jax.scipy.stats.norm.ppf.html
  1106. # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.normal.Normal
  1107. # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.transformed_distribution.TransformedDistribution.icdf
  1108. normal_dist = torch.distributions.normal.Normal(0, 1)
  1109. std_multiplier: torch.Tensor = normal_dist.icdf(target_sparsity_tensor)
  1110. std_multiplier = std_multiplier.type(inputs.dtype)
  1111. inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
  1112. inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
  1113. cutoff_x = inputs_mean + inputs_std * std_multiplier
  1114. return nn.functional.relu(inputs - cutoff_x)
  1115. class Gemma3nTextAltUp(nn.Module):
  1116. """Alternating Updates (AltUp)
  1117. The AltUp module wraps transformer layers. The `predict` step modifies the
  1118. input to the transformer layer, and the `correct` step propagates the output
  1119. of the transformer layer to the sparsely updated dimensions.
  1120. See more in the research paper:
  1121. https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
  1122. """
  1123. def __init__(self, config: Gemma3nTextConfig):
  1124. super().__init__()
  1125. self.config = config
  1126. self.correct_output_scale = nn.Parameter(torch.zeros(self.config.hidden_size))
  1127. self.correction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs, bias=False)
  1128. self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
  1129. self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
  1130. self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  1131. self.register_buffer("router_input_scale", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
  1132. def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
  1133. router_inputs = self.router_norm(x) * self.router_input_scale
  1134. routed = self.modality_router(router_inputs)
  1135. return torch.tanh(routed.float()).type_as(x)
  1136. def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
  1137. """Predicts the output of a layer using a trainable map.
  1138. Args:
  1139. hidden_states: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
  1140. stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
  1141. Returns:
  1142. A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` containing the predictions.
  1143. """
  1144. modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
  1145. if self.training and self.config.altup_coef_clip is not None:
  1146. self.prediction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
  1147. # Project and then transpose all 2D matrices contained so that mulmat gives the correct result
  1148. all_coefs: torch.Tensor = (
  1149. self.prediction_coefs(modalities)
  1150. .reshape(*modalities.shape[:-1], self.config.altup_num_inputs, self.config.altup_num_inputs)
  1151. .permute(0, 1, 3, 2)
  1152. )
  1153. # permute hidden_states to [batch_size, num_tokens, hidden_size, altup_num_inputs]
  1154. predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs)
  1155. predictions = predictions.permute(3, 0, 1, 2) # undo the permute
  1156. predictions += hidden_states # add the original input
  1157. return predictions.contiguous().type_as(hidden_states)
  1158. def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
  1159. """Corrects the predictions relative to the
  1160. Args:
  1161. predictions: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
  1162. stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
  1163. activated: A 3D tensor of shape `[batch_size, num_tokens, hidden_size]` containing the activated inputs.
  1164. Returns:
  1165. A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` correcting the original
  1166. predictions relative to the activated input embeddings.
  1167. """
  1168. modalities = self.compute_router_modalities(activated)
  1169. innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
  1170. innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
  1171. if self.training and self.config.altup_coef_clip is not None:
  1172. weight = self.correction_coefs.weight.clamp(-self.config.altup_coef_clip, self.config.altup_coef_clip)
  1173. all_coefs = torch.nn.functional.linear(modalities, weight, bias=None) + 1.0
  1174. else:
  1175. all_coefs = self.correction_coefs(modalities) + 1.0
  1176. # all_coefs adapted from jax.numpy.einsum("...p,pi->...i", ...)
  1177. # Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
  1178. # and expand on dim1 for broadcastability
  1179. all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1)
  1180. corrected = torch.mul(innovation, all_coefs)
  1181. corrected += predictions # add the original input
  1182. return corrected.contiguous().type_as(activated)
  1183. def forward(self, corrected: torch.Tensor) -> torch.Tensor:
  1184. """
  1185. This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
  1186. (which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
  1187. `scale_corrected_output`
  1188. """
  1189. return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
  1190. def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
  1191. """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
  1192. return self.forward(corrected)
  1193. def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1):
  1194. """Applies Rotary Position Embedding to the query and key tensors.
  1195. Args:
  1196. x (`torch.Tensor`): The tensor to embed.
  1197. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  1198. sin (`torch.Tensor`): The sine part of the rotary embedding.
  1199. unsqueeze_dim (`int`, *optional*, defaults to 1):
  1200. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  1201. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  1202. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  1203. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  1204. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  1205. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  1206. Returns:
  1207. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  1208. """
  1209. cos = cos.unsqueeze(unsqueeze_dim)
  1210. sin = sin.unsqueeze(unsqueeze_dim)
  1211. return (x * cos) + (rotate_half(x) * sin)
  1212. class Gemma3nTextAttention(Gemma3Attention):
  1213. def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
  1214. super().__init__(config, layer_idx)
  1215. self.is_causal = True
  1216. del self.attn_logit_softcapping
  1217. self.scaling = 1.0
  1218. self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
  1219. first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
  1220. self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
  1221. prev_layers = config.layer_types[:first_kv_shared_layer_idx]
  1222. if self.is_kv_shared_layer:
  1223. # For shared layers, find the last non-shared layer of the same type before sharing starts
  1224. self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
  1225. self.store_full_length_kv = False
  1226. else:
  1227. self.kv_shared_layer_index = None
  1228. # For non-shared layers, store full-length kv if this is the last non-shared layer of its type
  1229. self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
  1230. config.layer_types[layer_idx]
  1231. )
  1232. def forward(
  1233. self,
  1234. hidden_states: torch.Tensor,
  1235. position_embeddings: torch.Tensor = None,
  1236. attention_mask: torch.Tensor | None = None,
  1237. past_key_values: Cache | None = None,
  1238. **kwargs: Unpack[TransformersKwargs],
  1239. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  1240. input_shape = hidden_states.shape[:-1]
  1241. hidden_shape = (*input_shape, -1, self.config.head_dim)
  1242. cos, sin = position_embeddings
  1243. query_states = self.q_proj(hidden_states).view(hidden_shape)
  1244. query_states = self.q_norm(query_states)
  1245. query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
  1246. query_states = query_states.transpose(1, 2)
  1247. # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer
  1248. if self.is_kv_shared_layer and past_key_values is not None:
  1249. key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
  1250. # Device of past layer may be different from current one
  1251. key_states = key_states.to(query_states.device)
  1252. value_states = value_states.to(query_states.device)
  1253. else:
  1254. key_states = self.k_proj(hidden_states).view(hidden_shape)
  1255. key_states = self.k_norm(key_states)
  1256. key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
  1257. key_states = key_states.transpose(1, 2)
  1258. value_states = self.v_proj(hidden_states).view(hidden_shape)
  1259. value_states = self.v_norm(value_states)
  1260. value_states = value_states.transpose(1, 2)
  1261. if past_key_values is not None:
  1262. if not self.is_kv_shared_layer:
  1263. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  1264. if self.store_full_length_kv:
  1265. if not hasattr(past_key_values, "shared_layers"):
  1266. past_key_values.shared_layers = {}
  1267. past_key_values.shared_layers[self.layer_idx] = key_states, value_states
  1268. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  1269. self.config._attn_implementation, eager_attention_forward
  1270. )
  1271. attn_output, attn_weights = attention_interface(
  1272. self,
  1273. query_states,
  1274. key_states,
  1275. value_states,
  1276. attention_mask,
  1277. dropout=self.attention_dropout if self.training else 0.0,
  1278. scaling=self.scaling,
  1279. sliding_window=self.sliding_window,
  1280. **kwargs,
  1281. )
  1282. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  1283. attn_output = self.o_proj(attn_output)
  1284. return attn_output, attn_weights
  1285. class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
  1286. def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
  1287. super().__init__(config, layer_idx)
  1288. self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx)
  1289. self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
  1290. self.act_fn = ACT2FN[config.hidden_activation]
  1291. self.altup = Gemma3nTextAltUp(config)
  1292. self.laurel = Gemma3nTextLaurelBlock(config)
  1293. self.self_attn = Gemma3nTextAttention(config, layer_idx)
  1294. self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
  1295. self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
  1296. self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1297. def forward(
  1298. self,
  1299. hidden_states: torch.Tensor,
  1300. position_embeddings: torch.Tensor = None,
  1301. per_layer_input: torch.Tensor = None,
  1302. attention_mask: torch.Tensor | None = None,
  1303. position_ids: torch.LongTensor | None = None,
  1304. past_key_values: Cache | None = None,
  1305. **kwargs: Unpack[TransformersKwargs],
  1306. ) -> tuple[torch.Tensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  1307. predictions = self.altup.predict(hidden_states)
  1308. active_prediction = predictions[self.config.altup_active_idx]
  1309. active_prediction_normed = self.input_layernorm(active_prediction)
  1310. laurel_output = self.laurel(active_prediction_normed)
  1311. attn, _ = self.self_attn(
  1312. hidden_states=active_prediction_normed,
  1313. attention_mask=attention_mask,
  1314. position_ids=position_ids,
  1315. position_embeddings=position_embeddings,
  1316. past_key_values=past_key_values,
  1317. **kwargs,
  1318. )
  1319. attn = self.post_attention_layernorm(attn)
  1320. attn_gated = active_prediction + attn
  1321. attn_laurel = (attn_gated + laurel_output) / math.sqrt(2)
  1322. attn_norm = self.pre_feedforward_layernorm(attn_laurel)
  1323. attn_ffw = self.mlp(attn_norm)
  1324. attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
  1325. attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
  1326. corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
  1327. first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
  1328. if self.config.altup_correct_scale:
  1329. first_prediction = self.altup.scale_corrected_output(first_prediction)
  1330. # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
  1331. first_prediction = self.per_layer_input_gate(first_prediction)
  1332. first_prediction = self.act_fn(first_prediction)
  1333. first_prediction = torch.multiply(first_prediction, per_layer_input)
  1334. # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
  1335. first_prediction = self.per_layer_projection(first_prediction)
  1336. first_prediction = self.post_per_layer_input_norm(first_prediction)
  1337. corrected_predictions[1:] += first_prediction
  1338. return corrected_predictions
  1339. class Gemma3nPreTrainedModel(Gemma2PreTrainedModel):
  1340. config: Gemma3nConfig
  1341. input_modalities = ("image", "text", "audio")
  1342. _no_split_modules = ["Gemma3nTextDecoderLayer"]
  1343. _can_record_outputs = {
  1344. "hidden_states": Gemma3nTextDecoderLayer,
  1345. "attentions": Gemma3nTextAttention,
  1346. }
  1347. @torch.no_grad()
  1348. def _init_weights(self, module):
  1349. PreTrainedModel._init_weights(self, module)
  1350. if isinstance(module, Gemma3nAudioCumulativeGroupNorm):
  1351. init.ones_(module.weight)
  1352. elif isinstance(module, Gemma3nAudioAttention):
  1353. init.zeros_(module.per_dim_scale)
  1354. q_scale = module.head_dim**-0.5
  1355. r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
  1356. init.copy_(module.q_scale, q_scale * r_softplus_0)
  1357. init.constant_(module.softcap, module.attention_logits_soft_cap)
  1358. init.copy_(module.local_causal_valid_mask, module.create_local_causal_valid_mask())
  1359. elif isinstance(module, Gemma3nTextScaledWordEmbedding):
  1360. init.constant_(module.embed_scale, module.scalar_embed_scale)
  1361. elif isinstance(module, Gemma3nTextAltUp):
  1362. init.zeros_(module.correct_output_scale)
  1363. init.constant_(module.router_input_scale, self.config.hidden_size**-1.0)
  1364. elif isinstance(module, Gemma3nAudioRelativePositionEmbedding):
  1365. min_timescale, max_timescale = 1.0, 1.0e4
  1366. num_timescales = module.channels // 2
  1367. log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
  1368. num_timescales - 1, 1
  1369. )
  1370. inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
  1371. init.copy_(module.inv_timescales, inv_timescales.float().unsqueeze(0).unsqueeze(0))
  1372. elif isinstance(module, Gemma3nTextModel):
  1373. init.constant_(module.per_layer_projection_scale, self.hidden_size**-0.5)
  1374. init.constant_(module.per_layer_input_scale, 1 / math.sqrt(2.0))
  1375. elif isinstance(module, Gemma3nRotaryEmbedding):
  1376. for layer_type in module.layer_types:
  1377. rope_init_fn = module.compute_default_rope_parameters
  1378. if module.rope_type[layer_type] != "default":
  1379. rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
  1380. curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
  1381. init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
  1382. init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
  1383. if hasattr(module, "gradient_clipping"):
  1384. init.constant_(module.gradient_clipping, self.config.gradient_clipping)
  1385. class Gemma3nAudioEncoder(Gemma3nPreTrainedModel):
  1386. """
  1387. An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture.
  1388. """
  1389. config: Gemma3nAudioConfig
  1390. main_input_name = "audio_mel"
  1391. input_modalities = "audio"
  1392. def __init__(self, config: Gemma3nAudioConfig):
  1393. super().__init__(config)
  1394. self.config = config
  1395. self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
  1396. self.conformer = nn.ModuleList(
  1397. [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
  1398. )
  1399. self.post_init()
  1400. @merge_with_config_defaults
  1401. @capture_outputs
  1402. def forward(
  1403. self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs: Unpack[TransformersKwargs]
  1404. ) -> tuple | Gemma3nAudioEncoderModelOutput:
  1405. """Encodes a batch of MELs.
  1406. Args:
  1407. audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
  1408. mel_bins].
  1409. Returns:
  1410. audio_encodings: a torch.Tensor of shape
  1411. `[batch_size, self.config.audio_soft_tokens_per_image,
  1412. self.config.audio_config.hidden_size]`
  1413. audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
  1414. """
  1415. audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
  1416. # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
  1417. t_sub = audio_encodings.shape[1]
  1418. time_stride_product = 1
  1419. for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
  1420. time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
  1421. # Create indices for gathering from the original mask.
  1422. # These indices map to original time steps corresponding to the start of each
  1423. # receptive field in the subsampled output.
  1424. indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
  1425. indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) # Ensure indices are valid
  1426. # Expand indices for batch compatibility if B > 1 and indices is 1D.
  1427. if audio_mel_mask.ndim > 1 and indices.ndim == 1:
  1428. indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) # [B, T_sub]
  1429. elif (
  1430. audio_mel_mask.ndim == indices.ndim
  1431. and audio_mel_mask.shape[0] == 1
  1432. and indices.shape[0] != 1
  1433. and t_sub == indices.shape[0]
  1434. ):
  1435. # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
  1436. indices = indices.unsqueeze(0)
  1437. current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub]
  1438. for block in self.conformer:
  1439. audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
  1440. if self.config.conf_reduction_factor > 1:
  1441. audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
  1442. # Reduce the mask as well
  1443. current_mask = current_mask[:, :: self.config.conf_reduction_factor]
  1444. audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
  1445. return Gemma3nAudioEncoderModelOutput(
  1446. last_hidden_state=audio_encodings,
  1447. audio_mel_mask=current_mask,
  1448. )
  1449. class Gemma3nRotaryEmbedding(Gemma3RotaryEmbedding):
  1450. pass
  1451. @auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.")
  1452. class Gemma3nTextModel(Gemma3TextModel):
  1453. config: Gemma3nTextConfig
  1454. def __init__(self, config: Gemma3nTextConfig):
  1455. super().__init__(config)
  1456. self.hidden_size = config.hidden_size
  1457. self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
  1458. self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding(
  1459. config.vocab_size_per_layer_input,
  1460. config.num_hidden_layers * config.hidden_size_per_layer_input,
  1461. self.padding_idx,
  1462. embed_scale=config.hidden_size_per_layer_input**0.5,
  1463. )
  1464. self.per_layer_model_projection = nn.Linear(
  1465. self.hidden_size,
  1466. config.num_hidden_layers * config.hidden_size_per_layer_input,
  1467. bias=False,
  1468. )
  1469. self.per_layer_projection_norm = Gemma3nRMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps)
  1470. self.layers = nn.ModuleList(
  1471. [Gemma3nTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  1472. )
  1473. self.norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  1474. self.altup_projections = nn.ModuleList(
  1475. [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
  1476. )
  1477. self.altup_unembed_projections = nn.ModuleList(
  1478. [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
  1479. )
  1480. self.register_buffer("per_layer_projection_scale", torch.tensor(self.hidden_size**-0.5), persistent=False)
  1481. self.register_buffer("per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False)
  1482. def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor:
  1483. return self.embed_tokens_per_layer(input_ids).reshape(
  1484. *input_ids.shape,
  1485. self.config.num_hidden_layers,
  1486. self.hidden_size_per_layer_input,
  1487. )
  1488. def project_per_layer_inputs(
  1489. self,
  1490. inputs_embeds: torch.Tensor,
  1491. per_layer_inputs: torch.Tensor | None = None,
  1492. ) -> torch.Tensor:
  1493. per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
  1494. per_layer_projection *= self.per_layer_projection_scale.to(
  1495. dtype=inputs_embeds.dtype, device=per_layer_projection.device
  1496. )
  1497. per_layer_projection = per_layer_projection.reshape(
  1498. *inputs_embeds.shape[:-1],
  1499. self.config.num_hidden_layers,
  1500. self.hidden_size_per_layer_input,
  1501. )
  1502. per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
  1503. if per_layer_inputs is None:
  1504. return per_layer_projection
  1505. if per_layer_projection.shape != per_layer_inputs.shape:
  1506. # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
  1507. per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
  1508. return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
  1509. dtype=inputs_embeds.dtype, device=per_layer_projection.device
  1510. )
  1511. # Last hidden states should be before reprojecting, to stay consistent with the other layer outputs
  1512. @merge_with_config_defaults
  1513. @capture_outputs(tie_last_hidden_states=False)
  1514. @auto_docstring
  1515. def forward(
  1516. self,
  1517. input_ids: torch.LongTensor | None = None,
  1518. per_layer_inputs: torch.Tensor | None = None,
  1519. attention_mask: torch.Tensor | None = None,
  1520. position_ids: torch.LongTensor | None = None,
  1521. past_key_values: Cache | None = None,
  1522. inputs_embeds: torch.FloatTensor | None = None,
  1523. use_cache: bool | None = None,
  1524. **kwargs: Unpack[TransformersKwargs],
  1525. ) -> BaseModelOutputWithPast:
  1526. r"""
  1527. per_layer_inputs (torch.Tensor, *optional*, defaults to None):
  1528. Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
  1529. """
  1530. if (input_ids is None) ^ (inputs_embeds is not None):
  1531. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1532. if input_ids is not None:
  1533. inputs_embeds = self.embed_tokens(input_ids)
  1534. per_layer_inputs = self.get_per_layer_inputs(input_ids)
  1535. per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
  1536. if use_cache and past_key_values is None:
  1537. past_key_values = DynamicCache(config=self.config)
  1538. if position_ids is None:
  1539. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1540. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  1541. position_ids = position_ids.unsqueeze(0)
  1542. # It may already have been prepared by e.g. `generate`
  1543. if not isinstance(causal_mask_mapping := attention_mask, dict):
  1544. # Prepare mask arguments
  1545. mask_kwargs = {
  1546. "config": self.config,
  1547. "inputs_embeds": inputs_embeds,
  1548. "attention_mask": attention_mask,
  1549. "past_key_values": past_key_values,
  1550. "position_ids": position_ids,
  1551. }
  1552. # Create the masks
  1553. causal_mask_mapping = {
  1554. "full_attention": create_causal_mask(**mask_kwargs),
  1555. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  1556. }
  1557. # embed positions
  1558. hidden_states_0 = inputs_embeds
  1559. # Expand hidden_states to support per-layer inputs
  1560. target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
  1561. epsilon_tensor = torch.tensor(1e-5)
  1562. temp_hidden_states = [hidden_states_0]
  1563. for i in range(1, self.config.altup_num_inputs):
  1564. # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
  1565. altup_proj = self.altup_projections[i - 1](hidden_states_0)
  1566. current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
  1567. new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
  1568. new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
  1569. current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
  1570. temp_hidden_states.append(current_hidden_state)
  1571. hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
  1572. position_embeddings = {}
  1573. for layer_type in self.config.layer_types:
  1574. position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
  1575. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  1576. causal_mask = causal_mask_mapping[self.config.layer_types[i]]
  1577. per_layer_input = per_layer_inputs[:, :, i, :]
  1578. hidden_states = decoder_layer(
  1579. hidden_states,
  1580. position_embeddings[self.config.layer_types[i]],
  1581. per_layer_input,
  1582. attention_mask=causal_mask,
  1583. position_ids=position_ids,
  1584. past_key_values=past_key_values,
  1585. **kwargs,
  1586. )
  1587. # Per-layer inputs to single output
  1588. target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
  1589. temp_hidden_states = [hidden_states[0]]
  1590. for i in range(1, self.config.altup_num_inputs):
  1591. # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
  1592. altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
  1593. current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
  1594. new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
  1595. new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
  1596. current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
  1597. temp_hidden_states.append(current_hidden_state)
  1598. hidden_states = torch.stack(temp_hidden_states)
  1599. hidden_states = torch.mean(hidden_states, dim=0)
  1600. hidden_states = self.norm(hidden_states)
  1601. return BaseModelOutputWithPast(
  1602. last_hidden_state=hidden_states,
  1603. past_key_values=past_key_values,
  1604. )
  1605. @auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")
  1606. class Gemma3nForCausalLM(Gemma3ForCausalLM):
  1607. pass
  1608. class Gemma3nMultimodalEmbedder(nn.Module):
  1609. """Embeds token ids or soft tokens for multimodal content into language model space."""
  1610. def __init__(
  1611. self,
  1612. multimodal_config: Gemma3nAudioConfig | Gemma3nVisionConfig,
  1613. text_config: Gemma3nTextConfig,
  1614. ):
  1615. super().__init__()
  1616. self.multimodal_hidden_size = multimodal_config.hidden_size
  1617. self.eps = multimodal_config.rms_norm_eps
  1618. self.vocab_offset = multimodal_config.vocab_offset
  1619. self.vocab_size = multimodal_config.vocab_size
  1620. self.text_hidden_size = text_config.hidden_size
  1621. self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size)
  1622. self.hard_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
  1623. self.soft_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
  1624. self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
  1625. self.embedding_post_projection_norm = Gemma3nRMSNorm(self.text_hidden_size, eps=self.eps, with_scale=False)
  1626. def forward(
  1627. self,
  1628. input_ids: torch.LongTensor | None = None,
  1629. inputs_embeds: torch.Tensor | None = None,
  1630. ) -> torch.Tensor:
  1631. """Embeds token ids or soft tokens for multimodal content into language model space.
  1632. Args:
  1633. input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
  1634. `[vocab_offset, vocab_offset + vocab_size)`.
  1635. inputs_embeds: A torch.Tensor containing the soft tokens to embed.
  1636. Returns:
  1637. A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
  1638. """
  1639. if (input_ids is None) ^ (inputs_embeds is not None):
  1640. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1641. if inputs_embeds is not None:
  1642. emb_norm = self.soft_embedding_norm(inputs_embeds)
  1643. else:
  1644. hard_emb = self.embedding(input_ids - self.vocab_offset)
  1645. emb_norm = self.hard_embedding_norm(hard_emb)
  1646. emb_norm_proj = self.embedding_projection(emb_norm)
  1647. return self.embedding_post_projection_norm(emb_norm_proj)
  1648. @auto_docstring(
  1649. custom_intro="""
  1650. The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
  1651. language modeling head.
  1652. """
  1653. )
  1654. class Gemma3nModel(PaliGemmaModel):
  1655. def __init__(self, config: Gemma3nConfig):
  1656. super().__init__(config)
  1657. del self.multi_modal_projector # Replaced by Gemma3nVisionEmbedder
  1658. del self.text_config_dtype
  1659. self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
  1660. self.audio_tower = AutoModel.from_config(config.audio_config)
  1661. self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config)
  1662. self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config)
  1663. @can_return_tuple
  1664. @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.")
  1665. def get_image_features(
  1666. self,
  1667. pixel_values: torch.FloatTensor,
  1668. **kwargs: Unpack[TransformersKwargs],
  1669. ) -> tuple | BaseModelOutputWithPooling:
  1670. vision_outputs = self.vision_tower(pixel_values=pixel_values, do_pooling=False, return_dict=True, **kwargs)
  1671. last_hidden_state = vision_outputs.last_hidden_state
  1672. # Convert from (batch, channels, height, width) to (batch, height * width, channels) where:
  1673. # height == width and height * width == Gemma3nConfig.vision_soft_tokens_per_image.
  1674. last_hidden_state = last_hidden_state.reshape(
  1675. last_hidden_state.shape[0],
  1676. self.config.vision_config.hidden_size,
  1677. self.config.vision_soft_tokens_per_image,
  1678. ).permute(0, 2, 1)
  1679. # Normalize and embed the soft tokens into language model space.
  1680. last_hidden_state *= self.config.vision_config.hidden_size**0.5
  1681. vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state)
  1682. return vision_outputs
  1683. def get_placeholder_mask(
  1684. self,
  1685. input_ids: torch.LongTensor | None = None,
  1686. inputs_embeds: torch.FloatTensor | None = None,
  1687. image_features: torch.FloatTensor | None = None,
  1688. audio_features: torch.FloatTensor | None = None,
  1689. ):
  1690. """
  1691. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  1692. equal to the length of multimodal features. If the lengths are different, an error is raised.
  1693. """
  1694. if input_ids is None:
  1695. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  1696. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  1697. )
  1698. special_image_mask = special_image_mask.all(-1)
  1699. special_audio_mask = (
  1700. inputs_embeds
  1701. == self.get_input_embeddings()(
  1702. torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
  1703. )
  1704. ).all(-1)
  1705. else:
  1706. special_image_mask = input_ids == self.config.image_token_id
  1707. special_audio_mask = input_ids == self.config.audio_token_id
  1708. n_image_tokens = special_image_mask.sum()
  1709. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1710. if image_features is not None:
  1711. torch_compilable_check(
  1712. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  1713. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0] * image_features.shape[1]}",
  1714. )
  1715. n_audio_tokens = special_audio_mask.sum()
  1716. special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1717. if audio_features is not None:
  1718. torch_compilable_check(
  1719. inputs_embeds[special_audio_mask].numel() == audio_features.numel(),
  1720. f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {audio_features.shape[0] * audio_features.shape[1]}",
  1721. )
  1722. return special_image_mask, special_audio_mask
  1723. @can_return_tuple
  1724. def forward(
  1725. self,
  1726. input_ids: torch.LongTensor | None = None, # text inputs
  1727. pixel_values: torch.FloatTensor | None = None, # vision inputs
  1728. input_features: torch.FloatTensor | None = None, # audio inputs
  1729. attention_mask: torch.Tensor | None = None,
  1730. input_features_mask: torch.Tensor | None = None,
  1731. position_ids: torch.LongTensor | None = None,
  1732. past_key_values: Cache | None = None,
  1733. token_type_ids: torch.LongTensor | None = None,
  1734. inputs_embeds: torch.FloatTensor | None = None,
  1735. labels: torch.LongTensor | None = None,
  1736. use_cache: bool | None = None,
  1737. **lm_kwargs: Unpack[TransformersKwargs],
  1738. ) -> Gemma3nModelOutputWithPast:
  1739. r"""
  1740. input_features_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1741. Attention mask for `input_features` where non-zero values mark valid audio frames.
  1742. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1743. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1744. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1745. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
  1746. Example:
  1747. ```python
  1748. >>> from PIL import Image
  1749. >>> import httpx
  1750. >>> from io import BytesIO
  1751. >>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
  1752. >>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224")
  1753. >>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224")
  1754. >>> prompt = "Where is the cat standing?"
  1755. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
  1756. >>> with httpx.stream("GET", url) as response:
  1757. ... image = Image.open(BytesIO(response.read()))
  1758. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  1759. >>> # Generate
  1760. >>> generate_ids = model.generate(**inputs,)
  1761. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1762. "Where is the cat standing?\nsnow"
  1763. ```
  1764. """
  1765. if (input_ids is None) ^ (inputs_embeds is not None):
  1766. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1767. if input_ids is not None:
  1768. inputs_embeds = self.get_input_embeddings()(input_ids)
  1769. # Prepare per-layer inputs from inputs_ids
  1770. per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input)
  1771. per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids))
  1772. per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens)
  1773. # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset)
  1774. vision_mask = torch.logical_and(
  1775. input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset
  1776. )
  1777. dummy_vision_token_id = self.embed_vision.vocab_offset + self.embed_vision.vocab_size - 1
  1778. vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device)
  1779. vision_embeds = self.embed_vision(input_ids=vision_input_ids)
  1780. vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
  1781. expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds)
  1782. inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds)
  1783. # Handle audio tokens (>= embed_audio.vocab_offset)
  1784. audio_mask = input_ids >= self.embed_audio.vocab_offset
  1785. dummy_audio_token_id = self.embed_audio.vocab_offset + self.embed_audio.vocab_size - 1
  1786. audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device)
  1787. audio_embeds = self.embed_audio(input_ids=audio_input_ids)
  1788. audio_embeds = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
  1789. expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
  1790. inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds)
  1791. else:
  1792. per_layer_inputs = None
  1793. # Merge text and images
  1794. if pixel_values is not None:
  1795. image_features = self.get_image_features(pixel_values, return_dict=True).pooler_output
  1796. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  1797. special_image_mask, _ = self.get_placeholder_mask(
  1798. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  1799. )
  1800. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  1801. # Merge text and audio
  1802. if input_features is not None and input_features_mask is not None:
  1803. audio_outputs = self.get_audio_features(input_features, ~input_features_mask, return_dict=True)
  1804. audio_features = audio_outputs.pooler_output
  1805. audio_mask = audio_outputs.audio_mel_mask
  1806. # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
  1807. # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
  1808. # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
  1809. # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
  1810. # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
  1811. audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device)
  1812. audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
  1813. audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features)
  1814. audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
  1815. extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len
  1816. extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim)
  1817. audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
  1818. audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
  1819. _, special_audio_mask = self.get_placeholder_mask(
  1820. input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features
  1821. )
  1822. inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
  1823. outputs = self.language_model(
  1824. input_ids=None,
  1825. per_layer_inputs=per_layer_inputs,
  1826. attention_mask=attention_mask,
  1827. position_ids=position_ids,
  1828. past_key_values=past_key_values,
  1829. inputs_embeds=inputs_embeds,
  1830. use_cache=use_cache,
  1831. return_dict=True,
  1832. **lm_kwargs,
  1833. )
  1834. return Gemma3nModelOutputWithPast(
  1835. last_hidden_state=outputs.last_hidden_state,
  1836. past_key_values=outputs.past_key_values if use_cache else None,
  1837. hidden_states=outputs.hidden_states,
  1838. attentions=outputs.attentions,
  1839. image_hidden_states=image_features if pixel_values is not None else None,
  1840. audio_hidden_states=audio_features if input_features is not None else None,
  1841. )
  1842. @can_return_tuple
  1843. @auto_docstring(custom_intro="Projects the last hidden state from the audio encoder into language model space.")
  1844. def get_audio_features(
  1845. self,
  1846. input_features: torch.Tensor,
  1847. input_features_mask: torch.Tensor,
  1848. **kwargs: Unpack[TransformersKwargs],
  1849. ) -> tuple | Gemma3nAudioEncoderModelOutput:
  1850. r"""
  1851. input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`):
  1852. The tensors corresponding to the input audio.
  1853. input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
  1854. The attention mask for the input audio.
  1855. """
  1856. audio_outputs: Gemma3nAudioEncoderModelOutput = self.audio_tower(
  1857. input_features, input_features_mask, return_dict=True, **kwargs
  1858. )
  1859. audio_embeds = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state)
  1860. audio_outputs.pooler_output = audio_embeds
  1861. return audio_outputs
  1862. @auto_docstring(
  1863. custom_intro="""
  1864. The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
  1865. head.
  1866. """
  1867. )
  1868. class Gemma3nForConditionalGeneration(PaliGemmaForConditionalGeneration):
  1869. @can_return_tuple
  1870. @auto_docstring
  1871. def forward(
  1872. self,
  1873. input_ids: torch.LongTensor | None = None, # text inputs
  1874. pixel_values: torch.FloatTensor | None = None, # vision inputs
  1875. input_features: torch.FloatTensor | None = None, # audio inputs
  1876. attention_mask: torch.Tensor | None = None,
  1877. input_features_mask: torch.Tensor | None = None,
  1878. position_ids: torch.LongTensor | None = None,
  1879. past_key_values: Cache | None = None,
  1880. token_type_ids: torch.LongTensor | None = None,
  1881. inputs_embeds: torch.FloatTensor | None = None,
  1882. labels: torch.LongTensor | None = None,
  1883. use_cache: bool | None = None,
  1884. logits_to_keep: int | torch.Tensor = 0,
  1885. **lm_kwargs: Unpack[TransformersKwargs],
  1886. ) -> Gemma3nCausalLMOutputWithPast:
  1887. r"""
  1888. input_features_mask (torch.Tensor, *optional*, defaults to None):
  1889. The attention mask for the input audio.
  1890. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1891. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1892. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are
  1893. ignored (masked), the loss is only computed for the tokens with labels in
  1894. `[0, ..., config.text_config.vocab_size]`.
  1895. Example:
  1896. ```python
  1897. >>> from PIL import Image
  1898. >>> import httpx
  1899. >>> from io import BytesIO
  1900. >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
  1901. >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
  1902. >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
  1903. >>> messages = [
  1904. ... {
  1905. ... "role": "system",
  1906. ... "content": [
  1907. ... {"type": "text", "text": "You are a helpful assistant."}
  1908. ... ]
  1909. ... },
  1910. ... {
  1911. ... "role": "user", "content": [
  1912. ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
  1913. ... {"type": "text", "text": "Where is the cat standing?"},
  1914. ... ]
  1915. ... },
  1916. ... ]
  1917. >>> inputs = processor.apply_chat_template(
  1918. ... messages,
  1919. ... tokenizer=True,
  1920. ... return_dict=True,
  1921. ... return_tensors="pt",
  1922. ... add_generation_prompt=True
  1923. ... )
  1924. >>> # Generate
  1925. >>> generate_ids = model.generate(**inputs)
  1926. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1927. "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
  1928. ```
  1929. """
  1930. outputs = self.model(
  1931. input_ids=input_ids,
  1932. pixel_values=pixel_values,
  1933. input_features=input_features,
  1934. attention_mask=attention_mask,
  1935. input_features_mask=input_features_mask,
  1936. position_ids=position_ids,
  1937. past_key_values=past_key_values,
  1938. token_type_ids=token_type_ids,
  1939. inputs_embeds=inputs_embeds,
  1940. labels=labels,
  1941. use_cache=use_cache,
  1942. return_dict=True,
  1943. **lm_kwargs,
  1944. )
  1945. hidden_states = outputs.last_hidden_state
  1946. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1947. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1948. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1949. if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None:
  1950. logits = logits / final_logit_softcapping
  1951. logits = torch.tanh(logits)
  1952. logits = logits * final_logit_softcapping
  1953. loss = None
  1954. if labels is not None:
  1955. # Upcast to float if we need to compute the loss to avoid potential precision issues
  1956. logits = logits.float()
  1957. shift_logits = logits[..., :-1, :]
  1958. shift_labels = labels[..., 1:]
  1959. if attention_mask is not None:
  1960. # we use the input attention mask to shift the logits and labels, because it is 2D.
  1961. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
  1962. shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
  1963. shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
  1964. shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
  1965. else:
  1966. shift_logits = shift_logits.contiguous()
  1967. shift_labels = shift_labels.contiguous()
  1968. # Flatten the tokens
  1969. loss_fct = nn.CrossEntropyLoss()
  1970. flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
  1971. flat_labels = shift_labels.view(-1).to(shift_logits.device)
  1972. loss = loss_fct(flat_logits, flat_labels)
  1973. return Gemma3nCausalLMOutputWithPast(
  1974. loss=loss,
  1975. logits=logits,
  1976. past_key_values=outputs.past_key_values,
  1977. hidden_states=outputs.hidden_states,
  1978. attentions=outputs.attentions,
  1979. image_hidden_states=outputs.image_hidden_states,
  1980. audio_hidden_states=outputs.audio_hidden_states,
  1981. )
  1982. def prepare_inputs_for_generation(
  1983. self,
  1984. input_ids,
  1985. past_key_values=None,
  1986. inputs_embeds=None,
  1987. position_ids=None,
  1988. pixel_values=None,
  1989. input_features=None,
  1990. attention_mask=None,
  1991. input_features_mask=None,
  1992. token_type_ids=None,
  1993. use_cache=True,
  1994. logits_to_keep=None,
  1995. labels=None,
  1996. is_first_iteration=False,
  1997. **kwargs,
  1998. ):
  1999. # Overwritten -- custom `position_ids` and `pixel_values` handling
  2000. model_inputs = super().prepare_inputs_for_generation(
  2001. input_ids,
  2002. past_key_values=past_key_values,
  2003. inputs_embeds=inputs_embeds,
  2004. attention_mask=attention_mask,
  2005. position_ids=position_ids,
  2006. use_cache=use_cache,
  2007. logits_to_keep=logits_to_keep,
  2008. token_type_ids=token_type_ids,
  2009. is_first_iteration=is_first_iteration,
  2010. **kwargs,
  2011. )
  2012. # If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
  2013. # tokens anymore. Otherwise multimodal inputs should be passed to model.
  2014. # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
  2015. if is_first_iteration or not use_cache:
  2016. model_inputs["pixel_values"] = pixel_values
  2017. model_inputs["input_features"] = input_features
  2018. model_inputs["input_features_mask"] = input_features_mask
  2019. return model_inputs
  2020. def create_masks_for_generate(self, **super_kwargs):
  2021. raise AttributeError("Do not inherit create_masks_for_generate from PaliGemma")
  2022. __all__ = [
  2023. "Gemma3nAudioConfig",
  2024. "Gemma3nAudioEncoder",
  2025. "Gemma3nConfig",
  2026. "Gemma3nForCausalLM",
  2027. "Gemma3nForConditionalGeneration",
  2028. "Gemma3nModel",
  2029. "Gemma3nPreTrainedModel",
  2030. "Gemma3nTextConfig",
  2031. "Gemma3nTextModel",
  2032. "Gemma3nVisionConfig",
  2033. ]