modeling_gemma3n.py 110 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_gemma3n.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. from collections.abc import Callable, Sequence
  23. from dataclasses import dataclass
  24. from typing import Optional
  25. import torch
  26. import torch.nn as nn
  27. import torch.nn.functional as F
  28. from ... import initialization as init
  29. from ...activations import ACT2FN
  30. from ...cache_utils import Cache, DynamicCache
  31. from ...generation import GenerationMixin
  32. from ...integrations import use_kernelized_func
  33. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  34. from ...modeling_layers import GradientCheckpointingLayer
  35. from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
  36. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  37. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  38. from ...processing_utils import Unpack
  39. from ...utils import (
  40. ModelOutput,
  41. TransformersKwargs,
  42. auto_docstring,
  43. can_return_tuple,
  44. torch_compilable_check,
  45. )
  46. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  47. from ...utils.output_capturing import capture_outputs
  48. from ..auto import AutoModel
  49. from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig
  50. @dataclass
  51. @auto_docstring
  52. class Gemma3nAudioEncoderModelOutput(BaseModelOutputWithPooling):
  53. r"""
  54. audio_mel_mask (`torch.BoolTensor`, *optional*):
  55. A torch.BoolTensor of shape `(batch_size, num_frames)`
  56. """
  57. audio_mel_mask: torch.BoolTensor | None = None
  58. @dataclass
  59. @auto_docstring(
  60. custom_intro="""
  61. Base class for Gemma3n outputs, with hidden states and attentions.
  62. """
  63. )
  64. class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
  65. r"""
  66. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  67. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  68. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  69. `past_key_values` input) to speed up sequential decoding.
  70. image_hidden_states (`torch.FloatTensor`, *optional*):
  71. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  72. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  73. audio_hidden_states (`torch.FloatTensor`, *optional*):
  74. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  75. audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
  76. """
  77. image_hidden_states: torch.FloatTensor | None = None
  78. audio_hidden_states: torch.FloatTensor | None = None
  79. @dataclass
  80. @auto_docstring(
  81. custom_intro="""
  82. Base class for Gemma3n causal language model (or autoregressive) outputs.
  83. """
  84. )
  85. class Gemma3nCausalLMOutputWithPast(ModelOutput):
  86. r"""
  87. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  88. Language modeling loss (for next-token prediction).
  89. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
  90. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  91. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  92. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  93. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  94. `past_key_values` input) to speed up sequential decoding.
  95. image_hidden_states (`torch.FloatTensor`, *optional*):
  96. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  97. image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
  98. audio_hidden_states (`torch.FloatTensor`, *optional*):
  99. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  100. audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
  101. """
  102. loss: torch.FloatTensor | None = None
  103. logits: torch.FloatTensor | None = None
  104. past_key_values: Cache | None = None
  105. hidden_states: tuple[torch.FloatTensor] | None = None
  106. attentions: tuple[torch.FloatTensor] | None = None
  107. image_hidden_states: torch.FloatTensor | None = None
  108. audio_hidden_states: torch.FloatTensor | None = None
  109. class Gemma3nRMSNorm(nn.Module):
  110. def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
  111. super().__init__()
  112. self.eps = eps
  113. self.with_scale = with_scale
  114. if self.with_scale:
  115. self.weight = nn.Parameter(torch.ones(dim), requires_grad=True)
  116. def _norm(self, hidden_states: torch.Tensor):
  117. mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps
  118. # Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX
  119. return hidden_states * torch.pow(mean_squared, -0.5)
  120. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  121. normed_output = self._norm(hidden_states.float())
  122. if self.with_scale:
  123. normed_output = normed_output * self.weight.float()
  124. return normed_output.type_as(hidden_states)
  125. # ==== Audio Encoder ====
  126. class Gemma3nAudioRelativePositionEmbedding(nn.Module):
  127. def __init__(self, config: Gemma3nAudioConfig):
  128. super().__init__()
  129. self.config = config
  130. self.num_heads = self.config.conf_num_attention_heads
  131. self.channels = self.config.hidden_size
  132. self.head_dim = self.channels // self.num_heads
  133. self.max_backward = max(0, self.config.conf_attention_context_left - 1)
  134. self.max_forward = self.config.conf_attention_context_right
  135. self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
  136. min_timescale = 1.0
  137. max_timescale = 1.0e4
  138. num_timescales = self.channels // 2
  139. log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
  140. inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
  141. self.register_buffer(
  142. "inv_timescales",
  143. inv_timescales.float().unsqueeze(0).unsqueeze(0),
  144. persistent=False,
  145. )
  146. def _get_timing_signal_1d_pos(self, position: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
  147. position = position.float().unsqueeze(-1)
  148. scaled_time = position * self.inv_timescales.to(device=position.device, dtype=torch.float32)
  149. timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
  150. return timing_signal.type(dtype)
  151. def _relative_shift(
  152. self,
  153. term_bd_before_shift: torch.Tensor,
  154. batch_size: int,
  155. num_heads: int,
  156. num_query_blocks: int,
  157. query_block_size: int,
  158. key_context_size: int,
  159. max_span_plus_1: int,
  160. ) -> torch.Tensor:
  161. """Performs the relative shift.
  162. Args:
  163. term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
  164. (B), num_heads (N), num_query_blocks (U), query_block_size (W),
  165. key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1).
  166. Returns:
  167. Tensor of shape [B, N, U, W, C].
  168. """
  169. # term_bd_before_shift shape: [B, N, U, W, F_span]
  170. # Target shape after shift: [B, N, U, W, C]
  171. # Padding amount for the last dimension (F_span) to become (C + 1)
  172. # C = key_context_size
  173. # F_span = max_span_plus_1
  174. pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
  175. # PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...)
  176. # We only pad the last dimension on the right.
  177. padding_tuple = (0, pad_amount_last_dim)
  178. term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
  179. # Shape after pad: [B, N, U, W, C+1]
  180. # Reshape for slicing (emulating JAX's behavior)
  181. # [B, N, U, W * (C+1)]
  182. term_bd_reshaped = term_bd_padded.reshape(
  183. (
  184. batch_size,
  185. num_heads,
  186. num_query_blocks,
  187. query_block_size * (key_context_size + 1),
  188. )
  189. )
  190. # Slice to effective [B, N, U, W * C]
  191. term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size]
  192. # Reshape back to [B, N, U, W, C]
  193. term_bd_shifted = term_bd_sliced.reshape(
  194. (
  195. batch_size,
  196. num_heads,
  197. num_query_blocks,
  198. query_block_size,
  199. key_context_size,
  200. )
  201. )
  202. return term_bd_shifted
  203. def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
  204. # queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim)
  205. # keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim)
  206. # C = W + L + R (key_context_size)
  207. # F_span = L + R + 1 (max_span + 1)
  208. batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
  209. _, _, key_context_size, _, _ = keys.shape
  210. # Relative positions for sinusoidal embeddings: [L, L-1, ..., -R]
  211. # Length is L+R+1 = self.max_span + 1
  212. pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze(
  213. 0
  214. ) # Shape [1, F_span]
  215. max_span_plus_1 = pos_indices.shape[1] # F_span
  216. sin_emb_timing_signal = self._get_timing_signal_1d_pos(
  217. pos_indices, dtype=queries.dtype
  218. ) # Shape [1, F_span, self.channels]
  219. # Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H]
  220. projected_sin_emb = self.pos_proj(sin_emb_timing_signal)
  221. # Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H]
  222. sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze(
  223. 0
  224. ) # Shape [F, N, H]
  225. # term_ac: Query-Key content interaction
  226. # queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul
  227. # keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul
  228. queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H]
  229. keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C]
  230. term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
  231. # term_bd: Query-Position interaction
  232. # Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
  233. # queries shape: [B, U, W, N, H]
  234. # sin_emb shape: [F, N, H]
  235. # Target output shape: [B, N, U, W, F]
  236. # Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb
  237. q_permuted = queries.permute(0, 3, 1, 2, 4)
  238. # Permute sin_emb to [N, H, F] to prepare for matmul
  239. # sin_emb original is [F, N, H]
  240. s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F]
  241. # Reshape queries for matmul: [B, N, U*W, H]
  242. q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
  243. # Perform matmul: [B, N, U*W, H] @ [N, H, F]
  244. # s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
  245. # Result: [B, N, U*W, F]
  246. term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
  247. # Reshape to target [B, N, U, W, F]
  248. term_bd_unshifed = term_bd_unshifed_matmul.reshape(
  249. batch_size,
  250. num_heads,
  251. num_query_blocks,
  252. query_block_size,
  253. max_span_plus_1,
  254. )
  255. # Apply relative shift to term_bd_unshifed
  256. term_bd_shifted = self._relative_shift(
  257. term_bd_unshifed,
  258. batch_size,
  259. num_heads,
  260. num_query_blocks,
  261. query_block_size,
  262. key_context_size,
  263. max_span_plus_1,
  264. ) # Shape [B, N, U, W, C]
  265. return term_ac + term_bd_shifted
  266. class Gemma3nAudioAttention(nn.Module):
  267. def __init__(self, config: Gemma3nAudioConfig):
  268. super().__init__()
  269. self.config = config
  270. self.num_heads = self.config.conf_num_attention_heads
  271. self.hidden_size = self.config.hidden_size
  272. self.head_dim = self.hidden_size // self.num_heads
  273. self.chunk_size = self.config.conf_attention_chunk_size
  274. self.max_future_horizon = self.config.conf_attention_context_right
  275. self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
  276. self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
  277. self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
  278. self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config)
  279. self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
  280. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  281. self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  282. self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  283. q_scale = self.head_dim**-0.5
  284. r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
  285. self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
  286. local_causal_valid_mask = self.create_local_causal_valid_mask()
  287. self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
  288. self.register_buffer(
  289. "softcap",
  290. torch.tensor(self.attention_logits_soft_cap).float(),
  291. persistent=False,
  292. )
  293. def create_local_causal_valid_mask(self):
  294. lower_causal_mask = torch.tril(
  295. torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
  296. diagonal=0,
  297. ).T
  298. upper_causal_mask = torch.tril(
  299. torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
  300. diagonal=self.max_past_horizon + self.max_future_horizon,
  301. )
  302. local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
  303. local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
  304. return local_causal_valid_mask
  305. def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
  306. batch, _, *tail_shape = x.shape
  307. left = x.new_zeros((batch, pad_left, *tail_shape))
  308. right = x.new_zeros((batch, pad_right, *tail_shape))
  309. x = torch.cat([left, x, right], dim=1)
  310. return x
  311. def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
  312. """Turns a sequence to non overlapping blocks.
  313. Args:
  314. hidden_states: a tensor of [batch, time, ...].
  315. Returns:
  316. A tensor of [batch, num_blocks, block_size, ...], with necessary
  317. paddings,
  318. where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
  319. """
  320. shape = hidden_states.shape
  321. b, t = shape[:2]
  322. num_blocks = (t + self.chunk_size - 1) // self.chunk_size
  323. if (padding_len := num_blocks * self.chunk_size - t) > 0:
  324. hidden_states = self._pad_dim1(hidden_states, 0, padding_len)
  325. permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
  326. hidden_states = hidden_states.reshape(permute_dims).contiguous()
  327. return hidden_states
  328. def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
  329. """Extracts temporal context for every block.
  330. Args:
  331. hidden_states: a tensor of [batch, time, ...].
  332. Returns:
  333. A tensor of [batch, num_blocks, context_size, ...], with necessary
  334. paddings,
  335. where context_size = block_size + left_context + right_context,
  336. and output[:, i, ...] are x[:, start-left_context:end+right_context,
  337. ...],
  338. start = i * block_size, end = (i + 1) * block_size.
  339. """
  340. pad_left = self.max_past_horizon
  341. # The JAX equivalent padding for signal.frame with pad_mode='valid' is
  342. # (left_context, right_context + block_size - 1) on the time dimension.
  343. # PyTorch's _pad_dim1 applies padding symmetrically if only one value is given,
  344. # or (pad_dim_start, pad_dim_end) if two are given.
  345. # Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
  346. # or dim 1 (time for [B,T]).
  347. # The current pad_right calculation matches the JAX effective padding.
  348. pad_right = self.max_future_horizon + self.chunk_size - 1
  349. hidden_states = self._pad_dim1(hidden_states, pad_left, pad_right)
  350. frame_len = self.context_size
  351. frame_step = self.chunk_size
  352. # Directly use unfold without the subframe_factor logic
  353. # x.unfold(dimension, size, step)
  354. # dimension=1 (time dimension, assuming x is [B, T_padded, ...])
  355. # size=frame_len (context_size)
  356. # step=frame_step (chunk_size)
  357. x_unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step)
  358. # If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
  359. # If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
  360. # We want to match JAX's typical output for such operations which might be
  361. # [B, num_blocks, frame_len, N, H] if N, H are present.
  362. # The relative_position_embedding expects keys as [B, U, C, N, H].
  363. # If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
  364. if hidden_states.ndim > 2 and x_unfolded.ndim > 3: # Check if inner dimensions (like N, H) exist
  365. # Current shape after unfold for [B, T_pad, N, H] is [B, U, N, H, C]
  366. # Target shape for keys in RPE: [B, U, C, N, H]
  367. x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2)
  368. return x_unfolded.contiguous()
  369. def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
  370. # sl.Dense uses jax.numpy.einsum("...a,abcd->...bcd") and jax.numpy.select()
  371. qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
  372. query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
  373. key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
  374. value_states = self.v_proj(hidden_states).reshape(qkv_shape).contiguous()
  375. per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale)
  376. broadcast_shape = (1, 1, 1, self.head_dim)
  377. per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape)
  378. query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast
  379. batch_size, q_time = query_states.shape[:2]
  380. query_blocks = self._convert_to_block(query_states)
  381. key_blocks = self._extract_block_context(key_states)
  382. value_blocks = self._extract_block_context(value_states)
  383. num_query_blocks = query_blocks.shape[1]
  384. # 1. Create a mask indicating originally valid positions.
  385. original_valid_mask = ~mask # True for valid, False for padded
  386. # 2. Extract blocks from this validity mask.
  387. extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask)
  388. # If subframe_factor was used in _extract_block_context for a [B, T] input mask,
  389. # the shape might be [B, U, C/SF, SF]. Reshape to [B, U, C].
  390. # batch_size and num_query_blocks are known from query_blocks.
  391. # self.context_size is C.
  392. if (
  393. extracted_valid_mask_blocks.ndim == 4
  394. and extracted_valid_mask_blocks.shape[2] * extracted_valid_mask_blocks.shape[3] == self.context_size
  395. ):
  396. extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
  397. batch_size, num_query_blocks, self.context_size
  398. )
  399. # After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask.
  400. # This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
  401. # but for the mask case, this should hold.
  402. if extracted_valid_mask_blocks.shape != (
  403. batch_size,
  404. num_query_blocks,
  405. self.context_size,
  406. ):
  407. raise ValueError(
  408. "Shape of extracted_valid_mask_blocks"
  409. f" {extracted_valid_mask_blocks.shape} is not ({batch_size},"
  410. f" {num_query_blocks}, {self.context_size}) after potential reshape."
  411. )
  412. # 3. Expand dimensions for broadcasting with logits and causal mask.
  413. # Target shape for broadcasting with logits [B,N,U,W,C]
  414. # extracted_valid_mask_blocks to [B, 1, U, 1, C]
  415. condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(1).unsqueeze(-2)
  416. # self.local_causal_valid_mask is [W, C], True where allowed by local window.
  417. # Expand to [1, 1, 1, W, C]
  418. condition_from_causality = self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
  419. # 4. Combine the two conditions.
  420. # final_condition will be True where a key is *both* originally valid *and* causally accessible.
  421. # Broadcasts to [B, 1, U, W, C]
  422. final_condition_for_where = torch.logical_and(
  423. condition_from_input_validity,
  424. condition_from_causality.to(condition_from_input_validity.device), # Ensure same device
  425. )
  426. # Embed queries and keys
  427. logits = self.relative_position_embedding(query_blocks, key_blocks)
  428. # Apply attention logit softcap
  429. # Ensure softcap is on the same device as logits
  430. softcap_val = self.softcap.to(logits.device)
  431. logits = logits / softcap_val
  432. logits = torch.tanh(logits)
  433. logits = logits * softcap_val
  434. # Apply the combined mask.
  435. # final_condition_for_where will broadcast with logits [B,N,U,W,C]
  436. logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
  437. probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
  438. # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
  439. b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
  440. h_dim = value_blocks.shape[-1]
  441. prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
  442. v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
  443. result_bmm = torch.bmm(prob_bun, v_bun)
  444. context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4)
  445. context_vectors = context_vectors.reshape(
  446. (
  447. batch_size,
  448. num_query_blocks * self.chunk_size,
  449. self.num_heads,
  450. self.head_dim,
  451. )
  452. )
  453. context_vectors = context_vectors[:, :q_time]
  454. return context_vectors
  455. class Gemma3nAudioCumulativeGroupNorm(nn.Module):
  456. """Applies Group Normalization cumulatively over the time dimension.
  457. This layer normalizes the input by calculating the mean and variance
  458. cumulatively over the time dimension (dim 1). The statistics are computed
  459. over all feature dimensions (specified by `feature_dims` and `num_channels`)
  460. for elements marked as valid by the optional `mask`.
  461. If a `mask` is provided (True for valid, False for invalid/padded),
  462. invalid time steps do not contribute to the statistics calculation, and
  463. their corresponding output values are zeroed out.
  464. Scale and bias, if enabled, are applied per-channel (last dimension).
  465. This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
  466. and `cumulative=True`.
  467. """
  468. def __init__(
  469. self,
  470. num_channels: int, # Number of channels (size of the last dimension)
  471. feature_dims: Sequence[int], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
  472. eps: float = 1e-3,
  473. ):
  474. super().__init__()
  475. self.num_channels = num_channels
  476. self.feature_dims = tuple(feature_dims)
  477. self.eps = eps
  478. # Scale parameter depends only on the channel dimension
  479. self.weight = nn.Parameter(torch.ones(num_channels))
  480. # Axes for normalization: all dimensions except Batch (0) and Time (1).
  481. # For input [B, T, *feature_dims, C], these are dims from 2 onwards.
  482. self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
  483. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  484. """Applies cumulative group norm, optionally using a mask.
  485. Args:
  486. hidden_states: Input tensor, shape [B, T, *feature_dims, C].
  487. Returns:
  488. Normalized tensor with the same shape as x.
  489. """
  490. expected_input_suffix = self.feature_dims + (self.num_channels,)
  491. if hidden_states.shape[2:] != expected_input_suffix:
  492. raise ValueError(
  493. f"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected"
  494. f" suffix (feature_dims + num_channels) {expected_input_suffix}"
  495. )
  496. input_dtype = hidden_states.dtype
  497. # Calculations are performed in float32 for numerical stability.
  498. calc_dtype = torch.float32
  499. x_calc = hidden_states.to(calc_dtype)
  500. # Prepare a broadcastable mask (`mask_calc`).
  501. # If no mask is provided, treat all elements as valid
  502. # (mask_calc is all ones).
  503. # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
  504. mask_calc = torch.ones_like(x_calc, dtype=calc_dtype)
  505. # Cumulative Statistics Calculation
  506. # 1. Sum of values over reduction axes at each time step.
  507. sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True)
  508. # 2. Cumulative sum of values over time.
  509. cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
  510. # 3. Count of valid elements in the normalization group at each time step.
  511. # (A "group" here consists of all features at a given Batch, Time).
  512. elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
  513. # 4. Cumulative count of valid elements over time.
  514. cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
  515. # Avoid division by zero if all preceding elements were masked.
  516. safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0)
  517. # 5. Cumulative mean.
  518. cum_mean = cum_sum_values / safe_cum_count_elements
  519. # 6. Sum of squared differences from the cumulative mean.
  520. # Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
  521. # Using x_calc here for the difference, as cum_mean already accounts for masking.
  522. squared_diff_from_mean = (x_calc - cum_mean).pow(2)
  523. sum_sq_diff_at_t = torch.sum(squared_diff_from_mean, dim=self.reduction_axes, keepdim=True)
  524. # 7. Cumulative sum of squared differences over time.
  525. cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1)
  526. # 8. Cumulative variance.
  527. cum_variance = cum_sum_sq_diff / safe_cum_count_elements
  528. # Normalize the input using the calculated cumulative statistics:
  529. # (x - E[x]) / sqrt(Var[x] + eps)
  530. normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps)
  531. # Apply affine transformation (scale and bias) if enabled.
  532. # Scale and bias are applied per-channel (last dimension).
  533. scale = self.weight.to(calc_dtype)
  534. # Reshape for broadcasting: [C] -> [1, ..., 1, C]
  535. scale_view_shape = [1] * (hidden_states.dim() - 1) + [self.num_channels]
  536. normalized_x = normalized_x * scale.view(scale_view_shape)
  537. # Zero out outputs for time steps that were originally masked (where mask_calc is 0).
  538. # This ensures padded/invalid positions in the input result in zero output.
  539. final_output = normalized_x * mask_calc
  540. return final_output.to(input_dtype)
  541. class Gemma3nAudioSSCPConvBlock(nn.Module):
  542. """A single convolution block for the SubSampleConvProjection.
  543. This block consists of a 2D convolution, followed by CumulativeGroupNorm,
  544. and a ReLU activation. It handles manual padding for the convolution.
  545. """
  546. def __init__(
  547. self,
  548. config: Gemma3nAudioConfig,
  549. idx: int,
  550. input_freq_dim: int, # Changed from input_spatial_dim
  551. manual_padding: tuple[int, int, int, int] = (0, 0, 0, 0),
  552. ):
  553. super().__init__()
  554. self.config = config
  555. self.manual_padding = manual_padding
  556. # in_channels is 1 for the first block, or C_out from previous block's conv
  557. in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
  558. out_channels = self.config.sscp_conv_channel_size[idx]
  559. kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
  560. stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
  561. self.conv = nn.Conv2d(
  562. in_channels=in_channels,
  563. out_channels=out_channels,
  564. kernel_size=(
  565. kernel_h,
  566. kernel_w,
  567. ), # Kernel (kH, kW) operates on (Time, Freq_dim)
  568. stride=(stride_h, stride_w),
  569. padding=(0, 0), # Manual padding is used
  570. bias=False,
  571. )
  572. # Calculate output frequency dimension (f_out_conv) after this convolution.
  573. # input_freq_dim is the unpadded width (feature dimension).
  574. # self.manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
  575. f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1]
  576. f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
  577. self.norm = Gemma3nAudioCumulativeGroupNorm(
  578. num_channels=out_channels, # Channels of the conv output
  579. feature_dims=(f_out_conv,), # The frequency dimension size after conv
  580. eps=self.config.sscp_conv_group_norm_eps,
  581. )
  582. self.activation = nn.ReLU()
  583. def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
  584. # Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
  585. # manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
  586. # F.pad applies to last two dims: F_in then T_in
  587. audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0).to(
  588. self.conv.weight.dtype
  589. )
  590. # Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
  591. # Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
  592. audio_encodings_conv = self.conv(audio_encodings_padded)
  593. # Expected conv output shape: [B, C_out, T_out, F_out]
  594. # Input to norm is [B, T_out, F_out, C_out]
  595. x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous()
  596. x_normed = self.norm(x_for_norm)
  597. # Output of norm is [B, T_out, F_out, C_out], permute back to [B, C_out, T_out, F_out]
  598. audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous()
  599. return self.activation(audio_encodings_normed)
  600. class Gemma3nAudioSubSampleConvProjection(nn.Module):
  601. def __init__(self, config: Gemma3nAudioConfig):
  602. super().__init__()
  603. self.config = config
  604. current_f_for_block_input = config.input_feat_size # Start with original feature dim
  605. calculated_block_padding = []
  606. calculated_f_out_dims = [] # Tracking frequency dimension output sizes
  607. for i in range(2): # Assuming 2 conv layers as per sscp_conv_... arrays
  608. kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
  609. stride_h, stride_w = config.sscp_conv_stride_size[i]
  610. # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
  611. # JAX 'reverse_causal' padding is (0, kernel_size - 1)
  612. pad_t_top = 0
  613. pad_t_bottom = kernel_h - 1
  614. # Frequency Padding (Width for Conv2d)
  615. # Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
  616. # and the successful test configuration.
  617. # If kernel/stride/input_freq for frequency changes, this might need re-evaluation
  618. # to match generic JAX 'SAME' behavior if it differs.
  619. pad_f_left = 1
  620. pad_f_right = 1
  621. manual_padding_tuple = (
  622. pad_f_left,
  623. pad_f_right,
  624. pad_t_top,
  625. pad_t_bottom,
  626. )
  627. calculated_block_padding.append(manual_padding_tuple)
  628. # Calculate output frequency dimension after this convolution
  629. # This uses the actual padding applied and kernel/stride.
  630. f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
  631. f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 # Assuming dilation_w = 1
  632. calculated_f_out_dims.append(f_out_after_conv)
  633. current_f_for_block_input = f_out_after_conv
  634. self.conv_0 = Gemma3nAudioSSCPConvBlock(
  635. idx=0,
  636. input_freq_dim=config.input_feat_size, # Pass original feature dim
  637. config=config,
  638. manual_padding=calculated_block_padding[0],
  639. )
  640. self.conv_1 = Gemma3nAudioSSCPConvBlock(
  641. idx=1,
  642. input_freq_dim=calculated_f_out_dims[0], # Output freq dim from conv_0
  643. config=config,
  644. manual_padding=calculated_block_padding[1],
  645. )
  646. final_c_out = config.sscp_conv_channel_size[-1]
  647. final_f_out = calculated_f_out_dims[-1] # Final frequency dimension
  648. self.input_proj_in_features = final_c_out * final_f_out
  649. self.input_proj_linear = nn.Linear(self.input_proj_in_features, self.config.hidden_size, bias=False)
  650. def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
  651. # audio_encodings is [B, T, F_in]
  652. # Reshape to [B, 1, T, F_in] (Batch, Channels=1, Height=Time, Width=F_in)
  653. audio_encodings_reshaped = audio_encodings.unsqueeze(1)
  654. x = self.conv_0(audio_encodings_reshaped)
  655. x = self.conv_1(x)
  656. # x from conv_1 is [B, C_out_1, T_out_1, F_out_1]
  657. b, c_out, t_out, f_out = x.shape
  658. # Permute to [B, T_out_1, F_out_1, C_out_1] then flatten F_out_1 and C_out_1
  659. x_permuted = x.permute(0, 2, 3, 1).contiguous()
  660. output_flattened = x_permuted.view(b, t_out, f_out * c_out)
  661. output = self.input_proj_linear(output_flattened)
  662. return output
  663. class Gemma3nAudioConformerAttention(nn.Module):
  664. def __init__(self, config: Gemma3nAudioConfig):
  665. super().__init__()
  666. self.config = config
  667. self.post_in_features = self.config.hidden_size
  668. self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
  669. self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
  670. self.attn = Gemma3nAudioAttention(config)
  671. self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
  672. self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
  673. def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
  674. audio_encodings_input_to_attn = audio_encodings
  675. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  676. audio_encodings_norm = self.pre_attn_norm(audio_encodings)
  677. # Output of self.attn is [B, T, NumHeads, HeadDim]
  678. audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask)
  679. # Reshape from [B, T, NumHeads, HeadDim] to [B, T, NumHeads * HeadDim]
  680. # NumHeads * HeadDim = hidden_size
  681. b, t, num_heads, head_dim = audio_encodings_attn_out.shape
  682. audio_encodings_reshaped = audio_encodings_attn_out.reshape(b, t, num_heads * head_dim)
  683. audio_encodings = self.post(audio_encodings_reshaped)
  684. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  685. return audio_encodings_input_to_attn + self.post_norm(audio_encodings)
  686. class Gemma3nAudioConformerFeedForward(nn.Module):
  687. def __init__(self, config: Gemma3nAudioConfig):
  688. super().__init__()
  689. self.config = config
  690. self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
  691. self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
  692. self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
  693. self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
  694. self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
  695. self.post_layer_scale = self.config.conf_residual_weight
  696. def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
  697. residual = audio_encodings
  698. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  699. audio_encodings = self.pre_layer_norm(audio_encodings)
  700. audio_encodings: torch.Tensor = self.ffw_layer_1(audio_encodings)
  701. audio_encodings = nn.functional.silu(audio_encodings)
  702. audio_encodings: torch.Tensor = self.ffw_layer_2(audio_encodings)
  703. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  704. audio_encodings = self.post_layer_norm(audio_encodings)
  705. return residual + (audio_encodings * self.post_layer_scale)
  706. class Gemma3nAudioConformerLightConv1d(nn.Module):
  707. def __init__(self, config: Gemma3nAudioConfig):
  708. super().__init__()
  709. self.config = config
  710. self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  711. self.linear_start = nn.Linear(self.config.hidden_size, self.config.hidden_size * 2, bias=False)
  712. self.depthwise_conv1d = nn.Conv1d(
  713. in_channels=self.config.hidden_size,
  714. out_channels=self.config.hidden_size,
  715. kernel_size=self.config.conf_conv_kernel_size,
  716. stride=1,
  717. padding=0, # Manual causal padding
  718. groups=self.config.hidden_size, # Depthwise
  719. bias=False,
  720. )
  721. self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
  722. self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  723. self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
  724. self.causal_padding = self.config.conf_conv_kernel_size - 1
  725. def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
  726. audio_encodings_residual = audio_encodings # Save for residual connection
  727. audio_encodings = self.pre_layer_norm(audio_encodings)
  728. audio_encodings = self.linear_start(audio_encodings)
  729. audio_encodings = torch.nn.functional.glu(audio_encodings, dim=-1)
  730. # Permute for Conv1d: [B, T, D] -> [B, D, T]
  731. audio_encodings_permuted = audio_encodings.permute(0, 2, 1)
  732. # Apply manual causal padding
  733. audio_encodings_permuted_padded = F.pad(audio_encodings_permuted, (self.causal_padding, 0))
  734. audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded)
  735. # Permute back: [B, D, T_out] -> [B, T_out, D]
  736. audio_encodings = audio_encodings.permute(0, 2, 1)
  737. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  738. audio_encodings = self.conv_norm(audio_encodings)
  739. audio_encodings = nn.functional.silu(audio_encodings)
  740. audio_encodings = self.linear_end(audio_encodings)
  741. output = audio_encodings + audio_encodings_residual
  742. return output
  743. class Gemma3nAudioConformerBlock(nn.Module):
  744. def __init__(self, config: Gemma3nAudioConfig):
  745. super().__init__()
  746. self.config = config
  747. self.ffw_layer_start = Gemma3nAudioConformerFeedForward(self.config)
  748. self.attention = Gemma3nAudioConformerAttention(self.config)
  749. self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
  750. self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
  751. self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
  752. self.norm = Gemma3nRMSNorm(self.config.hidden_size)
  753. def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
  754. audio_encodings = self.ffw_layer_start(audio_encodings)
  755. audio_encodings = self.attention(audio_encodings, audio_mel_mask)
  756. validity_mask_for_lconv = ~audio_mel_mask # True for valid
  757. audio_encodings_for_lconv_input = audio_encodings * validity_mask_for_lconv.unsqueeze(-1).to(
  758. audio_encodings.dtype
  759. )
  760. audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
  761. audio_encodings = self.ffw_layer_end(audio_encodings)
  762. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  763. output = self.norm(audio_encodings)
  764. return output
  765. class Gemma3nTextScaledWordEmbedding(nn.Embedding):
  766. """
  767. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  768. """
  769. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
  770. super().__init__(num_embeddings, embedding_dim, padding_idx)
  771. self.scalar_embed_scale = embed_scale
  772. self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
  773. def forward(self, input_ids: torch.Tensor):
  774. return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
  775. class Gemma3nTextLaurelBlock(nn.Module):
  776. """Learned Augmented Residual Layer"""
  777. def __init__(self, config: Gemma3nTextConfig):
  778. super().__init__()
  779. self.config = config
  780. self.linear_left = nn.Linear(self.config.hidden_size, self.config.laurel_rank, bias=False)
  781. self.linear_right = nn.Linear(self.config.laurel_rank, self.config.hidden_size, bias=False)
  782. self.post_laurel_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  783. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  784. laurel_hidden_states: torch.Tensor = self.linear_left(hidden_states)
  785. laurel_hidden_states: torch.Tensor = self.linear_right(laurel_hidden_states)
  786. normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states)
  787. return hidden_states + normed_laurel_hidden_states
  788. class Gemma3nTextMLP(nn.Module):
  789. def __init__(self, config: Gemma3nTextConfig, layer_idx: int = 0):
  790. super().__init__()
  791. self.config = config
  792. self.hidden_size = config.hidden_size
  793. self.intermediate_size = config.intermediate_size[layer_idx]
  794. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  795. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  796. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  797. self.act_fn = ACT2FN[config.hidden_activation]
  798. self.activation_sparsity = config.activation_sparsity_pattern[layer_idx]
  799. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  800. gate_proj = self.gate_proj(hidden_states)
  801. if self.activation_sparsity > 0.0:
  802. gate_proj = self._gaussian_topk(gate_proj)
  803. activations = self.act_fn(gate_proj)
  804. up_proj = self.up_proj(hidden_states)
  805. down_proj = self.down_proj(activations * up_proj)
  806. return down_proj
  807. def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
  808. target_sparsity_tensor = torch.tensor(self.activation_sparsity, dtype=torch.float32, device=inputs.device)
  809. # normal_dist and std_multiplier are adapted from jax.scipy.stats.norm.ppf().
  810. #
  811. # References:
  812. # * https://docs.jax.dev/en/latest/_autosummary/jax.scipy.stats.norm.ppf.html
  813. # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.normal.Normal
  814. # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.transformed_distribution.TransformedDistribution.icdf
  815. normal_dist = torch.distributions.normal.Normal(0, 1)
  816. std_multiplier: torch.Tensor = normal_dist.icdf(target_sparsity_tensor)
  817. std_multiplier = std_multiplier.type(inputs.dtype)
  818. inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
  819. inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
  820. cutoff_x = inputs_mean + inputs_std * std_multiplier
  821. return nn.functional.relu(inputs - cutoff_x)
  822. class Gemma3nTextAltUp(nn.Module):
  823. """Alternating Updates (AltUp)
  824. The AltUp module wraps transformer layers. The `predict` step modifies the
  825. input to the transformer layer, and the `correct` step propagates the output
  826. of the transformer layer to the sparsely updated dimensions.
  827. See more in the research paper:
  828. https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
  829. """
  830. def __init__(self, config: Gemma3nTextConfig):
  831. super().__init__()
  832. self.config = config
  833. self.correct_output_scale = nn.Parameter(torch.zeros(self.config.hidden_size))
  834. self.correction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs, bias=False)
  835. self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
  836. self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
  837. self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  838. self.register_buffer("router_input_scale", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
  839. def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
  840. router_inputs = self.router_norm(x) * self.router_input_scale
  841. routed = self.modality_router(router_inputs)
  842. return torch.tanh(routed.float()).type_as(x)
  843. def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
  844. """Predicts the output of a layer using a trainable map.
  845. Args:
  846. hidden_states: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
  847. stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
  848. Returns:
  849. A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` containing the predictions.
  850. """
  851. modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
  852. if self.training and self.config.altup_coef_clip is not None:
  853. self.prediction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
  854. # Project and then transpose all 2D matrices contained so that mulmat gives the correct result
  855. all_coefs: torch.Tensor = (
  856. self.prediction_coefs(modalities)
  857. .reshape(*modalities.shape[:-1], self.config.altup_num_inputs, self.config.altup_num_inputs)
  858. .permute(0, 1, 3, 2)
  859. )
  860. # permute hidden_states to [batch_size, num_tokens, hidden_size, altup_num_inputs]
  861. predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs)
  862. predictions = predictions.permute(3, 0, 1, 2) # undo the permute
  863. predictions += hidden_states # add the original input
  864. return predictions.contiguous().type_as(hidden_states)
  865. def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
  866. """Corrects the predictions relative to the
  867. Args:
  868. predictions: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
  869. stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
  870. activated: A 3D tensor of shape `[batch_size, num_tokens, hidden_size]` containing the activated inputs.
  871. Returns:
  872. A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` correcting the original
  873. predictions relative to the activated input embeddings.
  874. """
  875. modalities = self.compute_router_modalities(activated)
  876. innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
  877. innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
  878. if self.training and self.config.altup_coef_clip is not None:
  879. weight = self.correction_coefs.weight.clamp(-self.config.altup_coef_clip, self.config.altup_coef_clip)
  880. all_coefs = torch.nn.functional.linear(modalities, weight, bias=None) + 1.0
  881. else:
  882. all_coefs = self.correction_coefs(modalities) + 1.0
  883. # all_coefs adapted from jax.numpy.einsum("...p,pi->...i", ...)
  884. # Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
  885. # and expand on dim1 for broadcastability
  886. all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1)
  887. corrected = torch.mul(innovation, all_coefs)
  888. corrected += predictions # add the original input
  889. return corrected.contiguous().type_as(activated)
  890. def forward(self, corrected: torch.Tensor) -> torch.Tensor:
  891. """
  892. This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
  893. (which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
  894. `scale_corrected_output`
  895. """
  896. return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
  897. def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
  898. """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
  899. return self.forward(corrected)
  900. def rotate_half(x):
  901. """Rotates half the hidden dims of the input."""
  902. x1 = x[..., : x.shape[-1] // 2]
  903. x2 = x[..., x.shape[-1] // 2 :]
  904. return torch.cat((-x2, x1), dim=-1)
  905. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  906. """
  907. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  908. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  909. """
  910. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  911. if n_rep == 1:
  912. return hidden_states
  913. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  914. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  915. def eager_attention_forward(
  916. module: nn.Module,
  917. query: torch.Tensor,
  918. key: torch.Tensor,
  919. value: torch.Tensor,
  920. attention_mask: torch.Tensor | None,
  921. dropout: float | int = 0.0,
  922. scaling: float | None = None,
  923. softcap: float | None = None,
  924. **kwargs,
  925. ) -> tuple[torch.Tensor, torch.Tensor]:
  926. if scaling is None:
  927. scaling = module.head_dim**-0.5
  928. key_states = repeat_kv(key, module.num_key_value_groups)
  929. value_states = repeat_kv(value, module.num_key_value_groups)
  930. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  931. if softcap is not None:
  932. attn_weights = attn_weights / softcap
  933. attn_weights = torch.tanh(attn_weights)
  934. attn_weights = attn_weights * softcap
  935. if attention_mask is not None:
  936. attn_weights = attn_weights + attention_mask
  937. # upcast attention to fp32
  938. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  939. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  940. attn_output = torch.matmul(attn_weights, value_states)
  941. attn_output = attn_output.transpose(1, 2).contiguous()
  942. return attn_output, attn_weights
  943. def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1):
  944. """Applies Rotary Position Embedding to the query and key tensors.
  945. Args:
  946. x (`torch.Tensor`): The tensor to embed.
  947. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  948. sin (`torch.Tensor`): The sine part of the rotary embedding.
  949. unsqueeze_dim (`int`, *optional*, defaults to 1):
  950. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  951. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  952. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  953. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  954. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  955. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  956. Returns:
  957. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  958. """
  959. cos = cos.unsqueeze(unsqueeze_dim)
  960. sin = sin.unsqueeze(unsqueeze_dim)
  961. return (x * cos) + (rotate_half(x) * sin)
  962. @use_kernelized_func(apply_rotary_pos_emb)
  963. class Gemma3nTextAttention(nn.Module):
  964. """Multi-headed attention from 'Attention Is All You Need' paper"""
  965. def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
  966. super().__init__()
  967. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  968. self.config = config
  969. self.layer_idx = layer_idx
  970. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  971. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  972. self.scaling = 1.0
  973. self.attention_dropout = self.config.attention_dropout
  974. self.is_causal = True
  975. self.q_proj = nn.Linear(
  976. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  977. )
  978. self.k_proj = nn.Linear(
  979. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  980. )
  981. self.v_proj = nn.Linear(
  982. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  983. )
  984. self.o_proj = nn.Linear(
  985. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  986. )
  987. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  988. self.is_sliding = self.layer_type == "sliding_attention"
  989. self.q_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  990. self.k_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  991. self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
  992. first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
  993. self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
  994. prev_layers = config.layer_types[:first_kv_shared_layer_idx]
  995. if self.is_kv_shared_layer:
  996. # For shared layers, find the last non-shared layer of the same type before sharing starts
  997. self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
  998. self.store_full_length_kv = False
  999. else:
  1000. self.kv_shared_layer_index = None
  1001. # For non-shared layers, store full-length kv if this is the last non-shared layer of its type
  1002. self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
  1003. config.layer_types[layer_idx]
  1004. )
  1005. def forward(
  1006. self,
  1007. hidden_states: torch.Tensor,
  1008. position_embeddings: torch.Tensor = None,
  1009. attention_mask: torch.Tensor | None = None,
  1010. past_key_values: Cache | None = None,
  1011. **kwargs: Unpack[TransformersKwargs],
  1012. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  1013. input_shape = hidden_states.shape[:-1]
  1014. hidden_shape = (*input_shape, -1, self.config.head_dim)
  1015. cos, sin = position_embeddings
  1016. query_states = self.q_proj(hidden_states).view(hidden_shape)
  1017. query_states = self.q_norm(query_states)
  1018. query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
  1019. query_states = query_states.transpose(1, 2)
  1020. # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer
  1021. if self.is_kv_shared_layer and past_key_values is not None:
  1022. key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
  1023. # Device of past layer may be different from current one
  1024. key_states = key_states.to(query_states.device)
  1025. value_states = value_states.to(query_states.device)
  1026. else:
  1027. key_states = self.k_proj(hidden_states).view(hidden_shape)
  1028. key_states = self.k_norm(key_states)
  1029. key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
  1030. key_states = key_states.transpose(1, 2)
  1031. value_states = self.v_proj(hidden_states).view(hidden_shape)
  1032. value_states = self.v_norm(value_states)
  1033. value_states = value_states.transpose(1, 2)
  1034. if past_key_values is not None:
  1035. if not self.is_kv_shared_layer:
  1036. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  1037. if self.store_full_length_kv:
  1038. if not hasattr(past_key_values, "shared_layers"):
  1039. past_key_values.shared_layers = {}
  1040. past_key_values.shared_layers[self.layer_idx] = key_states, value_states
  1041. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  1042. self.config._attn_implementation, eager_attention_forward
  1043. )
  1044. attn_output, attn_weights = attention_interface(
  1045. self,
  1046. query_states,
  1047. key_states,
  1048. value_states,
  1049. attention_mask,
  1050. dropout=self.attention_dropout if self.training else 0.0,
  1051. scaling=self.scaling,
  1052. sliding_window=self.sliding_window,
  1053. **kwargs,
  1054. )
  1055. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  1056. attn_output = self.o_proj(attn_output)
  1057. return attn_output, attn_weights
  1058. class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
  1059. def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
  1060. super().__init__()
  1061. self.config = config
  1062. self.hidden_size = config.hidden_size
  1063. self.layer_idx = layer_idx
  1064. self.self_attn = Gemma3nTextAttention(config, layer_idx)
  1065. self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx)
  1066. self.input_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1067. self.post_attention_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1068. self.pre_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1069. self.post_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1070. self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
  1071. self.act_fn = ACT2FN[config.hidden_activation]
  1072. self.altup = Gemma3nTextAltUp(config)
  1073. self.laurel = Gemma3nTextLaurelBlock(config)
  1074. self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
  1075. self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
  1076. self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1077. def forward(
  1078. self,
  1079. hidden_states: torch.Tensor,
  1080. position_embeddings: torch.Tensor = None,
  1081. per_layer_input: torch.Tensor = None,
  1082. attention_mask: torch.Tensor | None = None,
  1083. position_ids: torch.LongTensor | None = None,
  1084. past_key_values: Cache | None = None,
  1085. **kwargs: Unpack[TransformersKwargs],
  1086. ) -> tuple[torch.Tensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  1087. predictions = self.altup.predict(hidden_states)
  1088. active_prediction = predictions[self.config.altup_active_idx]
  1089. active_prediction_normed = self.input_layernorm(active_prediction)
  1090. laurel_output = self.laurel(active_prediction_normed)
  1091. attn, _ = self.self_attn(
  1092. hidden_states=active_prediction_normed,
  1093. attention_mask=attention_mask,
  1094. position_ids=position_ids,
  1095. position_embeddings=position_embeddings,
  1096. past_key_values=past_key_values,
  1097. **kwargs,
  1098. )
  1099. attn = self.post_attention_layernorm(attn)
  1100. attn_gated = active_prediction + attn
  1101. attn_laurel = (attn_gated + laurel_output) / math.sqrt(2)
  1102. attn_norm = self.pre_feedforward_layernorm(attn_laurel)
  1103. attn_ffw = self.mlp(attn_norm)
  1104. attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
  1105. attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
  1106. corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
  1107. first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
  1108. if self.config.altup_correct_scale:
  1109. first_prediction = self.altup.scale_corrected_output(first_prediction)
  1110. # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
  1111. first_prediction = self.per_layer_input_gate(first_prediction)
  1112. first_prediction = self.act_fn(first_prediction)
  1113. first_prediction = torch.multiply(first_prediction, per_layer_input)
  1114. # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
  1115. first_prediction = self.per_layer_projection(first_prediction)
  1116. first_prediction = self.post_per_layer_input_norm(first_prediction)
  1117. corrected_predictions[1:] += first_prediction
  1118. return corrected_predictions
  1119. @auto_docstring
  1120. class Gemma3nPreTrainedModel(PreTrainedModel):
  1121. config: Gemma3nConfig
  1122. base_model_prefix = "model"
  1123. supports_gradient_checkpointing = True
  1124. _no_split_modules = ["Gemma3nTextDecoderLayer"]
  1125. _skip_keys_device_placement = ["past_key_values"]
  1126. _supports_flash_attn = True
  1127. _supports_sdpa = True
  1128. _supports_flex_attn = True
  1129. _can_compile_fullgraph = True
  1130. _supports_attention_backend = True
  1131. _can_record_outputs = {
  1132. "hidden_states": Gemma3nTextDecoderLayer,
  1133. "attentions": Gemma3nTextAttention,
  1134. }
  1135. input_modalities = ("image", "text", "audio")
  1136. @torch.no_grad()
  1137. def _init_weights(self, module):
  1138. super()._init_weights(module)
  1139. if isinstance(module, Gemma3nAudioCumulativeGroupNorm):
  1140. init.ones_(module.weight)
  1141. elif isinstance(module, Gemma3nAudioAttention):
  1142. init.zeros_(module.per_dim_scale)
  1143. q_scale = module.head_dim**-0.5
  1144. r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
  1145. init.copy_(module.q_scale, q_scale * r_softplus_0)
  1146. init.constant_(module.softcap, module.attention_logits_soft_cap)
  1147. init.copy_(module.local_causal_valid_mask, module.create_local_causal_valid_mask())
  1148. elif isinstance(module, Gemma3nTextScaledWordEmbedding):
  1149. init.constant_(module.embed_scale, module.scalar_embed_scale)
  1150. elif isinstance(module, Gemma3nTextAltUp):
  1151. init.zeros_(module.correct_output_scale)
  1152. init.constant_(module.router_input_scale, self.config.hidden_size**-1.0)
  1153. elif isinstance(module, Gemma3nAudioRelativePositionEmbedding):
  1154. min_timescale, max_timescale = 1.0, 1.0e4
  1155. num_timescales = module.channels // 2
  1156. log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
  1157. num_timescales - 1, 1
  1158. )
  1159. inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
  1160. init.copy_(module.inv_timescales, inv_timescales.float().unsqueeze(0).unsqueeze(0))
  1161. elif isinstance(module, Gemma3nTextModel):
  1162. init.constant_(module.per_layer_projection_scale, self.hidden_size**-0.5)
  1163. init.constant_(module.per_layer_input_scale, 1 / math.sqrt(2.0))
  1164. elif isinstance(module, Gemma3nRotaryEmbedding):
  1165. for layer_type in module.layer_types:
  1166. rope_init_fn = module.compute_default_rope_parameters
  1167. if module.rope_type[layer_type] != "default":
  1168. rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
  1169. curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
  1170. init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
  1171. init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
  1172. if hasattr(module, "gradient_clipping"):
  1173. init.constant_(module.gradient_clipping, self.config.gradient_clipping)
  1174. class Gemma3nAudioEncoder(Gemma3nPreTrainedModel):
  1175. """
  1176. An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture.
  1177. """
  1178. config: Gemma3nAudioConfig
  1179. main_input_name = "audio_mel"
  1180. input_modalities = "audio"
  1181. def __init__(self, config: Gemma3nAudioConfig):
  1182. super().__init__(config)
  1183. self.config = config
  1184. self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
  1185. self.conformer = nn.ModuleList(
  1186. [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
  1187. )
  1188. self.post_init()
  1189. @merge_with_config_defaults
  1190. @capture_outputs
  1191. def forward(
  1192. self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs: Unpack[TransformersKwargs]
  1193. ) -> tuple | Gemma3nAudioEncoderModelOutput:
  1194. """Encodes a batch of MELs.
  1195. Args:
  1196. audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
  1197. mel_bins].
  1198. Returns:
  1199. audio_encodings: a torch.Tensor of shape
  1200. `[batch_size, self.config.audio_soft_tokens_per_image,
  1201. self.config.audio_config.hidden_size]`
  1202. audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
  1203. """
  1204. audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
  1205. # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
  1206. t_sub = audio_encodings.shape[1]
  1207. time_stride_product = 1
  1208. for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
  1209. time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
  1210. # Create indices for gathering from the original mask.
  1211. # These indices map to original time steps corresponding to the start of each
  1212. # receptive field in the subsampled output.
  1213. indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
  1214. indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) # Ensure indices are valid
  1215. # Expand indices for batch compatibility if B > 1 and indices is 1D.
  1216. if audio_mel_mask.ndim > 1 and indices.ndim == 1:
  1217. indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) # [B, T_sub]
  1218. elif (
  1219. audio_mel_mask.ndim == indices.ndim
  1220. and audio_mel_mask.shape[0] == 1
  1221. and indices.shape[0] != 1
  1222. and t_sub == indices.shape[0]
  1223. ):
  1224. # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
  1225. indices = indices.unsqueeze(0)
  1226. current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub]
  1227. for block in self.conformer:
  1228. audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
  1229. if self.config.conf_reduction_factor > 1:
  1230. audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
  1231. # Reduce the mask as well
  1232. current_mask = current_mask[:, :: self.config.conf_reduction_factor]
  1233. audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
  1234. return Gemma3nAudioEncoderModelOutput(
  1235. last_hidden_state=audio_encodings,
  1236. audio_mel_mask=current_mask,
  1237. )
  1238. class Gemma3nRotaryEmbedding(nn.Module):
  1239. inv_freq: torch.Tensor # fix linting for `register_buffer`
  1240. def __init__(self, config: Gemma3nTextConfig, device=None, layer_type=None):
  1241. super().__init__()
  1242. self.max_seq_len_cached = config.max_position_embeddings
  1243. self.original_max_seq_len = config.max_position_embeddings
  1244. self.config = config
  1245. self.layer_types = list(set(config.layer_types))
  1246. self.rope_type = {}
  1247. for layer_type in self.layer_types:
  1248. rope_params = self.config.rope_parameters[layer_type]
  1249. if rope_params is None:
  1250. continue
  1251. self.rope_type[layer_type] = rope_params["rope_type"]
  1252. rope_init_fn: Callable = self.compute_default_rope_parameters
  1253. if self.rope_type[layer_type] != "default":
  1254. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
  1255. curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
  1256. self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
  1257. self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
  1258. setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
  1259. @staticmethod
  1260. def compute_default_rope_parameters(
  1261. config: Gemma3nTextConfig | None = None,
  1262. device: Optional["torch.device"] = None,
  1263. seq_len: int | None = None,
  1264. layer_type: str | None = None,
  1265. ) -> tuple["torch.Tensor", float]:
  1266. """
  1267. Computes the inverse frequencies according to the original RoPE implementation
  1268. Args:
  1269. config ([`~transformers.PreTrainedConfig`]):
  1270. The model configuration.
  1271. device (`torch.device`):
  1272. The device to use for initialization of the inverse frequencies.
  1273. seq_len (`int`, *optional*):
  1274. The current sequence length. Unused for this type of RoPE.
  1275. layer_type (`str`, *optional*):
  1276. The current layer type if the model has different RoPE parameters per type.
  1277. Should not be used unless `config.layer_types is not None`
  1278. Returns:
  1279. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  1280. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  1281. """
  1282. # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
  1283. base = config.rope_parameters[layer_type]["rope_theta"]
  1284. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  1285. attention_factor = 1.0 # Unused in this type of RoPE
  1286. # Compute the inverse frequencies
  1287. inv_freq = 1.0 / (
  1288. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  1289. )
  1290. return inv_freq, attention_factor
  1291. @torch.no_grad()
  1292. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  1293. def forward(self, x, position_ids, layer_type=None):
  1294. inv_freq = getattr(self, f"{layer_type}_inv_freq")
  1295. attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
  1296. inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  1297. position_ids_expanded = position_ids[:, None, :].float()
  1298. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  1299. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  1300. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  1301. emb = torch.cat((freqs, freqs), dim=-1)
  1302. cos = emb.cos() * attention_scaling
  1303. sin = emb.sin() * attention_scaling
  1304. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  1305. @auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.")
  1306. class Gemma3nTextModel(Gemma3nPreTrainedModel):
  1307. config: Gemma3nTextConfig
  1308. input_modalities = ("text",)
  1309. def __init__(self, config: Gemma3nTextConfig):
  1310. super().__init__(config)
  1311. self.padding_idx = config.pad_token_id
  1312. self.vocab_size = config.vocab_size
  1313. # Gemma3n downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
  1314. self.embed_tokens = Gemma3nTextScaledWordEmbedding(
  1315. config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
  1316. )
  1317. self.layers = nn.ModuleList(
  1318. [Gemma3nTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  1319. )
  1320. self.norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  1321. self.rotary_emb = Gemma3nRotaryEmbedding(config)
  1322. self.gradient_checkpointing = False
  1323. self.hidden_size = config.hidden_size
  1324. self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
  1325. self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding(
  1326. config.vocab_size_per_layer_input,
  1327. config.num_hidden_layers * config.hidden_size_per_layer_input,
  1328. self.padding_idx,
  1329. embed_scale=config.hidden_size_per_layer_input**0.5,
  1330. )
  1331. self.per_layer_model_projection = nn.Linear(
  1332. self.hidden_size,
  1333. config.num_hidden_layers * config.hidden_size_per_layer_input,
  1334. bias=False,
  1335. )
  1336. self.per_layer_projection_norm = Gemma3nRMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps)
  1337. self.altup_projections = nn.ModuleList(
  1338. [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
  1339. )
  1340. self.altup_unembed_projections = nn.ModuleList(
  1341. [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
  1342. )
  1343. self.register_buffer("per_layer_projection_scale", torch.tensor(self.hidden_size**-0.5), persistent=False)
  1344. self.register_buffer("per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False)
  1345. # Initialize weights and apply final processing
  1346. self.post_init()
  1347. @merge_with_config_defaults
  1348. @capture_outputs(tie_last_hidden_states=False)
  1349. @auto_docstring
  1350. def forward(
  1351. self,
  1352. input_ids: torch.LongTensor | None = None,
  1353. per_layer_inputs: torch.Tensor | None = None,
  1354. attention_mask: torch.Tensor | None = None,
  1355. position_ids: torch.LongTensor | None = None,
  1356. past_key_values: Cache | None = None,
  1357. inputs_embeds: torch.FloatTensor | None = None,
  1358. use_cache: bool | None = None,
  1359. **kwargs: Unpack[TransformersKwargs],
  1360. ) -> BaseModelOutputWithPast:
  1361. r"""
  1362. per_layer_inputs (torch.Tensor, *optional*, defaults to None):
  1363. Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
  1364. """
  1365. if (input_ids is None) ^ (inputs_embeds is not None):
  1366. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1367. if input_ids is not None:
  1368. inputs_embeds = self.embed_tokens(input_ids)
  1369. per_layer_inputs = self.get_per_layer_inputs(input_ids)
  1370. per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
  1371. if use_cache and past_key_values is None:
  1372. past_key_values = DynamicCache(config=self.config)
  1373. if position_ids is None:
  1374. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1375. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  1376. position_ids = position_ids.unsqueeze(0)
  1377. # It may already have been prepared by e.g. `generate`
  1378. if not isinstance(causal_mask_mapping := attention_mask, dict):
  1379. # Prepare mask arguments
  1380. mask_kwargs = {
  1381. "config": self.config,
  1382. "inputs_embeds": inputs_embeds,
  1383. "attention_mask": attention_mask,
  1384. "past_key_values": past_key_values,
  1385. "position_ids": position_ids,
  1386. }
  1387. # Create the masks
  1388. causal_mask_mapping = {
  1389. "full_attention": create_causal_mask(**mask_kwargs),
  1390. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  1391. }
  1392. # embed positions
  1393. hidden_states_0 = inputs_embeds
  1394. # Expand hidden_states to support per-layer inputs
  1395. target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
  1396. epsilon_tensor = torch.tensor(1e-5)
  1397. temp_hidden_states = [hidden_states_0]
  1398. for i in range(1, self.config.altup_num_inputs):
  1399. # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
  1400. altup_proj = self.altup_projections[i - 1](hidden_states_0)
  1401. current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
  1402. new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
  1403. new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
  1404. current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
  1405. temp_hidden_states.append(current_hidden_state)
  1406. hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
  1407. position_embeddings = {}
  1408. for layer_type in self.config.layer_types:
  1409. position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
  1410. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  1411. causal_mask = causal_mask_mapping[self.config.layer_types[i]]
  1412. per_layer_input = per_layer_inputs[:, :, i, :]
  1413. hidden_states = decoder_layer(
  1414. hidden_states,
  1415. position_embeddings[self.config.layer_types[i]],
  1416. per_layer_input,
  1417. attention_mask=causal_mask,
  1418. position_ids=position_ids,
  1419. past_key_values=past_key_values,
  1420. **kwargs,
  1421. )
  1422. # Per-layer inputs to single output
  1423. target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
  1424. temp_hidden_states = [hidden_states[0]]
  1425. for i in range(1, self.config.altup_num_inputs):
  1426. # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
  1427. altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
  1428. current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
  1429. new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
  1430. new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
  1431. current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
  1432. temp_hidden_states.append(current_hidden_state)
  1433. hidden_states = torch.stack(temp_hidden_states)
  1434. hidden_states = torch.mean(hidden_states, dim=0)
  1435. hidden_states = self.norm(hidden_states)
  1436. return BaseModelOutputWithPast(
  1437. last_hidden_state=hidden_states,
  1438. past_key_values=past_key_values,
  1439. )
  1440. def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor:
  1441. return self.embed_tokens_per_layer(input_ids).reshape(
  1442. *input_ids.shape,
  1443. self.config.num_hidden_layers,
  1444. self.hidden_size_per_layer_input,
  1445. )
  1446. def project_per_layer_inputs(
  1447. self,
  1448. inputs_embeds: torch.Tensor,
  1449. per_layer_inputs: torch.Tensor | None = None,
  1450. ) -> torch.Tensor:
  1451. per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
  1452. per_layer_projection *= self.per_layer_projection_scale.to(
  1453. dtype=inputs_embeds.dtype, device=per_layer_projection.device
  1454. )
  1455. per_layer_projection = per_layer_projection.reshape(
  1456. *inputs_embeds.shape[:-1],
  1457. self.config.num_hidden_layers,
  1458. self.hidden_size_per_layer_input,
  1459. )
  1460. per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
  1461. if per_layer_inputs is None:
  1462. return per_layer_projection
  1463. if per_layer_projection.shape != per_layer_inputs.shape:
  1464. # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
  1465. per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
  1466. return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
  1467. dtype=inputs_embeds.dtype, device=per_layer_projection.device
  1468. )
  1469. @auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")
  1470. class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
  1471. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  1472. _tp_plan = {"lm_head": "colwise_gather_output"}
  1473. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  1474. config: Gemma3nTextConfig
  1475. def __init__(self, config: Gemma3nTextConfig):
  1476. super().__init__(config)
  1477. self.model = Gemma3nTextModel(config)
  1478. self.vocab_size = config.vocab_size
  1479. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1480. # Initialize weights and apply final processing
  1481. self.post_init()
  1482. @can_return_tuple
  1483. @auto_docstring
  1484. def forward(
  1485. self,
  1486. input_ids: torch.LongTensor | None = None,
  1487. attention_mask: torch.Tensor | None = None,
  1488. position_ids: torch.LongTensor | None = None,
  1489. past_key_values: Cache | None = None,
  1490. inputs_embeds: torch.FloatTensor | None = None,
  1491. labels: torch.LongTensor | None = None,
  1492. use_cache: bool | None = None,
  1493. logits_to_keep: int | torch.Tensor = 0,
  1494. **kwargs: Unpack[TransformersKwargs],
  1495. ) -> CausalLMOutputWithPast:
  1496. r"""
  1497. Example:
  1498. ```python
  1499. >>> from transformers import AutoTokenizer, Gemma3nForCausalLM
  1500. >>> model = Gemma3nForCausalLM.from_pretrained("google/gemma-2-9b")
  1501. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
  1502. >>> prompt = "What is your favorite condiment?"
  1503. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1504. >>> # Generate
  1505. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1506. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1507. "What is your favorite condiment?"
  1508. ```"""
  1509. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1510. outputs: BaseModelOutputWithPast = self.model(
  1511. input_ids=input_ids,
  1512. attention_mask=attention_mask,
  1513. position_ids=position_ids,
  1514. past_key_values=past_key_values,
  1515. inputs_embeds=inputs_embeds,
  1516. use_cache=use_cache,
  1517. **kwargs,
  1518. )
  1519. hidden_states = outputs.last_hidden_state
  1520. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1521. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1522. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1523. if self.config.final_logit_softcapping is not None:
  1524. logits = logits / self.config.final_logit_softcapping
  1525. logits = torch.tanh(logits)
  1526. logits = logits * self.config.final_logit_softcapping
  1527. loss = None
  1528. if labels is not None:
  1529. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  1530. return CausalLMOutputWithPast(
  1531. loss=loss,
  1532. logits=logits,
  1533. past_key_values=outputs.past_key_values,
  1534. hidden_states=outputs.hidden_states,
  1535. attentions=outputs.attentions,
  1536. )
  1537. class Gemma3nMultimodalEmbedder(nn.Module):
  1538. """Embeds token ids or soft tokens for multimodal content into language model space."""
  1539. def __init__(
  1540. self,
  1541. multimodal_config: Gemma3nAudioConfig | Gemma3nVisionConfig,
  1542. text_config: Gemma3nTextConfig,
  1543. ):
  1544. super().__init__()
  1545. self.multimodal_hidden_size = multimodal_config.hidden_size
  1546. self.eps = multimodal_config.rms_norm_eps
  1547. self.vocab_offset = multimodal_config.vocab_offset
  1548. self.vocab_size = multimodal_config.vocab_size
  1549. self.text_hidden_size = text_config.hidden_size
  1550. self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size)
  1551. self.hard_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
  1552. self.soft_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
  1553. self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
  1554. self.embedding_post_projection_norm = Gemma3nRMSNorm(self.text_hidden_size, eps=self.eps, with_scale=False)
  1555. def forward(
  1556. self,
  1557. input_ids: torch.LongTensor | None = None,
  1558. inputs_embeds: torch.Tensor | None = None,
  1559. ) -> torch.Tensor:
  1560. """Embeds token ids or soft tokens for multimodal content into language model space.
  1561. Args:
  1562. input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
  1563. `[vocab_offset, vocab_offset + vocab_size)`.
  1564. inputs_embeds: A torch.Tensor containing the soft tokens to embed.
  1565. Returns:
  1566. A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
  1567. """
  1568. if (input_ids is None) ^ (inputs_embeds is not None):
  1569. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1570. if inputs_embeds is not None:
  1571. emb_norm = self.soft_embedding_norm(inputs_embeds)
  1572. else:
  1573. hard_emb = self.embedding(input_ids - self.vocab_offset)
  1574. emb_norm = self.hard_embedding_norm(hard_emb)
  1575. emb_norm_proj = self.embedding_projection(emb_norm)
  1576. return self.embedding_post_projection_norm(emb_norm_proj)
  1577. @auto_docstring(
  1578. custom_intro="""
  1579. The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
  1580. language modeling head.
  1581. """
  1582. )
  1583. class Gemma3nModel(Gemma3nPreTrainedModel):
  1584. # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
  1585. accepts_loss_kwargs = False
  1586. def __init__(self, config: Gemma3nConfig):
  1587. super().__init__(config)
  1588. self.vision_tower = AutoModel.from_config(config=config.vision_config)
  1589. self.vocab_size = config.text_config.vocab_size
  1590. language_model = AutoModel.from_config(config=config.text_config)
  1591. self.language_model = language_model
  1592. self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
  1593. self.audio_tower = AutoModel.from_config(config.audio_config)
  1594. self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config)
  1595. self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config)
  1596. self.post_init()
  1597. def get_input_embeddings(self):
  1598. return self.language_model.get_input_embeddings()
  1599. def set_input_embeddings(self, value):
  1600. self.language_model.set_input_embeddings(value)
  1601. @can_return_tuple
  1602. @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.")
  1603. def get_image_features(
  1604. self,
  1605. pixel_values: torch.FloatTensor,
  1606. **kwargs: Unpack[TransformersKwargs],
  1607. ) -> tuple | BaseModelOutputWithPooling:
  1608. vision_outputs = self.vision_tower(pixel_values=pixel_values, do_pooling=False, return_dict=True, **kwargs)
  1609. last_hidden_state = vision_outputs.last_hidden_state
  1610. # Convert from (batch, channels, height, width) to (batch, height * width, channels) where:
  1611. # height == width and height * width == Gemma3nConfig.vision_soft_tokens_per_image.
  1612. last_hidden_state = last_hidden_state.reshape(
  1613. last_hidden_state.shape[0],
  1614. self.config.vision_config.hidden_size,
  1615. self.config.vision_soft_tokens_per_image,
  1616. ).permute(0, 2, 1)
  1617. # Normalize and embed the soft tokens into language model space.
  1618. last_hidden_state *= self.config.vision_config.hidden_size**0.5
  1619. vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state)
  1620. return vision_outputs
  1621. def get_placeholder_mask(
  1622. self,
  1623. input_ids: torch.LongTensor | None = None,
  1624. inputs_embeds: torch.FloatTensor | None = None,
  1625. image_features: torch.FloatTensor | None = None,
  1626. audio_features: torch.FloatTensor | None = None,
  1627. ):
  1628. """
  1629. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  1630. equal to the length of multimodal features. If the lengths are different, an error is raised.
  1631. """
  1632. if input_ids is None:
  1633. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  1634. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  1635. )
  1636. special_image_mask = special_image_mask.all(-1)
  1637. special_audio_mask = (
  1638. inputs_embeds
  1639. == self.get_input_embeddings()(
  1640. torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
  1641. )
  1642. ).all(-1)
  1643. else:
  1644. special_image_mask = input_ids == self.config.image_token_id
  1645. special_audio_mask = input_ids == self.config.audio_token_id
  1646. n_image_tokens = special_image_mask.sum()
  1647. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1648. if image_features is not None:
  1649. torch_compilable_check(
  1650. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  1651. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0] * image_features.shape[1]}",
  1652. )
  1653. n_audio_tokens = special_audio_mask.sum()
  1654. special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1655. if audio_features is not None:
  1656. torch_compilable_check(
  1657. inputs_embeds[special_audio_mask].numel() == audio_features.numel(),
  1658. f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {audio_features.shape[0] * audio_features.shape[1]}",
  1659. )
  1660. return special_image_mask, special_audio_mask
  1661. @can_return_tuple
  1662. def forward(
  1663. self,
  1664. input_ids: torch.LongTensor | None = None, # text inputs
  1665. pixel_values: torch.FloatTensor | None = None, # vision inputs
  1666. input_features: torch.FloatTensor | None = None, # audio inputs
  1667. attention_mask: torch.Tensor | None = None,
  1668. input_features_mask: torch.Tensor | None = None,
  1669. position_ids: torch.LongTensor | None = None,
  1670. past_key_values: Cache | None = None,
  1671. token_type_ids: torch.LongTensor | None = None,
  1672. inputs_embeds: torch.FloatTensor | None = None,
  1673. labels: torch.LongTensor | None = None,
  1674. use_cache: bool | None = None,
  1675. **lm_kwargs: Unpack[TransformersKwargs],
  1676. ) -> Gemma3nModelOutputWithPast:
  1677. r"""
  1678. input_features_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1679. Attention mask for `input_features` where non-zero values mark valid audio frames.
  1680. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1681. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1682. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1683. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
  1684. Example:
  1685. ```python
  1686. >>> from PIL import Image
  1687. >>> import httpx
  1688. >>> from io import BytesIO
  1689. >>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
  1690. >>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224")
  1691. >>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224")
  1692. >>> prompt = "Where is the cat standing?"
  1693. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
  1694. >>> with httpx.stream("GET", url) as response:
  1695. ... image = Image.open(BytesIO(response.read()))
  1696. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  1697. >>> # Generate
  1698. >>> generate_ids = model.generate(**inputs,)
  1699. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1700. "Where is the cat standing?\nsnow"
  1701. ```
  1702. """
  1703. if (input_ids is None) ^ (inputs_embeds is not None):
  1704. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1705. if input_ids is not None:
  1706. inputs_embeds = self.get_input_embeddings()(input_ids)
  1707. # Prepare per-layer inputs from inputs_ids
  1708. per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input)
  1709. per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids))
  1710. per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens)
  1711. # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset)
  1712. vision_mask = torch.logical_and(
  1713. input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset
  1714. )
  1715. dummy_vision_token_id = self.embed_vision.vocab_offset + self.embed_vision.vocab_size - 1
  1716. vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device)
  1717. vision_embeds = self.embed_vision(input_ids=vision_input_ids)
  1718. vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
  1719. expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds)
  1720. inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds)
  1721. # Handle audio tokens (>= embed_audio.vocab_offset)
  1722. audio_mask = input_ids >= self.embed_audio.vocab_offset
  1723. dummy_audio_token_id = self.embed_audio.vocab_offset + self.embed_audio.vocab_size - 1
  1724. audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device)
  1725. audio_embeds = self.embed_audio(input_ids=audio_input_ids)
  1726. audio_embeds = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
  1727. expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
  1728. inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds)
  1729. else:
  1730. per_layer_inputs = None
  1731. # Merge text and images
  1732. if pixel_values is not None:
  1733. image_features = self.get_image_features(pixel_values, return_dict=True).pooler_output
  1734. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  1735. special_image_mask, _ = self.get_placeholder_mask(
  1736. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  1737. )
  1738. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  1739. # Merge text and audio
  1740. if input_features is not None and input_features_mask is not None:
  1741. audio_outputs = self.get_audio_features(input_features, ~input_features_mask, return_dict=True)
  1742. audio_features = audio_outputs.pooler_output
  1743. audio_mask = audio_outputs.audio_mel_mask
  1744. # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
  1745. # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
  1746. # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
  1747. # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
  1748. # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
  1749. audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device)
  1750. audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
  1751. audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features)
  1752. audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
  1753. extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len
  1754. extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim)
  1755. audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
  1756. audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
  1757. _, special_audio_mask = self.get_placeholder_mask(
  1758. input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features
  1759. )
  1760. inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
  1761. outputs = self.language_model(
  1762. input_ids=None,
  1763. per_layer_inputs=per_layer_inputs,
  1764. attention_mask=attention_mask,
  1765. position_ids=position_ids,
  1766. past_key_values=past_key_values,
  1767. inputs_embeds=inputs_embeds,
  1768. use_cache=use_cache,
  1769. return_dict=True,
  1770. **lm_kwargs,
  1771. )
  1772. return Gemma3nModelOutputWithPast(
  1773. last_hidden_state=outputs.last_hidden_state,
  1774. past_key_values=outputs.past_key_values if use_cache else None,
  1775. hidden_states=outputs.hidden_states,
  1776. attentions=outputs.attentions,
  1777. image_hidden_states=image_features if pixel_values is not None else None,
  1778. audio_hidden_states=audio_features if input_features is not None else None,
  1779. )
  1780. @can_return_tuple
  1781. @auto_docstring(custom_intro="Projects the last hidden state from the audio encoder into language model space.")
  1782. def get_audio_features(
  1783. self,
  1784. input_features: torch.Tensor,
  1785. input_features_mask: torch.Tensor,
  1786. **kwargs: Unpack[TransformersKwargs],
  1787. ) -> tuple | Gemma3nAudioEncoderModelOutput:
  1788. r"""
  1789. input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`):
  1790. The tensors corresponding to the input audio.
  1791. input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
  1792. The attention mask for the input audio.
  1793. """
  1794. audio_outputs: Gemma3nAudioEncoderModelOutput = self.audio_tower(
  1795. input_features, input_features_mask, return_dict=True, **kwargs
  1796. )
  1797. audio_embeds = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state)
  1798. audio_outputs.pooler_output = audio_embeds
  1799. return audio_outputs
  1800. @auto_docstring(
  1801. custom_intro="""
  1802. The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
  1803. head.
  1804. """
  1805. )
  1806. class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
  1807. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  1808. def __init__(self, config: Gemma3nConfig):
  1809. super().__init__(config)
  1810. self.model = Gemma3nModel(config)
  1811. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  1812. self.post_init()
  1813. def get_input_embeddings(self):
  1814. return self.model.get_input_embeddings()
  1815. def set_input_embeddings(self, value):
  1816. self.model.set_input_embeddings(value)
  1817. @auto_docstring
  1818. def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]):
  1819. return self.model.get_image_features(pixel_values, **kwargs)
  1820. @can_return_tuple
  1821. @auto_docstring
  1822. def forward(
  1823. self,
  1824. input_ids: torch.LongTensor | None = None, # text inputs
  1825. pixel_values: torch.FloatTensor | None = None, # vision inputs
  1826. input_features: torch.FloatTensor | None = None, # audio inputs
  1827. attention_mask: torch.Tensor | None = None,
  1828. input_features_mask: torch.Tensor | None = None,
  1829. position_ids: torch.LongTensor | None = None,
  1830. past_key_values: Cache | None = None,
  1831. token_type_ids: torch.LongTensor | None = None,
  1832. inputs_embeds: torch.FloatTensor | None = None,
  1833. labels: torch.LongTensor | None = None,
  1834. use_cache: bool | None = None,
  1835. logits_to_keep: int | torch.Tensor = 0,
  1836. **lm_kwargs: Unpack[TransformersKwargs],
  1837. ) -> Gemma3nCausalLMOutputWithPast:
  1838. r"""
  1839. input_features_mask (torch.Tensor, *optional*, defaults to None):
  1840. The attention mask for the input audio.
  1841. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1842. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1843. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are
  1844. ignored (masked), the loss is only computed for the tokens with labels in
  1845. `[0, ..., config.text_config.vocab_size]`.
  1846. Example:
  1847. ```python
  1848. >>> from PIL import Image
  1849. >>> import httpx
  1850. >>> from io import BytesIO
  1851. >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
  1852. >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
  1853. >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
  1854. >>> messages = [
  1855. ... {
  1856. ... "role": "system",
  1857. ... "content": [
  1858. ... {"type": "text", "text": "You are a helpful assistant."}
  1859. ... ]
  1860. ... },
  1861. ... {
  1862. ... "role": "user", "content": [
  1863. ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
  1864. ... {"type": "text", "text": "Where is the cat standing?"},
  1865. ... ]
  1866. ... },
  1867. ... ]
  1868. >>> inputs = processor.apply_chat_template(
  1869. ... messages,
  1870. ... tokenizer=True,
  1871. ... return_dict=True,
  1872. ... return_tensors="pt",
  1873. ... add_generation_prompt=True
  1874. ... )
  1875. >>> # Generate
  1876. >>> generate_ids = model.generate(**inputs)
  1877. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1878. "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
  1879. ```
  1880. """
  1881. outputs = self.model(
  1882. input_ids=input_ids,
  1883. pixel_values=pixel_values,
  1884. input_features=input_features,
  1885. attention_mask=attention_mask,
  1886. input_features_mask=input_features_mask,
  1887. position_ids=position_ids,
  1888. past_key_values=past_key_values,
  1889. token_type_ids=token_type_ids,
  1890. inputs_embeds=inputs_embeds,
  1891. labels=labels,
  1892. use_cache=use_cache,
  1893. return_dict=True,
  1894. **lm_kwargs,
  1895. )
  1896. hidden_states = outputs.last_hidden_state
  1897. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1898. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1899. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1900. if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None:
  1901. logits = logits / final_logit_softcapping
  1902. logits = torch.tanh(logits)
  1903. logits = logits * final_logit_softcapping
  1904. loss = None
  1905. if labels is not None:
  1906. # Upcast to float if we need to compute the loss to avoid potential precision issues
  1907. logits = logits.float()
  1908. shift_logits = logits[..., :-1, :]
  1909. shift_labels = labels[..., 1:]
  1910. if attention_mask is not None:
  1911. # we use the input attention mask to shift the logits and labels, because it is 2D.
  1912. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
  1913. shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
  1914. shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
  1915. shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
  1916. else:
  1917. shift_logits = shift_logits.contiguous()
  1918. shift_labels = shift_labels.contiguous()
  1919. # Flatten the tokens
  1920. loss_fct = nn.CrossEntropyLoss()
  1921. flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
  1922. flat_labels = shift_labels.view(-1).to(shift_logits.device)
  1923. loss = loss_fct(flat_logits, flat_labels)
  1924. return Gemma3nCausalLMOutputWithPast(
  1925. loss=loss,
  1926. logits=logits,
  1927. past_key_values=outputs.past_key_values,
  1928. hidden_states=outputs.hidden_states,
  1929. attentions=outputs.attentions,
  1930. image_hidden_states=outputs.image_hidden_states,
  1931. audio_hidden_states=outputs.audio_hidden_states,
  1932. )
  1933. def prepare_inputs_for_generation(
  1934. self,
  1935. input_ids,
  1936. past_key_values=None,
  1937. inputs_embeds=None,
  1938. position_ids=None,
  1939. pixel_values=None,
  1940. input_features=None,
  1941. attention_mask=None,
  1942. input_features_mask=None,
  1943. token_type_ids=None,
  1944. use_cache=True,
  1945. logits_to_keep=None,
  1946. labels=None,
  1947. is_first_iteration=False,
  1948. **kwargs,
  1949. ):
  1950. # Overwritten -- custom `position_ids` and `pixel_values` handling
  1951. model_inputs = super().prepare_inputs_for_generation(
  1952. input_ids,
  1953. past_key_values=past_key_values,
  1954. inputs_embeds=inputs_embeds,
  1955. attention_mask=attention_mask,
  1956. position_ids=position_ids,
  1957. use_cache=use_cache,
  1958. logits_to_keep=logits_to_keep,
  1959. token_type_ids=token_type_ids,
  1960. is_first_iteration=is_first_iteration,
  1961. **kwargs,
  1962. )
  1963. # If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
  1964. # tokens anymore. Otherwise multimodal inputs should be passed to model.
  1965. # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
  1966. if is_first_iteration or not use_cache:
  1967. model_inputs["pixel_values"] = pixel_values
  1968. model_inputs["input_features"] = input_features
  1969. model_inputs["input_features_mask"] = input_features_mask
  1970. return model_inputs
  1971. __all__ = [
  1972. "Gemma3nAudioEncoder",
  1973. "Gemma3nForCausalLM",
  1974. "Gemma3nForConditionalGeneration",
  1975. "Gemma3nModel",
  1976. "Gemma3nPreTrainedModel",
  1977. "Gemma3nTextModel",
  1978. ]