| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_gemma3n.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
- #
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from collections.abc import Callable, Sequence
- from dataclasses import dataclass
- from typing import Optional
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache
- from ...generation import GenerationMixin
- from ...integrations import use_kernelized_func
- from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import (
- ModelOutput,
- TransformersKwargs,
- auto_docstring,
- can_return_tuple,
- torch_compilable_check,
- )
- from ...utils.generic import maybe_autocast, merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from ..auto import AutoModel
- from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig
- @dataclass
- @auto_docstring
- class Gemma3nAudioEncoderModelOutput(BaseModelOutputWithPooling):
- r"""
- audio_mel_mask (`torch.BoolTensor`, *optional*):
- A torch.BoolTensor of shape `(batch_size, num_frames)`
- """
- audio_mel_mask: torch.BoolTensor | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for Gemma3n outputs, with hidden states and attentions.
- """
- )
- class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
- r"""
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
- audio_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
- audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
- """
- image_hidden_states: torch.FloatTensor | None = None
- audio_hidden_states: torch.FloatTensor | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for Gemma3n causal language model (or autoregressive) outputs.
- """
- )
- class Gemma3nCausalLMOutputWithPast(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
- audio_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
- audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor | None = None
- past_key_values: Cache | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- image_hidden_states: torch.FloatTensor | None = None
- audio_hidden_states: torch.FloatTensor | None = None
- class Gemma3nRMSNorm(nn.Module):
- def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
- super().__init__()
- self.eps = eps
- self.with_scale = with_scale
- if self.with_scale:
- self.weight = nn.Parameter(torch.ones(dim), requires_grad=True)
- def _norm(self, hidden_states: torch.Tensor):
- mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps
- # Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX
- return hidden_states * torch.pow(mean_squared, -0.5)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- normed_output = self._norm(hidden_states.float())
- if self.with_scale:
- normed_output = normed_output * self.weight.float()
- return normed_output.type_as(hidden_states)
- # ==== Audio Encoder ====
- class Gemma3nAudioRelativePositionEmbedding(nn.Module):
- def __init__(self, config: Gemma3nAudioConfig):
- super().__init__()
- self.config = config
- self.num_heads = self.config.conf_num_attention_heads
- self.channels = self.config.hidden_size
- self.head_dim = self.channels // self.num_heads
- self.max_backward = max(0, self.config.conf_attention_context_left - 1)
- self.max_forward = self.config.conf_attention_context_right
- self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
- min_timescale = 1.0
- max_timescale = 1.0e4
- num_timescales = self.channels // 2
- log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
- inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
- self.register_buffer(
- "inv_timescales",
- inv_timescales.float().unsqueeze(0).unsqueeze(0),
- persistent=False,
- )
- def _get_timing_signal_1d_pos(self, position: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
- position = position.float().unsqueeze(-1)
- scaled_time = position * self.inv_timescales.to(device=position.device, dtype=torch.float32)
- timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
- return timing_signal.type(dtype)
- def _relative_shift(
- self,
- term_bd_before_shift: torch.Tensor,
- batch_size: int,
- num_heads: int,
- num_query_blocks: int,
- query_block_size: int,
- key_context_size: int,
- max_span_plus_1: int,
- ) -> torch.Tensor:
- """Performs the relative shift.
- Args:
- term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
- (B), num_heads (N), num_query_blocks (U), query_block_size (W),
- key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1).
- Returns:
- Tensor of shape [B, N, U, W, C].
- """
- # term_bd_before_shift shape: [B, N, U, W, F_span]
- # Target shape after shift: [B, N, U, W, C]
- # Padding amount for the last dimension (F_span) to become (C + 1)
- # C = key_context_size
- # F_span = max_span_plus_1
- pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
- # PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...)
- # We only pad the last dimension on the right.
- padding_tuple = (0, pad_amount_last_dim)
- term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
- # Shape after pad: [B, N, U, W, C+1]
- # Reshape for slicing (emulating JAX's behavior)
- # [B, N, U, W * (C+1)]
- term_bd_reshaped = term_bd_padded.reshape(
- (
- batch_size,
- num_heads,
- num_query_blocks,
- query_block_size * (key_context_size + 1),
- )
- )
- # Slice to effective [B, N, U, W * C]
- term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size]
- # Reshape back to [B, N, U, W, C]
- term_bd_shifted = term_bd_sliced.reshape(
- (
- batch_size,
- num_heads,
- num_query_blocks,
- query_block_size,
- key_context_size,
- )
- )
- return term_bd_shifted
- def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
- # queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim)
- # keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim)
- # C = W + L + R (key_context_size)
- # F_span = L + R + 1 (max_span + 1)
- batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
- _, _, key_context_size, _, _ = keys.shape
- # Relative positions for sinusoidal embeddings: [L, L-1, ..., -R]
- # Length is L+R+1 = self.max_span + 1
- pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze(
- 0
- ) # Shape [1, F_span]
- max_span_plus_1 = pos_indices.shape[1] # F_span
- sin_emb_timing_signal = self._get_timing_signal_1d_pos(
- pos_indices, dtype=queries.dtype
- ) # Shape [1, F_span, self.channels]
- # Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H]
- projected_sin_emb = self.pos_proj(sin_emb_timing_signal)
- # Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H]
- sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze(
- 0
- ) # Shape [F, N, H]
- # term_ac: Query-Key content interaction
- # queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul
- # keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul
- queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H]
- keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C]
- term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
- # term_bd: Query-Position interaction
- # Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
- # queries shape: [B, U, W, N, H]
- # sin_emb shape: [F, N, H]
- # Target output shape: [B, N, U, W, F]
- # Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb
- q_permuted = queries.permute(0, 3, 1, 2, 4)
- # Permute sin_emb to [N, H, F] to prepare for matmul
- # sin_emb original is [F, N, H]
- s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F]
- # Reshape queries for matmul: [B, N, U*W, H]
- q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
- # Perform matmul: [B, N, U*W, H] @ [N, H, F]
- # s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
- # Result: [B, N, U*W, F]
- term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
- # Reshape to target [B, N, U, W, F]
- term_bd_unshifed = term_bd_unshifed_matmul.reshape(
- batch_size,
- num_heads,
- num_query_blocks,
- query_block_size,
- max_span_plus_1,
- )
- # Apply relative shift to term_bd_unshifed
- term_bd_shifted = self._relative_shift(
- term_bd_unshifed,
- batch_size,
- num_heads,
- num_query_blocks,
- query_block_size,
- key_context_size,
- max_span_plus_1,
- ) # Shape [B, N, U, W, C]
- return term_ac + term_bd_shifted
- class Gemma3nAudioAttention(nn.Module):
- def __init__(self, config: Gemma3nAudioConfig):
- super().__init__()
- self.config = config
- self.num_heads = self.config.conf_num_attention_heads
- self.hidden_size = self.config.hidden_size
- self.head_dim = self.hidden_size // self.num_heads
- self.chunk_size = self.config.conf_attention_chunk_size
- self.max_future_horizon = self.config.conf_attention_context_right
- self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
- self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
- self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
- self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config)
- self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
- self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
- self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
- q_scale = self.head_dim**-0.5
- r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
- self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
- local_causal_valid_mask = self.create_local_causal_valid_mask()
- self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
- self.register_buffer(
- "softcap",
- torch.tensor(self.attention_logits_soft_cap).float(),
- persistent=False,
- )
- def create_local_causal_valid_mask(self):
- lower_causal_mask = torch.tril(
- torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
- diagonal=0,
- ).T
- upper_causal_mask = torch.tril(
- torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
- diagonal=self.max_past_horizon + self.max_future_horizon,
- )
- local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
- local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
- return local_causal_valid_mask
- def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
- batch, _, *tail_shape = x.shape
- left = x.new_zeros((batch, pad_left, *tail_shape))
- right = x.new_zeros((batch, pad_right, *tail_shape))
- x = torch.cat([left, x, right], dim=1)
- return x
- def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """Turns a sequence to non overlapping blocks.
- Args:
- hidden_states: a tensor of [batch, time, ...].
- Returns:
- A tensor of [batch, num_blocks, block_size, ...], with necessary
- paddings,
- where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
- """
- shape = hidden_states.shape
- b, t = shape[:2]
- num_blocks = (t + self.chunk_size - 1) // self.chunk_size
- if (padding_len := num_blocks * self.chunk_size - t) > 0:
- hidden_states = self._pad_dim1(hidden_states, 0, padding_len)
- permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
- hidden_states = hidden_states.reshape(permute_dims).contiguous()
- return hidden_states
- def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """Extracts temporal context for every block.
- Args:
- hidden_states: a tensor of [batch, time, ...].
- Returns:
- A tensor of [batch, num_blocks, context_size, ...], with necessary
- paddings,
- where context_size = block_size + left_context + right_context,
- and output[:, i, ...] are x[:, start-left_context:end+right_context,
- ...],
- start = i * block_size, end = (i + 1) * block_size.
- """
- pad_left = self.max_past_horizon
- # The JAX equivalent padding for signal.frame with pad_mode='valid' is
- # (left_context, right_context + block_size - 1) on the time dimension.
- # PyTorch's _pad_dim1 applies padding symmetrically if only one value is given,
- # or (pad_dim_start, pad_dim_end) if two are given.
- # Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
- # or dim 1 (time for [B,T]).
- # The current pad_right calculation matches the JAX effective padding.
- pad_right = self.max_future_horizon + self.chunk_size - 1
- hidden_states = self._pad_dim1(hidden_states, pad_left, pad_right)
- frame_len = self.context_size
- frame_step = self.chunk_size
- # Directly use unfold without the subframe_factor logic
- # x.unfold(dimension, size, step)
- # dimension=1 (time dimension, assuming x is [B, T_padded, ...])
- # size=frame_len (context_size)
- # step=frame_step (chunk_size)
- x_unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step)
- # If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
- # If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
- # We want to match JAX's typical output for such operations which might be
- # [B, num_blocks, frame_len, N, H] if N, H are present.
- # The relative_position_embedding expects keys as [B, U, C, N, H].
- # If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
- if hidden_states.ndim > 2 and x_unfolded.ndim > 3: # Check if inner dimensions (like N, H) exist
- # Current shape after unfold for [B, T_pad, N, H] is [B, U, N, H, C]
- # Target shape for keys in RPE: [B, U, C, N, H]
- x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2)
- return x_unfolded.contiguous()
- def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
- # sl.Dense uses jax.numpy.einsum("...a,abcd->...bcd") and jax.numpy.select()
- qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
- query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
- key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
- value_states = self.v_proj(hidden_states).reshape(qkv_shape).contiguous()
- per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale)
- broadcast_shape = (1, 1, 1, self.head_dim)
- per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape)
- query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast
- batch_size, q_time = query_states.shape[:2]
- query_blocks = self._convert_to_block(query_states)
- key_blocks = self._extract_block_context(key_states)
- value_blocks = self._extract_block_context(value_states)
- num_query_blocks = query_blocks.shape[1]
- # 1. Create a mask indicating originally valid positions.
- original_valid_mask = ~mask # True for valid, False for padded
- # 2. Extract blocks from this validity mask.
- extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask)
- # If subframe_factor was used in _extract_block_context for a [B, T] input mask,
- # the shape might be [B, U, C/SF, SF]. Reshape to [B, U, C].
- # batch_size and num_query_blocks are known from query_blocks.
- # self.context_size is C.
- if (
- extracted_valid_mask_blocks.ndim == 4
- and extracted_valid_mask_blocks.shape[2] * extracted_valid_mask_blocks.shape[3] == self.context_size
- ):
- extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
- batch_size, num_query_blocks, self.context_size
- )
- # After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask.
- # This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
- # but for the mask case, this should hold.
- if extracted_valid_mask_blocks.shape != (
- batch_size,
- num_query_blocks,
- self.context_size,
- ):
- raise ValueError(
- "Shape of extracted_valid_mask_blocks"
- f" {extracted_valid_mask_blocks.shape} is not ({batch_size},"
- f" {num_query_blocks}, {self.context_size}) after potential reshape."
- )
- # 3. Expand dimensions for broadcasting with logits and causal mask.
- # Target shape for broadcasting with logits [B,N,U,W,C]
- # extracted_valid_mask_blocks to [B, 1, U, 1, C]
- condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(1).unsqueeze(-2)
- # self.local_causal_valid_mask is [W, C], True where allowed by local window.
- # Expand to [1, 1, 1, W, C]
- condition_from_causality = self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
- # 4. Combine the two conditions.
- # final_condition will be True where a key is *both* originally valid *and* causally accessible.
- # Broadcasts to [B, 1, U, W, C]
- final_condition_for_where = torch.logical_and(
- condition_from_input_validity,
- condition_from_causality.to(condition_from_input_validity.device), # Ensure same device
- )
- # Embed queries and keys
- logits = self.relative_position_embedding(query_blocks, key_blocks)
- # Apply attention logit softcap
- # Ensure softcap is on the same device as logits
- softcap_val = self.softcap.to(logits.device)
- logits = logits / softcap_val
- logits = torch.tanh(logits)
- logits = logits * softcap_val
- # Apply the combined mask.
- # final_condition_for_where will broadcast with logits [B,N,U,W,C]
- logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
- probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
- # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
- b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
- h_dim = value_blocks.shape[-1]
- prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
- v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
- result_bmm = torch.bmm(prob_bun, v_bun)
- context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4)
- context_vectors = context_vectors.reshape(
- (
- batch_size,
- num_query_blocks * self.chunk_size,
- self.num_heads,
- self.head_dim,
- )
- )
- context_vectors = context_vectors[:, :q_time]
- return context_vectors
- class Gemma3nAudioCumulativeGroupNorm(nn.Module):
- """Applies Group Normalization cumulatively over the time dimension.
- This layer normalizes the input by calculating the mean and variance
- cumulatively over the time dimension (dim 1). The statistics are computed
- over all feature dimensions (specified by `feature_dims` and `num_channels`)
- for elements marked as valid by the optional `mask`.
- If a `mask` is provided (True for valid, False for invalid/padded),
- invalid time steps do not contribute to the statistics calculation, and
- their corresponding output values are zeroed out.
- Scale and bias, if enabled, are applied per-channel (last dimension).
- This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
- and `cumulative=True`.
- """
- def __init__(
- self,
- num_channels: int, # Number of channels (size of the last dimension)
- feature_dims: Sequence[int], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
- eps: float = 1e-3,
- ):
- super().__init__()
- self.num_channels = num_channels
- self.feature_dims = tuple(feature_dims)
- self.eps = eps
- # Scale parameter depends only on the channel dimension
- self.weight = nn.Parameter(torch.ones(num_channels))
- # Axes for normalization: all dimensions except Batch (0) and Time (1).
- # For input [B, T, *feature_dims, C], these are dims from 2 onwards.
- self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """Applies cumulative group norm, optionally using a mask.
- Args:
- hidden_states: Input tensor, shape [B, T, *feature_dims, C].
- Returns:
- Normalized tensor with the same shape as x.
- """
- expected_input_suffix = self.feature_dims + (self.num_channels,)
- if hidden_states.shape[2:] != expected_input_suffix:
- raise ValueError(
- f"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected"
- f" suffix (feature_dims + num_channels) {expected_input_suffix}"
- )
- input_dtype = hidden_states.dtype
- # Calculations are performed in float32 for numerical stability.
- calc_dtype = torch.float32
- x_calc = hidden_states.to(calc_dtype)
- # Prepare a broadcastable mask (`mask_calc`).
- # If no mask is provided, treat all elements as valid
- # (mask_calc is all ones).
- # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
- mask_calc = torch.ones_like(x_calc, dtype=calc_dtype)
- # Cumulative Statistics Calculation
- # 1. Sum of values over reduction axes at each time step.
- sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True)
- # 2. Cumulative sum of values over time.
- cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
- # 3. Count of valid elements in the normalization group at each time step.
- # (A "group" here consists of all features at a given Batch, Time).
- elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
- # 4. Cumulative count of valid elements over time.
- cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
- # Avoid division by zero if all preceding elements were masked.
- safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0)
- # 5. Cumulative mean.
- cum_mean = cum_sum_values / safe_cum_count_elements
- # 6. Sum of squared differences from the cumulative mean.
- # Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
- # Using x_calc here for the difference, as cum_mean already accounts for masking.
- squared_diff_from_mean = (x_calc - cum_mean).pow(2)
- sum_sq_diff_at_t = torch.sum(squared_diff_from_mean, dim=self.reduction_axes, keepdim=True)
- # 7. Cumulative sum of squared differences over time.
- cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1)
- # 8. Cumulative variance.
- cum_variance = cum_sum_sq_diff / safe_cum_count_elements
- # Normalize the input using the calculated cumulative statistics:
- # (x - E[x]) / sqrt(Var[x] + eps)
- normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps)
- # Apply affine transformation (scale and bias) if enabled.
- # Scale and bias are applied per-channel (last dimension).
- scale = self.weight.to(calc_dtype)
- # Reshape for broadcasting: [C] -> [1, ..., 1, C]
- scale_view_shape = [1] * (hidden_states.dim() - 1) + [self.num_channels]
- normalized_x = normalized_x * scale.view(scale_view_shape)
- # Zero out outputs for time steps that were originally masked (where mask_calc is 0).
- # This ensures padded/invalid positions in the input result in zero output.
- final_output = normalized_x * mask_calc
- return final_output.to(input_dtype)
- class Gemma3nAudioSSCPConvBlock(nn.Module):
- """A single convolution block for the SubSampleConvProjection.
- This block consists of a 2D convolution, followed by CumulativeGroupNorm,
- and a ReLU activation. It handles manual padding for the convolution.
- """
- def __init__(
- self,
- config: Gemma3nAudioConfig,
- idx: int,
- input_freq_dim: int, # Changed from input_spatial_dim
- manual_padding: tuple[int, int, int, int] = (0, 0, 0, 0),
- ):
- super().__init__()
- self.config = config
- self.manual_padding = manual_padding
- # in_channels is 1 for the first block, or C_out from previous block's conv
- in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
- out_channels = self.config.sscp_conv_channel_size[idx]
- kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
- stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
- self.conv = nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=(
- kernel_h,
- kernel_w,
- ), # Kernel (kH, kW) operates on (Time, Freq_dim)
- stride=(stride_h, stride_w),
- padding=(0, 0), # Manual padding is used
- bias=False,
- )
- # Calculate output frequency dimension (f_out_conv) after this convolution.
- # input_freq_dim is the unpadded width (feature dimension).
- # self.manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
- f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1]
- f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
- self.norm = Gemma3nAudioCumulativeGroupNorm(
- num_channels=out_channels, # Channels of the conv output
- feature_dims=(f_out_conv,), # The frequency dimension size after conv
- eps=self.config.sscp_conv_group_norm_eps,
- )
- self.activation = nn.ReLU()
- def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
- # Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
- # manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
- # F.pad applies to last two dims: F_in then T_in
- audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0).to(
- self.conv.weight.dtype
- )
- # Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
- # Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
- audio_encodings_conv = self.conv(audio_encodings_padded)
- # Expected conv output shape: [B, C_out, T_out, F_out]
- # Input to norm is [B, T_out, F_out, C_out]
- x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous()
- x_normed = self.norm(x_for_norm)
- # Output of norm is [B, T_out, F_out, C_out], permute back to [B, C_out, T_out, F_out]
- audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous()
- return self.activation(audio_encodings_normed)
- class Gemma3nAudioSubSampleConvProjection(nn.Module):
- def __init__(self, config: Gemma3nAudioConfig):
- super().__init__()
- self.config = config
- current_f_for_block_input = config.input_feat_size # Start with original feature dim
- calculated_block_padding = []
- calculated_f_out_dims = [] # Tracking frequency dimension output sizes
- for i in range(2): # Assuming 2 conv layers as per sscp_conv_... arrays
- kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
- stride_h, stride_w = config.sscp_conv_stride_size[i]
- # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
- # JAX 'reverse_causal' padding is (0, kernel_size - 1)
- pad_t_top = 0
- pad_t_bottom = kernel_h - 1
- # Frequency Padding (Width for Conv2d)
- # Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
- # and the successful test configuration.
- # If kernel/stride/input_freq for frequency changes, this might need re-evaluation
- # to match generic JAX 'SAME' behavior if it differs.
- pad_f_left = 1
- pad_f_right = 1
- manual_padding_tuple = (
- pad_f_left,
- pad_f_right,
- pad_t_top,
- pad_t_bottom,
- )
- calculated_block_padding.append(manual_padding_tuple)
- # Calculate output frequency dimension after this convolution
- # This uses the actual padding applied and kernel/stride.
- f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
- f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 # Assuming dilation_w = 1
- calculated_f_out_dims.append(f_out_after_conv)
- current_f_for_block_input = f_out_after_conv
- self.conv_0 = Gemma3nAudioSSCPConvBlock(
- idx=0,
- input_freq_dim=config.input_feat_size, # Pass original feature dim
- config=config,
- manual_padding=calculated_block_padding[0],
- )
- self.conv_1 = Gemma3nAudioSSCPConvBlock(
- idx=1,
- input_freq_dim=calculated_f_out_dims[0], # Output freq dim from conv_0
- config=config,
- manual_padding=calculated_block_padding[1],
- )
- final_c_out = config.sscp_conv_channel_size[-1]
- final_f_out = calculated_f_out_dims[-1] # Final frequency dimension
- self.input_proj_in_features = final_c_out * final_f_out
- self.input_proj_linear = nn.Linear(self.input_proj_in_features, self.config.hidden_size, bias=False)
- def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
- # audio_encodings is [B, T, F_in]
- # Reshape to [B, 1, T, F_in] (Batch, Channels=1, Height=Time, Width=F_in)
- audio_encodings_reshaped = audio_encodings.unsqueeze(1)
- x = self.conv_0(audio_encodings_reshaped)
- x = self.conv_1(x)
- # x from conv_1 is [B, C_out_1, T_out_1, F_out_1]
- b, c_out, t_out, f_out = x.shape
- # Permute to [B, T_out_1, F_out_1, C_out_1] then flatten F_out_1 and C_out_1
- x_permuted = x.permute(0, 2, 3, 1).contiguous()
- output_flattened = x_permuted.view(b, t_out, f_out * c_out)
- output = self.input_proj_linear(output_flattened)
- return output
- class Gemma3nAudioConformerAttention(nn.Module):
- def __init__(self, config: Gemma3nAudioConfig):
- super().__init__()
- self.config = config
- self.post_in_features = self.config.hidden_size
- self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
- self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
- self.attn = Gemma3nAudioAttention(config)
- self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
- self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
- def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
- audio_encodings_input_to_attn = audio_encodings
- audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
- audio_encodings_norm = self.pre_attn_norm(audio_encodings)
- # Output of self.attn is [B, T, NumHeads, HeadDim]
- audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask)
- # Reshape from [B, T, NumHeads, HeadDim] to [B, T, NumHeads * HeadDim]
- # NumHeads * HeadDim = hidden_size
- b, t, num_heads, head_dim = audio_encodings_attn_out.shape
- audio_encodings_reshaped = audio_encodings_attn_out.reshape(b, t, num_heads * head_dim)
- audio_encodings = self.post(audio_encodings_reshaped)
- audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
- return audio_encodings_input_to_attn + self.post_norm(audio_encodings)
- class Gemma3nAudioConformerFeedForward(nn.Module):
- def __init__(self, config: Gemma3nAudioConfig):
- super().__init__()
- self.config = config
- self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
- self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
- self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
- self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
- self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
- self.post_layer_scale = self.config.conf_residual_weight
- def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
- residual = audio_encodings
- audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
- audio_encodings = self.pre_layer_norm(audio_encodings)
- audio_encodings: torch.Tensor = self.ffw_layer_1(audio_encodings)
- audio_encodings = nn.functional.silu(audio_encodings)
- audio_encodings: torch.Tensor = self.ffw_layer_2(audio_encodings)
- audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
- audio_encodings = self.post_layer_norm(audio_encodings)
- return residual + (audio_encodings * self.post_layer_scale)
- class Gemma3nAudioConformerLightConv1d(nn.Module):
- def __init__(self, config: Gemma3nAudioConfig):
- super().__init__()
- self.config = config
- self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
- self.linear_start = nn.Linear(self.config.hidden_size, self.config.hidden_size * 2, bias=False)
- self.depthwise_conv1d = nn.Conv1d(
- in_channels=self.config.hidden_size,
- out_channels=self.config.hidden_size,
- kernel_size=self.config.conf_conv_kernel_size,
- stride=1,
- padding=0, # Manual causal padding
- groups=self.config.hidden_size, # Depthwise
- bias=False,
- )
- self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
- self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
- self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
- self.causal_padding = self.config.conf_conv_kernel_size - 1
- def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
- audio_encodings_residual = audio_encodings # Save for residual connection
- audio_encodings = self.pre_layer_norm(audio_encodings)
- audio_encodings = self.linear_start(audio_encodings)
- audio_encodings = torch.nn.functional.glu(audio_encodings, dim=-1)
- # Permute for Conv1d: [B, T, D] -> [B, D, T]
- audio_encodings_permuted = audio_encodings.permute(0, 2, 1)
- # Apply manual causal padding
- audio_encodings_permuted_padded = F.pad(audio_encodings_permuted, (self.causal_padding, 0))
- audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded)
- # Permute back: [B, D, T_out] -> [B, T_out, D]
- audio_encodings = audio_encodings.permute(0, 2, 1)
- audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
- audio_encodings = self.conv_norm(audio_encodings)
- audio_encodings = nn.functional.silu(audio_encodings)
- audio_encodings = self.linear_end(audio_encodings)
- output = audio_encodings + audio_encodings_residual
- return output
- class Gemma3nAudioConformerBlock(nn.Module):
- def __init__(self, config: Gemma3nAudioConfig):
- super().__init__()
- self.config = config
- self.ffw_layer_start = Gemma3nAudioConformerFeedForward(self.config)
- self.attention = Gemma3nAudioConformerAttention(self.config)
- self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
- self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
- self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
- self.norm = Gemma3nRMSNorm(self.config.hidden_size)
- def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
- audio_encodings = self.ffw_layer_start(audio_encodings)
- audio_encodings = self.attention(audio_encodings, audio_mel_mask)
- validity_mask_for_lconv = ~audio_mel_mask # True for valid
- audio_encodings_for_lconv_input = audio_encodings * validity_mask_for_lconv.unsqueeze(-1).to(
- audio_encodings.dtype
- )
- audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
- audio_encodings = self.ffw_layer_end(audio_encodings)
- audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
- output = self.norm(audio_encodings)
- return output
- class Gemma3nTextScaledWordEmbedding(nn.Embedding):
- """
- This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
- """
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
- super().__init__(num_embeddings, embedding_dim, padding_idx)
- self.scalar_embed_scale = embed_scale
- self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
- def forward(self, input_ids: torch.Tensor):
- return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
- class Gemma3nTextLaurelBlock(nn.Module):
- """Learned Augmented Residual Layer"""
- def __init__(self, config: Gemma3nTextConfig):
- super().__init__()
- self.config = config
- self.linear_left = nn.Linear(self.config.hidden_size, self.config.laurel_rank, bias=False)
- self.linear_right = nn.Linear(self.config.laurel_rank, self.config.hidden_size, bias=False)
- self.post_laurel_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- laurel_hidden_states: torch.Tensor = self.linear_left(hidden_states)
- laurel_hidden_states: torch.Tensor = self.linear_right(laurel_hidden_states)
- normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states)
- return hidden_states + normed_laurel_hidden_states
- class Gemma3nTextMLP(nn.Module):
- def __init__(self, config: Gemma3nTextConfig, layer_idx: int = 0):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size[layer_idx]
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_activation]
- self.activation_sparsity = config.activation_sparsity_pattern[layer_idx]
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- gate_proj = self.gate_proj(hidden_states)
- if self.activation_sparsity > 0.0:
- gate_proj = self._gaussian_topk(gate_proj)
- activations = self.act_fn(gate_proj)
- up_proj = self.up_proj(hidden_states)
- down_proj = self.down_proj(activations * up_proj)
- return down_proj
- def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
- target_sparsity_tensor = torch.tensor(self.activation_sparsity, dtype=torch.float32, device=inputs.device)
- # normal_dist and std_multiplier are adapted from jax.scipy.stats.norm.ppf().
- #
- # References:
- # * https://docs.jax.dev/en/latest/_autosummary/jax.scipy.stats.norm.ppf.html
- # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.normal.Normal
- # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.transformed_distribution.TransformedDistribution.icdf
- normal_dist = torch.distributions.normal.Normal(0, 1)
- std_multiplier: torch.Tensor = normal_dist.icdf(target_sparsity_tensor)
- std_multiplier = std_multiplier.type(inputs.dtype)
- inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
- inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
- cutoff_x = inputs_mean + inputs_std * std_multiplier
- return nn.functional.relu(inputs - cutoff_x)
- class Gemma3nTextAltUp(nn.Module):
- """Alternating Updates (AltUp)
- The AltUp module wraps transformer layers. The `predict` step modifies the
- input to the transformer layer, and the `correct` step propagates the output
- of the transformer layer to the sparsely updated dimensions.
- See more in the research paper:
- https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
- """
- def __init__(self, config: Gemma3nTextConfig):
- super().__init__()
- self.config = config
- self.correct_output_scale = nn.Parameter(torch.zeros(self.config.hidden_size))
- self.correction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs, bias=False)
- self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
- self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
- self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
- self.register_buffer("router_input_scale", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
- def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
- router_inputs = self.router_norm(x) * self.router_input_scale
- routed = self.modality_router(router_inputs)
- return torch.tanh(routed.float()).type_as(x)
- def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """Predicts the output of a layer using a trainable map.
- Args:
- hidden_states: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
- stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
- Returns:
- A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` containing the predictions.
- """
- modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
- if self.training and self.config.altup_coef_clip is not None:
- self.prediction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
- # Project and then transpose all 2D matrices contained so that mulmat gives the correct result
- all_coefs: torch.Tensor = (
- self.prediction_coefs(modalities)
- .reshape(*modalities.shape[:-1], self.config.altup_num_inputs, self.config.altup_num_inputs)
- .permute(0, 1, 3, 2)
- )
- # permute hidden_states to [batch_size, num_tokens, hidden_size, altup_num_inputs]
- predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs)
- predictions = predictions.permute(3, 0, 1, 2) # undo the permute
- predictions += hidden_states # add the original input
- return predictions.contiguous().type_as(hidden_states)
- def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
- """Corrects the predictions relative to the
- Args:
- predictions: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
- stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
- activated: A 3D tensor of shape `[batch_size, num_tokens, hidden_size]` containing the activated inputs.
- Returns:
- A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` correcting the original
- predictions relative to the activated input embeddings.
- """
- modalities = self.compute_router_modalities(activated)
- innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
- innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
- if self.training and self.config.altup_coef_clip is not None:
- weight = self.correction_coefs.weight.clamp(-self.config.altup_coef_clip, self.config.altup_coef_clip)
- all_coefs = torch.nn.functional.linear(modalities, weight, bias=None) + 1.0
- else:
- all_coefs = self.correction_coefs(modalities) + 1.0
- # all_coefs adapted from jax.numpy.einsum("...p,pi->...i", ...)
- # Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
- # and expand on dim1 for broadcastability
- all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1)
- corrected = torch.mul(innovation, all_coefs)
- corrected += predictions # add the original input
- return corrected.contiguous().type_as(activated)
- def forward(self, corrected: torch.Tensor) -> torch.Tensor:
- """
- This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
- (which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
- `scale_corrected_output`
- """
- return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
- def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
- """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
- return self.forward(corrected)
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- dropout: float | int = 0.0,
- scaling: float | None = None,
- softcap: float | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- if scaling is None:
- scaling = module.head_dim**-0.5
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if softcap is not None:
- attn_weights = attn_weights / softcap
- attn_weights = torch.tanh(attn_weights)
- attn_weights = attn_weights * softcap
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- x (`torch.Tensor`): The tensor to embed.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- return (x * cos) + (rotate_half(x) * sin)
- @use_kernelized_func(apply_rotary_pos_emb)
- class Gemma3nTextAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
- super().__init__()
- self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
- self.config = config
- self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = 1.0
- self.attention_dropout = self.config.attention_dropout
- self.is_causal = True
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
- self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
- self.is_sliding = self.layer_type == "sliding_attention"
- self.q_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
- self.k_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
- self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
- first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
- self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
- prev_layers = config.layer_types[:first_kv_shared_layer_idx]
- if self.is_kv_shared_layer:
- # For shared layers, find the last non-shared layer of the same type before sharing starts
- self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
- self.store_full_length_kv = False
- else:
- self.kv_shared_layer_index = None
- # For non-shared layers, store full-length kv if this is the last non-shared layer of its type
- self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
- config.layer_types[layer_idx]
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: torch.Tensor = None,
- attention_mask: torch.Tensor | None = None,
- past_key_values: Cache | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.config.head_dim)
- cos, sin = position_embeddings
- query_states = self.q_proj(hidden_states).view(hidden_shape)
- query_states = self.q_norm(query_states)
- query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
- query_states = query_states.transpose(1, 2)
- # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer
- if self.is_kv_shared_layer and past_key_values is not None:
- key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
- # Device of past layer may be different from current one
- key_states = key_states.to(query_states.device)
- value_states = value_states.to(query_states.device)
- else:
- key_states = self.k_proj(hidden_states).view(hidden_shape)
- key_states = self.k_norm(key_states)
- key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
- key_states = key_states.transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape)
- value_states = self.v_norm(value_states)
- value_states = value_states.transpose(1, 2)
- if past_key_values is not None:
- if not self.is_kv_shared_layer:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- if self.store_full_length_kv:
- if not hasattr(past_key_values, "shared_layers"):
- past_key_values.shared_layers = {}
- past_key_values.shared_layers[self.layer_idx] = key_states, value_states
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=self.attention_dropout if self.training else 0.0,
- scaling=self.scaling,
- sliding_window=self.sliding_window,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.layer_idx = layer_idx
- self.self_attn = Gemma3nTextAttention(config, layer_idx)
- self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx)
- self.input_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- self.pre_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- self.post_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
- self.act_fn = ACT2FN[config.hidden_activation]
- self.altup = Gemma3nTextAltUp(config)
- self.laurel = Gemma3nTextLaurelBlock(config)
- self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
- self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
- self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: torch.Tensor = None,
- per_layer_input: torch.Tensor = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
- predictions = self.altup.predict(hidden_states)
- active_prediction = predictions[self.config.altup_active_idx]
- active_prediction_normed = self.input_layernorm(active_prediction)
- laurel_output = self.laurel(active_prediction_normed)
- attn, _ = self.self_attn(
- hidden_states=active_prediction_normed,
- attention_mask=attention_mask,
- position_ids=position_ids,
- position_embeddings=position_embeddings,
- past_key_values=past_key_values,
- **kwargs,
- )
- attn = self.post_attention_layernorm(attn)
- attn_gated = active_prediction + attn
- attn_laurel = (attn_gated + laurel_output) / math.sqrt(2)
- attn_norm = self.pre_feedforward_layernorm(attn_laurel)
- attn_ffw = self.mlp(attn_norm)
- attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
- attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
- corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
- first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
- if self.config.altup_correct_scale:
- first_prediction = self.altup.scale_corrected_output(first_prediction)
- # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
- first_prediction = self.per_layer_input_gate(first_prediction)
- first_prediction = self.act_fn(first_prediction)
- first_prediction = torch.multiply(first_prediction, per_layer_input)
- # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
- first_prediction = self.per_layer_projection(first_prediction)
- first_prediction = self.post_per_layer_input_norm(first_prediction)
- corrected_predictions[1:] += first_prediction
- return corrected_predictions
- @auto_docstring
- class Gemma3nPreTrainedModel(PreTrainedModel):
- config: Gemma3nConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["Gemma3nTextDecoderLayer"]
- _skip_keys_device_placement = ["past_key_values"]
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _can_compile_fullgraph = True
- _supports_attention_backend = True
- _can_record_outputs = {
- "hidden_states": Gemma3nTextDecoderLayer,
- "attentions": Gemma3nTextAttention,
- }
- input_modalities = ("image", "text", "audio")
- @torch.no_grad()
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, Gemma3nAudioCumulativeGroupNorm):
- init.ones_(module.weight)
- elif isinstance(module, Gemma3nAudioAttention):
- init.zeros_(module.per_dim_scale)
- q_scale = module.head_dim**-0.5
- r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
- init.copy_(module.q_scale, q_scale * r_softplus_0)
- init.constant_(module.softcap, module.attention_logits_soft_cap)
- init.copy_(module.local_causal_valid_mask, module.create_local_causal_valid_mask())
- elif isinstance(module, Gemma3nTextScaledWordEmbedding):
- init.constant_(module.embed_scale, module.scalar_embed_scale)
- elif isinstance(module, Gemma3nTextAltUp):
- init.zeros_(module.correct_output_scale)
- init.constant_(module.router_input_scale, self.config.hidden_size**-1.0)
- elif isinstance(module, Gemma3nAudioRelativePositionEmbedding):
- min_timescale, max_timescale = 1.0, 1.0e4
- num_timescales = module.channels // 2
- log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
- num_timescales - 1, 1
- )
- inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
- init.copy_(module.inv_timescales, inv_timescales.float().unsqueeze(0).unsqueeze(0))
- elif isinstance(module, Gemma3nTextModel):
- init.constant_(module.per_layer_projection_scale, self.hidden_size**-0.5)
- init.constant_(module.per_layer_input_scale, 1 / math.sqrt(2.0))
- elif isinstance(module, Gemma3nRotaryEmbedding):
- for layer_type in module.layer_types:
- rope_init_fn = module.compute_default_rope_parameters
- if module.rope_type[layer_type] != "default":
- rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
- curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
- init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
- init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
- if hasattr(module, "gradient_clipping"):
- init.constant_(module.gradient_clipping, self.config.gradient_clipping)
- class Gemma3nAudioEncoder(Gemma3nPreTrainedModel):
- """
- An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture.
- """
- config: Gemma3nAudioConfig
- main_input_name = "audio_mel"
- input_modalities = "audio"
- def __init__(self, config: Gemma3nAudioConfig):
- super().__init__(config)
- self.config = config
- self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
- self.conformer = nn.ModuleList(
- [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
- )
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs: Unpack[TransformersKwargs]
- ) -> tuple | Gemma3nAudioEncoderModelOutput:
- """Encodes a batch of MELs.
- Args:
- audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
- mel_bins].
- Returns:
- audio_encodings: a torch.Tensor of shape
- `[batch_size, self.config.audio_soft_tokens_per_image,
- self.config.audio_config.hidden_size]`
- audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
- """
- audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
- # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
- t_sub = audio_encodings.shape[1]
- time_stride_product = 1
- for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
- time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
- # Create indices for gathering from the original mask.
- # These indices map to original time steps corresponding to the start of each
- # receptive field in the subsampled output.
- indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
- indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) # Ensure indices are valid
- # Expand indices for batch compatibility if B > 1 and indices is 1D.
- if audio_mel_mask.ndim > 1 and indices.ndim == 1:
- indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) # [B, T_sub]
- elif (
- audio_mel_mask.ndim == indices.ndim
- and audio_mel_mask.shape[0] == 1
- and indices.shape[0] != 1
- and t_sub == indices.shape[0]
- ):
- # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
- indices = indices.unsqueeze(0)
- current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub]
- for block in self.conformer:
- audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
- if self.config.conf_reduction_factor > 1:
- audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
- # Reduce the mask as well
- current_mask = current_mask[:, :: self.config.conf_reduction_factor]
- audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
- return Gemma3nAudioEncoderModelOutput(
- last_hidden_state=audio_encodings,
- audio_mel_mask=current_mask,
- )
- class Gemma3nRotaryEmbedding(nn.Module):
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, config: Gemma3nTextConfig, device=None, layer_type=None):
- super().__init__()
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.layer_types = list(set(config.layer_types))
- self.rope_type = {}
- for layer_type in self.layer_types:
- rope_params = self.config.rope_parameters[layer_type]
- if rope_params is None:
- continue
- self.rope_type[layer_type] = rope_params["rope_type"]
- rope_init_fn: Callable = self.compute_default_rope_parameters
- if self.rope_type[layer_type] != "default":
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
- curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
- self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
- self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
- setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
- @staticmethod
- def compute_default_rope_parameters(
- config: Gemma3nTextConfig | None = None,
- device: Optional["torch.device"] = None,
- seq_len: int | None = None,
- layer_type: str | None = None,
- ) -> tuple["torch.Tensor", float]:
- """
- Computes the inverse frequencies according to the original RoPE implementation
- Args:
- config ([`~transformers.PreTrainedConfig`]):
- The model configuration.
- device (`torch.device`):
- The device to use for initialization of the inverse frequencies.
- seq_len (`int`, *optional*):
- The current sequence length. Unused for this type of RoPE.
- layer_type (`str`, *optional*):
- The current layer type if the model has different RoPE parameters per type.
- Should not be used unless `config.layer_types is not None`
- Returns:
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
- """
- # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
- base = config.rope_parameters[layer_type]["rope_theta"]
- dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
- attention_factor = 1.0 # Unused in this type of RoPE
- # Compute the inverse frequencies
- inv_freq = 1.0 / (
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
- )
- return inv_freq, attention_factor
- @torch.no_grad()
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
- def forward(self, x, position_ids, layer_type=None):
- inv_freq = getattr(self, f"{layer_type}_inv_freq")
- attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
- inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
- position_ids_expanded = position_ids[:, None, :].float()
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * attention_scaling
- sin = emb.sin() * attention_scaling
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
- @auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.")
- class Gemma3nTextModel(Gemma3nPreTrainedModel):
- config: Gemma3nTextConfig
- input_modalities = ("text",)
- def __init__(self, config: Gemma3nTextConfig):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- # Gemma3n downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
- self.embed_tokens = Gemma3nTextScaledWordEmbedding(
- config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
- )
- self.layers = nn.ModuleList(
- [Gemma3nTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = Gemma3nRotaryEmbedding(config)
- self.gradient_checkpointing = False
- self.hidden_size = config.hidden_size
- self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
- self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding(
- config.vocab_size_per_layer_input,
- config.num_hidden_layers * config.hidden_size_per_layer_input,
- self.padding_idx,
- embed_scale=config.hidden_size_per_layer_input**0.5,
- )
- self.per_layer_model_projection = nn.Linear(
- self.hidden_size,
- config.num_hidden_layers * config.hidden_size_per_layer_input,
- bias=False,
- )
- self.per_layer_projection_norm = Gemma3nRMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps)
- self.altup_projections = nn.ModuleList(
- [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
- )
- self.altup_unembed_projections = nn.ModuleList(
- [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
- )
- self.register_buffer("per_layer_projection_scale", torch.tensor(self.hidden_size**-0.5), persistent=False)
- self.register_buffer("per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False)
- # Initialize weights and apply final processing
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs(tie_last_hidden_states=False)
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- per_layer_inputs: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPast:
- r"""
- per_layer_inputs (torch.Tensor, *optional*, defaults to None):
- Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
- """
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if input_ids is not None:
- inputs_embeds = self.embed_tokens(input_ids)
- per_layer_inputs = self.get_per_layer_inputs(input_ids)
- per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- if position_ids is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
- position_ids = position_ids.unsqueeze(0)
- # It may already have been prepared by e.g. `generate`
- if not isinstance(causal_mask_mapping := attention_mask, dict):
- # Prepare mask arguments
- mask_kwargs = {
- "config": self.config,
- "inputs_embeds": inputs_embeds,
- "attention_mask": attention_mask,
- "past_key_values": past_key_values,
- "position_ids": position_ids,
- }
- # Create the masks
- causal_mask_mapping = {
- "full_attention": create_causal_mask(**mask_kwargs),
- "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
- }
- # embed positions
- hidden_states_0 = inputs_embeds
- # Expand hidden_states to support per-layer inputs
- target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
- epsilon_tensor = torch.tensor(1e-5)
- temp_hidden_states = [hidden_states_0]
- for i in range(1, self.config.altup_num_inputs):
- # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
- altup_proj = self.altup_projections[i - 1](hidden_states_0)
- current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
- new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
- new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
- current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
- temp_hidden_states.append(current_hidden_state)
- hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
- position_embeddings = {}
- for layer_type in self.config.layer_types:
- position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
- for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
- causal_mask = causal_mask_mapping[self.config.layer_types[i]]
- per_layer_input = per_layer_inputs[:, :, i, :]
- hidden_states = decoder_layer(
- hidden_states,
- position_embeddings[self.config.layer_types[i]],
- per_layer_input,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- **kwargs,
- )
- # Per-layer inputs to single output
- target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
- temp_hidden_states = [hidden_states[0]]
- for i in range(1, self.config.altup_num_inputs):
- # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
- altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
- current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
- new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
- new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
- current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
- temp_hidden_states.append(current_hidden_state)
- hidden_states = torch.stack(temp_hidden_states)
- hidden_states = torch.mean(hidden_states, dim=0)
- hidden_states = self.norm(hidden_states)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- )
- def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor:
- return self.embed_tokens_per_layer(input_ids).reshape(
- *input_ids.shape,
- self.config.num_hidden_layers,
- self.hidden_size_per_layer_input,
- )
- def project_per_layer_inputs(
- self,
- inputs_embeds: torch.Tensor,
- per_layer_inputs: torch.Tensor | None = None,
- ) -> torch.Tensor:
- per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
- per_layer_projection *= self.per_layer_projection_scale.to(
- dtype=inputs_embeds.dtype, device=per_layer_projection.device
- )
- per_layer_projection = per_layer_projection.reshape(
- *inputs_embeds.shape[:-1],
- self.config.num_hidden_layers,
- self.hidden_size_per_layer_input,
- )
- per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
- if per_layer_inputs is None:
- return per_layer_projection
- if per_layer_projection.shape != per_layer_inputs.shape:
- # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
- per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
- return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
- dtype=inputs_embeds.dtype, device=per_layer_projection.device
- )
- @auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")
- class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
- _tp_plan = {"lm_head": "colwise_gather_output"}
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
- config: Gemma3nTextConfig
- def __init__(self, config: Gemma3nTextConfig):
- super().__init__(config)
- self.model = Gemma3nTextModel(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
- r"""
- Example:
- ```python
- >>> from transformers import AutoTokenizer, Gemma3nForCausalLM
- >>> model = Gemma3nForCausalLM.from_pretrained("google/gemma-2-9b")
- >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
- >>> prompt = "What is your favorite condiment?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "What is your favorite condiment?"
- ```"""
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs: BaseModelOutputWithPast = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- if self.config.final_logit_softcapping is not None:
- logits = logits / self.config.final_logit_softcapping
- logits = torch.tanh(logits)
- logits = logits * self.config.final_logit_softcapping
- loss = None
- if labels is not None:
- loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- class Gemma3nMultimodalEmbedder(nn.Module):
- """Embeds token ids or soft tokens for multimodal content into language model space."""
- def __init__(
- self,
- multimodal_config: Gemma3nAudioConfig | Gemma3nVisionConfig,
- text_config: Gemma3nTextConfig,
- ):
- super().__init__()
- self.multimodal_hidden_size = multimodal_config.hidden_size
- self.eps = multimodal_config.rms_norm_eps
- self.vocab_offset = multimodal_config.vocab_offset
- self.vocab_size = multimodal_config.vocab_size
- self.text_hidden_size = text_config.hidden_size
- self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size)
- self.hard_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
- self.soft_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
- self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
- self.embedding_post_projection_norm = Gemma3nRMSNorm(self.text_hidden_size, eps=self.eps, with_scale=False)
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- ) -> torch.Tensor:
- """Embeds token ids or soft tokens for multimodal content into language model space.
- Args:
- input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
- `[vocab_offset, vocab_offset + vocab_size)`.
- inputs_embeds: A torch.Tensor containing the soft tokens to embed.
- Returns:
- A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
- """
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is not None:
- emb_norm = self.soft_embedding_norm(inputs_embeds)
- else:
- hard_emb = self.embedding(input_ids - self.vocab_offset)
- emb_norm = self.hard_embedding_norm(hard_emb)
- emb_norm_proj = self.embedding_projection(emb_norm)
- return self.embedding_post_projection_norm(emb_norm_proj)
- @auto_docstring(
- custom_intro="""
- The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
- language modeling head.
- """
- )
- class Gemma3nModel(Gemma3nPreTrainedModel):
- # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
- accepts_loss_kwargs = False
- def __init__(self, config: Gemma3nConfig):
- super().__init__(config)
- self.vision_tower = AutoModel.from_config(config=config.vision_config)
- self.vocab_size = config.text_config.vocab_size
- language_model = AutoModel.from_config(config=config.text_config)
- self.language_model = language_model
- self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
- self.audio_tower = AutoModel.from_config(config.audio_config)
- self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config)
- self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config)
- self.post_init()
- def get_input_embeddings(self):
- return self.language_model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.language_model.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.")
- def get_image_features(
- self,
- pixel_values: torch.FloatTensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- vision_outputs = self.vision_tower(pixel_values=pixel_values, do_pooling=False, return_dict=True, **kwargs)
- last_hidden_state = vision_outputs.last_hidden_state
- # Convert from (batch, channels, height, width) to (batch, height * width, channels) where:
- # height == width and height * width == Gemma3nConfig.vision_soft_tokens_per_image.
- last_hidden_state = last_hidden_state.reshape(
- last_hidden_state.shape[0],
- self.config.vision_config.hidden_size,
- self.config.vision_soft_tokens_per_image,
- ).permute(0, 2, 1)
- # Normalize and embed the soft tokens into language model space.
- last_hidden_state *= self.config.vision_config.hidden_size**0.5
- vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state)
- return vision_outputs
- def get_placeholder_mask(
- self,
- input_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- image_features: torch.FloatTensor | None = None,
- audio_features: torch.FloatTensor | None = None,
- ):
- """
- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
- equal to the length of multimodal features. If the lengths are different, an error is raised.
- """
- if input_ids is None:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- special_image_mask = special_image_mask.all(-1)
- special_audio_mask = (
- inputs_embeds
- == self.get_input_embeddings()(
- torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- ).all(-1)
- else:
- special_image_mask = input_ids == self.config.image_token_id
- special_audio_mask = input_ids == self.config.audio_token_id
- n_image_tokens = special_image_mask.sum()
- special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- if image_features is not None:
- torch_compilable_check(
- inputs_embeds[special_image_mask].numel() == image_features.numel(),
- f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0] * image_features.shape[1]}",
- )
- n_audio_tokens = special_audio_mask.sum()
- special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- if audio_features is not None:
- torch_compilable_check(
- inputs_embeds[special_audio_mask].numel() == audio_features.numel(),
- f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {audio_features.shape[0] * audio_features.shape[1]}",
- )
- return special_image_mask, special_audio_mask
- @can_return_tuple
- def forward(
- self,
- input_ids: torch.LongTensor | None = None, # text inputs
- pixel_values: torch.FloatTensor | None = None, # vision inputs
- input_features: torch.FloatTensor | None = None, # audio inputs
- attention_mask: torch.Tensor | None = None,
- input_features_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- token_type_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- **lm_kwargs: Unpack[TransformersKwargs],
- ) -> Gemma3nModelOutputWithPast:
- r"""
- input_features_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Attention mask for `input_features` where non-zero values mark valid audio frames.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
- Example:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
- >>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224")
- >>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224")
- >>> prompt = "Where is the cat standing?"
- >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(**inputs,)
- >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Where is the cat standing?\nsnow"
- ```
- """
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if input_ids is not None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
- # Prepare per-layer inputs from inputs_ids
- per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input)
- per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids))
- per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens)
- # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset)
- vision_mask = torch.logical_and(
- input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset
- )
- dummy_vision_token_id = self.embed_vision.vocab_offset + self.embed_vision.vocab_size - 1
- vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device)
- vision_embeds = self.embed_vision(input_ids=vision_input_ids)
- vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
- expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds)
- inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds)
- # Handle audio tokens (>= embed_audio.vocab_offset)
- audio_mask = input_ids >= self.embed_audio.vocab_offset
- dummy_audio_token_id = self.embed_audio.vocab_offset + self.embed_audio.vocab_size - 1
- audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device)
- audio_embeds = self.embed_audio(input_ids=audio_input_ids)
- audio_embeds = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
- expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
- inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds)
- else:
- per_layer_inputs = None
- # Merge text and images
- if pixel_values is not None:
- image_features = self.get_image_features(pixel_values, return_dict=True).pooler_output
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- special_image_mask, _ = self.get_placeholder_mask(
- input_ids, inputs_embeds=inputs_embeds, image_features=image_features
- )
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
- # Merge text and audio
- if input_features is not None and input_features_mask is not None:
- audio_outputs = self.get_audio_features(input_features, ~input_features_mask, return_dict=True)
- audio_features = audio_outputs.pooler_output
- audio_mask = audio_outputs.audio_mel_mask
- # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
- # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
- # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
- # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
- # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
- audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device)
- audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
- audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features)
- audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
- extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len
- extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim)
- audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
- audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
- _, special_audio_mask = self.get_placeholder_mask(
- input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features
- )
- inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
- outputs = self.language_model(
- input_ids=None,
- per_layer_inputs=per_layer_inputs,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- return_dict=True,
- **lm_kwargs,
- )
- return Gemma3nModelOutputWithPast(
- last_hidden_state=outputs.last_hidden_state,
- past_key_values=outputs.past_key_values if use_cache else None,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
- audio_hidden_states=audio_features if input_features is not None else None,
- )
- @can_return_tuple
- @auto_docstring(custom_intro="Projects the last hidden state from the audio encoder into language model space.")
- def get_audio_features(
- self,
- input_features: torch.Tensor,
- input_features_mask: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Gemma3nAudioEncoderModelOutput:
- r"""
- input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`):
- The tensors corresponding to the input audio.
- input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
- The attention mask for the input audio.
- """
- audio_outputs: Gemma3nAudioEncoderModelOutput = self.audio_tower(
- input_features, input_features_mask, return_dict=True, **kwargs
- )
- audio_embeds = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state)
- audio_outputs.pooler_output = audio_embeds
- return audio_outputs
- @auto_docstring(
- custom_intro="""
- The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
- head.
- """
- )
- class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
- def __init__(self, config: Gemma3nConfig):
- super().__init__(config)
- self.model = Gemma3nModel(config)
- self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
- self.post_init()
- def get_input_embeddings(self):
- return self.model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.model.set_input_embeddings(value)
- @auto_docstring
- def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]):
- return self.model.get_image_features(pixel_values, **kwargs)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None, # text inputs
- pixel_values: torch.FloatTensor | None = None, # vision inputs
- input_features: torch.FloatTensor | None = None, # audio inputs
- attention_mask: torch.Tensor | None = None,
- input_features_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- token_type_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **lm_kwargs: Unpack[TransformersKwargs],
- ) -> Gemma3nCausalLMOutputWithPast:
- r"""
- input_features_mask (torch.Tensor, *optional*, defaults to None):
- The attention mask for the input audio.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are
- ignored (masked), the loss is only computed for the tokens with labels in
- `[0, ..., config.text_config.vocab_size]`.
- Example:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
- >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
- >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
- >>> messages = [
- ... {
- ... "role": "system",
- ... "content": [
- ... {"type": "text", "text": "You are a helpful assistant."}
- ... ]
- ... },
- ... {
- ... "role": "user", "content": [
- ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
- ... {"type": "text", "text": "Where is the cat standing?"},
- ... ]
- ... },
- ... ]
- >>> inputs = processor.apply_chat_template(
- ... messages,
- ... tokenizer=True,
- ... return_dict=True,
- ... return_tensors="pt",
- ... add_generation_prompt=True
- ... )
- >>> # Generate
- >>> generate_ids = model.generate(**inputs)
- >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
- ```
- """
- outputs = self.model(
- input_ids=input_ids,
- pixel_values=pixel_values,
- input_features=input_features,
- attention_mask=attention_mask,
- input_features_mask=input_features_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- token_type_ids=token_type_ids,
- inputs_embeds=inputs_embeds,
- labels=labels,
- use_cache=use_cache,
- return_dict=True,
- **lm_kwargs,
- )
- hidden_states = outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None:
- logits = logits / final_logit_softcapping
- logits = torch.tanh(logits)
- logits = logits * final_logit_softcapping
- loss = None
- if labels is not None:
- # Upcast to float if we need to compute the loss to avoid potential precision issues
- logits = logits.float()
- shift_logits = logits[..., :-1, :]
- shift_labels = labels[..., 1:]
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
- shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
- else:
- shift_logits = shift_logits.contiguous()
- shift_labels = shift_labels.contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
- flat_labels = shift_labels.view(-1).to(shift_logits.device)
- loss = loss_fct(flat_logits, flat_labels)
- return Gemma3nCausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- image_hidden_states=outputs.image_hidden_states,
- audio_hidden_states=outputs.audio_hidden_states,
- )
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- inputs_embeds=None,
- position_ids=None,
- pixel_values=None,
- input_features=None,
- attention_mask=None,
- input_features_mask=None,
- token_type_ids=None,
- use_cache=True,
- logits_to_keep=None,
- labels=None,
- is_first_iteration=False,
- **kwargs,
- ):
- # Overwritten -- custom `position_ids` and `pixel_values` handling
- model_inputs = super().prepare_inputs_for_generation(
- input_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- position_ids=position_ids,
- use_cache=use_cache,
- logits_to_keep=logits_to_keep,
- token_type_ids=token_type_ids,
- is_first_iteration=is_first_iteration,
- **kwargs,
- )
- # If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
- # tokens anymore. Otherwise multimodal inputs should be passed to model.
- # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
- if is_first_iteration or not use_cache:
- model_inputs["pixel_values"] = pixel_values
- model_inputs["input_features"] = input_features
- model_inputs["input_features_mask"] = input_features_mask
- return model_inputs
- __all__ = [
- "Gemma3nAudioEncoder",
- "Gemma3nForCausalLM",
- "Gemma3nForConditionalGeneration",
- "Gemma3nModel",
- "Gemma3nPreTrainedModel",
- "Gemma3nTextModel",
- ]
|