| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/t5gemma2/modular_t5gemma2.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_t5gemma2.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
- #
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- from collections.abc import Callable
- from typing import Optional
- import torch
- import torch.nn as nn
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
- from ...generation import GenerationConfig, GenerationMixin, GenerationMode
- from ...integrations import use_kernel_func_from_hub, use_kernelized_func
- from ...masking_utils import create_bidirectional_mask, create_causal_mask, create_sliding_window_causal_mask
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPastAndCrossAttentions,
- BaseModelOutputWithPooling,
- Seq2SeqLMOutput,
- Seq2SeqModelOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
- )
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
- from ...utils.generic import maybe_autocast, merge_with_config_defaults
- from ...utils.output_capturing import OutputRecorder, capture_outputs
- from ..auto import AutoModel
- from .configuration_t5gemma2 import T5Gemma2Config, T5Gemma2DecoderConfig, T5Gemma2EncoderConfig, T5Gemma2TextConfig
- class T5Gemma2RMSNorm(nn.Module):
- def __init__(self, dim: int, eps: float = 1e-6):
- super().__init__()
- self.eps = eps
- self.weight = nn.Parameter(torch.zeros(dim))
- def _norm(self, x):
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
- def forward(self, x):
- output = self._norm(x.float())
- # Llama does x.to(float16) * w whilst T5Gemma2 is (x * w).to(float16)
- # See https://github.com/huggingface/transformers/pull/29402
- output = output * (1.0 + self.weight.float())
- return output.type_as(x)
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.eps}"
- class T5Gemma2MLP(nn.Module):
- def __init__(self, config: T5Gemma2TextConfig):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_activation]
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(self, x):
- hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
- hidden_states = self.dropout(hidden_states)
- down_proj = self.down_proj(hidden_states)
- return down_proj
- class T5Gemma2RotaryEmbedding(nn.Module):
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, config: T5Gemma2TextConfig, device=None):
- super().__init__()
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.layer_types = list(set(config.layer_types))
- self.rope_type = {}
- for layer_type in self.layer_types:
- rope_params = self.config.rope_parameters[layer_type]
- if rope_params is None:
- continue
- self.rope_type[layer_type] = rope_params["rope_type"]
- rope_init_fn: Callable = self.compute_default_rope_parameters
- if self.rope_type[layer_type] != "default":
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
- curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
- self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
- self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
- setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
- @staticmethod
- def compute_default_rope_parameters(
- config: T5Gemma2TextConfig | None = None,
- device: Optional["torch.device"] = None,
- seq_len: int | None = None,
- layer_type: str | None = None,
- ) -> tuple["torch.Tensor", float]:
- """
- Computes the inverse frequencies according to the original RoPE implementation
- Args:
- config ([`~transformers.PreTrainedConfig`]):
- The model configuration.
- device (`torch.device`):
- The device to use for initialization of the inverse frequencies.
- seq_len (`int`, *optional*):
- The current sequence length. Unused for this type of RoPE.
- layer_type (`str`, *optional*):
- The current layer type if the model has different RoPE parameters per type.
- Should not be used unless `config.layer_types is not None`
- Returns:
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
- """
- # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
- base = config.rope_parameters[layer_type]["rope_theta"]
- dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
- attention_factor = 1.0 # Unused in this type of RoPE
- # Compute the inverse frequencies
- inv_freq = 1.0 / (
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
- )
- return inv_freq, attention_factor
- @torch.no_grad()
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
- def forward(self, x, position_ids, layer_type=None):
- inv_freq = getattr(self, f"{layer_type}_inv_freq")
- attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
- inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
- position_ids_expanded = position_ids[:, None, :].float()
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * attention_scaling
- sin = emb.sin() * attention_scaling
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- @use_kernel_func_from_hub("rotary_pos_emb")
- def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- dropout: float | int = 0.0,
- scaling: float | None = None,
- softcap: float | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- if scaling is None:
- scaling = module.head_dim**-0.5
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if softcap is not None:
- attn_weights = attn_weights / softcap
- attn_weights = torch.tanh(attn_weights)
- attn_weights = attn_weights * softcap
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- @use_kernelized_func(apply_rotary_pos_emb)
- class T5Gemma2SelfAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: T5Gemma2TextConfig, layer_idx: int):
- super().__init__()
- self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
- self.config = config
- self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = config.query_pre_attn_scalar**-0.5
- self.attention_dropout = self.config.attention_dropout
- self.is_causal = False # Only used by the encoder
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
- self.attn_logit_softcapping = self.config.attn_logit_softcapping
- self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
- self.is_sliding = self.layer_type == "sliding_attention"
- self.q_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
- self.k_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: torch.Tensor = None,
- attention_mask: torch.Tensor | None = None,
- past_key_values: Cache | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- query_states = self.q_norm(query_states)
- key_states = self.k_norm(key_states)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=self.attention_dropout if self.training else 0.0,
- scaling=self.scaling,
- sliding_window=self.sliding_window,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- @use_kernelized_func(apply_rotary_pos_emb)
- class T5Gemma2MergedAttention(nn.Module):
- """Merged self-attention and cross-attention for decoder."""
- def __init__(self, config: T5Gemma2TextConfig, layer_idx: int):
- super().__init__()
- self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
- self.config = config
- self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = config.query_pre_attn_scalar**-0.5
- self.attention_dropout = self.config.attention_dropout
- self.is_causal = False # Fused causal and encoder mask
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
- self.attn_logit_softcapping = self.config.attn_logit_softcapping
- self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
- self.is_sliding = self.layer_type == "sliding_attention"
- self.q_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
- self.k_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
- def forward(
- self,
- # decoder self-attention inputs
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- merged_attention_mask: torch.Tensor | None,
- # cross-attention inputs
- encoder_hidden_states: torch.Tensor,
- # cache inputs
- past_key_values: EncoderDecoderCache | None = None,
- # others
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
- # attention shapes.
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- cross_input_shape = encoder_hidden_states.shape[:-1]
- cross_hidden_shape = (*cross_input_shape, -1, self.head_dim)
- # self-attention.
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- query_states = self.q_norm(query_states)
- key_states = self.k_norm(key_states)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- # self-attention.
- self_attention_cache = past_key_values.self_attention_cache
- key_states, value_states = self_attention_cache.update(key_states, value_states, self.layer_idx)
- # cross-attention.
- is_updated = past_key_values.is_updated.get(self.layer_idx)
- cross_attention_cache = past_key_values.cross_attention_cache
- if past_key_values is None or not is_updated:
- cross_key_states = self.k_proj(encoder_hidden_states).view(cross_hidden_shape).transpose(1, 2)
- cross_value_states = self.v_proj(encoder_hidden_states).view(cross_hidden_shape).transpose(1, 2)
- cross_key_states = self.k_norm(cross_key_states)
- if past_key_values is not None:
- cross_key_states, cross_value_states = cross_attention_cache.update(
- cross_key_states, cross_value_states, self.layer_idx
- )
- past_key_values.is_updated[self.layer_idx] = True
- else:
- cross_key_states = cross_attention_cache.layers[self.layer_idx].keys
- cross_value_states = cross_attention_cache.layers[self.layer_idx].values
- # merged attention.
- query_states = query_states
- cross_key_size = cross_input_shape[1]
- key_states = torch.cat([key_states, cross_key_states], dim=2)
- value_states = torch.cat([value_states, cross_value_states], dim=2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- merged_attention_mask,
- dropout=self.attention_dropout if self.training else 0.0,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- # decompose merged attention weights into self & cross attention weights
- if attn_weights is not None:
- self_attn_weights = attn_weights[..., :-cross_key_size]
- cross_attn_weights = attn_weights[..., -cross_key_size:]
- else:
- self_attn_weights, cross_attn_weights = None, None
- return attn_output, self_attn_weights, cross_attn_weights
- class T5Gemma2EncoderLayer(GradientCheckpointingLayer):
- """Encoder sub-layer."""
- def __init__(self, config, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.config = config
- self.layer_idx = layer_idx
- self.attention_type = config.layer_types[layer_idx]
- self.self_attn = T5Gemma2SelfAttention(
- config=config,
- layer_idx=layer_idx,
- )
- self.pre_self_attn_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_self_attn_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.mlp = T5Gemma2MLP(config)
- self.pre_feedforward_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_feedforward_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- **kwargs,
- ) -> tuple[torch.FloatTensor,]:
- residual = hidden_states
- hidden_states = self.pre_self_attn_layernorm(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=None,
- **kwargs,
- )
- hidden_states = self.post_self_attn_layernorm(hidden_states)
- hidden_states = residual + self.dropout(hidden_states)
- residual = hidden_states
- hidden_states = self.pre_feedforward_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = self.post_feedforward_layernorm(hidden_states)
- hidden_states = residual + self.dropout(hidden_states)
- return hidden_states
- class T5Gemma2DecoderLayer(GradientCheckpointingLayer):
- """Decoder sub-layer: merged attention instead of vanilla self-attention."""
- def __init__(self, config, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.config = config
- self.layer_idx = layer_idx
- self.attention_type = config.layer_types[layer_idx]
- # replace vanilla self-attention with merged attention to support joint cross-attention.
- self.self_attn = T5Gemma2MergedAttention(
- config=config,
- layer_idx=layer_idx,
- )
- self.pre_self_attn_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_self_attn_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.mlp = T5Gemma2MLP(config)
- self.pre_feedforward_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_feedforward_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- merged_attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: EncoderDecoderCache | None = None,
- use_cache: bool | None = False,
- encoder_hidden_states: torch.Tensor | None = None,
- **kwargs,
- ) -> torch.FloatTensor:
- residual = hidden_states
- hidden_states = self.pre_self_attn_layernorm(hidden_states)
- hidden_states, _, _ = self.self_attn(
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- merged_attention_mask=merged_attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- encoder_hidden_states=encoder_hidden_states,
- **kwargs,
- )
- hidden_states = self.post_self_attn_layernorm(hidden_states)
- hidden_states = residual + self.dropout(hidden_states)
- residual = hidden_states
- hidden_states = self.pre_feedforward_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = self.post_feedforward_layernorm(hidden_states)
- hidden_states = residual + self.dropout(hidden_states)
- return hidden_states
- class T5Gemma2LMHead(nn.Module):
- """Head for language modeling (generation) tasks."""
- def __init__(self, hidden_size: int, vocab_size: int, bias: bool = False):
- super().__init__()
- self.out_proj = nn.Linear(hidden_size, vocab_size, bias=bias)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- logits = self.out_proj(hidden_states)
- return logits
- class T5Gemma2ClassificationHead(nn.Module):
- """Head for sentence-level classification tasks."""
- def __init__(self, hidden_size: int, num_labels: int, classifier_dropout_rate: float = 0.0):
- super().__init__()
- self.dropout = nn.Dropout(p=classifier_dropout_rate)
- self.out_proj = nn.Linear(hidden_size, num_labels)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.out_proj(hidden_states)
- return hidden_states
- class T5Gemma2MultiModalProjector(nn.Module):
- def __init__(self, config: T5Gemma2EncoderConfig):
- super().__init__()
- self.mm_input_projection_weight = nn.Parameter(
- torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
- )
- self.mm_soft_emb_norm = T5Gemma2RMSNorm(
- config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
- )
- self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)
- self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
- self.kernel_size = self.patches_per_image // self.tokens_per_side
- self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
- def forward(self, vision_outputs: torch.Tensor):
- batch_size, _, hidden_size = vision_outputs.shape
- reshaped_vision_outputs = vision_outputs.transpose(1, 2)
- reshaped_vision_outputs = reshaped_vision_outputs.reshape(
- batch_size, hidden_size, self.patches_per_image, self.patches_per_image
- )
- reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
- pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
- pooled_vision_outputs = pooled_vision_outputs.flatten(2)
- pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
- normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
- projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)
- return projected_vision_outputs.type_as(vision_outputs)
- class T5Gemma2TextScaledWordEmbedding(nn.Embedding):
- """T5Gemma2 Embedding: override to add eoi token embedding separately."""
- def __init__(
- self,
- num_embeddings: int,
- embedding_dim: int,
- padding_idx: int,
- embed_scale: float = 1.0,
- eoi_token_index: int = 256_000,
- ):
- super().__init__(num_embeddings, embedding_dim, padding_idx)
- self.scalar_embed_scale = embed_scale
- self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
- self.eoi_token_index = eoi_token_index
- self.eoi_embedding = nn.Parameter(torch.zeros(self.embedding_dim))
- def forward(self, input_ids: torch.Tensor):
- input_embeddings = super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
- input_embeddings[input_ids == self.eoi_token_index] = self.eoi_embedding.to(input_embeddings.dtype)
- return input_embeddings
- @auto_docstring
- class T5Gemma2PreTrainedModel(PreTrainedModel):
- config: T5Gemma2Config
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = [
- "T5Gemma2EncoderLayer",
- "T5Gemma2DecoderLayer",
- "SiglipVisionEmbeddings",
- "SiglipEncoderLayer",
- "SiglipMultiheadAttentionPoolingHead",
- ]
- _skip_keys_device_placement = ["past_key_values"]
- # Mask creation is incompatible
- # FA due to non-default creation / SWA
- _supports_flash_attn = False
- _supports_sdpa = True
- # Flex due to custom masks not compatible to be merged after creation
- _supports_flex_attn = False
- _can_compile_fullgraph = True
- _supports_attention_backend = True
- _can_record_outputs = {
- "hidden_states": [T5Gemma2EncoderLayer, T5Gemma2DecoderLayer],
- "attentions": [
- OutputRecorder(T5Gemma2SelfAttention, index=1, layer_name="self_attn"),
- OutputRecorder(T5Gemma2MergedAttention, index=1, layer_name="self_attn"),
- OutputRecorder(T5Gemma2MergedAttention, index=2, layer_name="cross_attn"),
- ],
- }
- input_modalities = ("image", "text")
- @torch.no_grad()
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, T5Gemma2MultiModalProjector):
- init.zeros_(module.mm_input_projection_weight)
- elif isinstance(module, T5Gemma2TextScaledWordEmbedding):
- init.zeros_(module.eoi_embedding)
- init.constant_(module.embed_scale, module.scalar_embed_scale)
- elif isinstance(module, T5Gemma2ClassificationHead):
- scale = module.out_proj.weight.shape[0] ** -0.5
- init.normal_(module.out_proj.weight, mean=0.0, std=self.config.initializer_range * scale)
- if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
- init.zeros_(module.out_proj.bias)
- # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
- elif "RMSNorm" in module.__class__.__name__:
- init.zeros_(module.weight)
- elif isinstance(module, T5Gemma2RotaryEmbedding):
- for layer_type in module.layer_types:
- rope_init_fn = module.compute_default_rope_parameters
- if module.rope_type[layer_type] != "default":
- rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
- curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
- init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
- init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
- def prepare_decoder_input_ids_from_labels(self, input_ids):
- """
- Shifts input_ids to the right, prepends the decoder_start_token_id, and handles
- pad_token_id replacement for labels that were -100.
- This is a common preparation step for decoder inputs in sequence-to-sequence models.
- """
- decoder_config = self.config.decoder
- decoder_start_token_id = decoder_config.bos_token_id
- pad_token_id = decoder_config.pad_token_id
- if decoder_start_token_id is None:
- raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ")
- # shift inputs to the right
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
- shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
- shifted_input_ids[..., 0] = decoder_start_token_id
- if pad_token_id is None:
- raise ValueError("self.model.config.decoder.pad_token_id has to be defined.")
- # Is this T5 specific?
- # replace possible -100 values in labels by `pad_token_id`
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
- return shifted_input_ids
- def sliding_window_mask_function(sliding_window: int, is_causal=True) -> Callable:
- """
- This creates uni/bidirectional attention mask with sliding window.
- """
- def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
- if is_causal:
- left_window_size, right_window_size = sliding_window, 0
- else:
- left_window_size, right_window_size = ((sliding_window + 1) // 2, (sliding_window) // 2 + 1)
- dist = q_idx - kv_idx
- left_mask = (dist >= 0) & (dist < left_window_size)
- right_mask = (dist < 0) & (-dist < right_window_size)
- return left_mask | right_mask
- return inner_mask
- class T5Gemma2TextEncoder(T5Gemma2PreTrainedModel):
- config: T5Gemma2TextConfig
- _can_record_outputs = {
- "attentions": T5Gemma2SelfAttention,
- "hidden_states": T5Gemma2EncoderLayer,
- }
- def __init__(
- self,
- config: T5Gemma2TextConfig,
- eoi_token_index: int = 256_000,
- ):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.embed_tokens = T5Gemma2TextScaledWordEmbedding(
- config.vocab_size,
- config.hidden_size,
- self.padding_idx,
- embed_scale=config.hidden_size**0.5,
- eoi_token_index=eoi_token_index,
- )
- self.norm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.gradient_checkpointing = False
- self.layers = nn.ModuleList(
- [T5Gemma2EncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.dropout = nn.Dropout(config.dropout_rate)
- self.rotary_emb = T5Gemma2RotaryEmbedding(config)
- # Initialize weights and apply final processing
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- # Unused for processor compatibility kept in signature.
- token_type_ids: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutput:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- # As we want to pass `past_key_values=None` explicitly everywhere, we need to pop them from kwargs if present
- kwargs.pop("past_key_values", None)
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- if position_ids is None:
- position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
- if not isinstance(self_attn_mask_mapping := attention_mask, dict):
- mask_kwargs = {
- "config": self.config,
- "inputs_embeds": inputs_embeds,
- "attention_mask": attention_mask,
- }
- self_attn_mask_mapping = {
- "full_attention": create_bidirectional_mask(**mask_kwargs),
- "sliding_attention": create_bidirectional_mask(
- **mask_kwargs,
- and_mask_function=sliding_window_mask_function(self.config.sliding_window, is_causal=False),
- ),
- }
- # input layer
- hidden_states = inputs_embeds
- # global and local position embeddings
- position_embeddings = {}
- for layer_type in self.config.layer_types:
- position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
- # dropout
- hidden_states = self.dropout(hidden_states)
- for i, layer_module in enumerate(self.layers[: self.config.num_hidden_layers]):
- hidden_states = layer_module(
- hidden_states,
- position_embeddings[self.config.layer_types[i]],
- self_attn_mask_mapping[self.config.layer_types[i]],
- position_ids,
- **kwargs,
- )
- hidden_states = self.norm(hidden_states)
- hidden_states = self.dropout(hidden_states)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- )
- class T5Gemma2Encoder(T5Gemma2PreTrainedModel):
- config: T5Gemma2EncoderConfig
- def __init__(
- self,
- config: T5Gemma2EncoderConfig,
- eoi_token_index: int = 256_000,
- ):
- super().__init__(config)
- self.text_model = T5Gemma2TextEncoder._from_config(config.text_config, eoi_token_index=eoi_token_index)
- self.vision_tower = AutoModel.from_config(config=config.vision_config)
- self.multi_modal_projector = T5Gemma2MultiModalProjector(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.text_model.get_input_embeddings()
- def set_input_embeddings(self, new_embeddings):
- return self.text_model.set_input_embeddings(new_embeddings)
- @can_return_tuple
- @auto_docstring
- def get_image_features(
- self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
- ) -> tuple | BaseModelOutputWithPooling:
- # pixel_values: (batch_size, channels, height, width)
- # image_features: Image feature tensor of shape (num_images, image_length, embed_dim).
- vision_outputs = self.vision_tower(pixel_values=pixel_values, return_dict=True, **kwargs)
- last_hidden_state = vision_outputs.last_hidden_state
- image_features = self.multi_modal_projector(last_hidden_state)
- vision_outputs.pooler_output = image_features
- return vision_outputs
- def get_image_placeholder_mask(
- self,
- input_ids: torch.LongTensor | None,
- inputs_embeds: torch.FloatTensor | None,
- image_features: torch.FloatTensor,
- ):
- """
- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
- equal to the length of multimodal features. If the lengths are different, an error is raised.
- """
- image_token_id = self.config.image_token_id
- if input_ids is None:
- if inputs_embeds is None:
- raise ValueError("Either `input_ids` or `inputs_embeds` has to be provided.")
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- special_image_mask = special_image_mask.all(-1)
- else:
- special_image_mask = input_ids == image_token_id
- n_image_tokens = special_image_mask.sum()
- special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- n_image_features = image_features.shape[0] * image_features.shape[1]
- torch_compilable_check(
- inputs_embeds[special_image_mask].numel() == image_features.numel(),
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}",
- )
- return special_image_mask
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- # Unused for processor compatibility kept in signature.
- token_type_ids: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutput:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds = self.text_model.embed_tokens(input_ids)
- if pixel_values is not None:
- image_features = self.get_image_features(pixel_values, return_dict=True).pooler_output
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- image_mask = self.get_image_placeholder_mask(
- input_ids, inputs_embeds=inputs_embeds, image_features=image_features
- )
- inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features)
- outputs = self.text_model(
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- position_ids=position_ids,
- **kwargs,
- )
- return outputs
- class T5Gemma2Decoder(T5Gemma2PreTrainedModel):
- config: T5Gemma2DecoderConfig
- _can_record_outputs = {
- "attentions": OutputRecorder(T5Gemma2MergedAttention, index=1),
- "cross_attentions": OutputRecorder(T5Gemma2MergedAttention, index=2),
- "hidden_states": T5Gemma2DecoderLayer,
- }
- def __init__(self, config: T5Gemma2DecoderConfig, eoi_token_index: int = 256_000):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.embed_tokens = T5Gemma2TextScaledWordEmbedding(
- config.vocab_size,
- config.hidden_size,
- config.pad_token_id,
- embed_scale=config.hidden_size**0.5,
- eoi_token_index=eoi_token_index,
- )
- self.norm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.gradient_checkpointing = False
- self.layers = nn.ModuleList(
- [T5Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.dropout = nn.Dropout(config.dropout_rate)
- self.rotary_emb = T5Gemma2RotaryEmbedding(config)
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: EncoderDecoderCache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- encoder_hidden_states: torch.Tensor | None = None,
- encoder_attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPastAndCrossAttentions:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if encoder_hidden_states is None:
- raise ValueError("`encoder_hidden_states` must be given in decoder")
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- if not self.training and use_cache and past_key_values is None:
- past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache())
- if position_ids is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
- position_ids = position_ids.unsqueeze(0)
- if not isinstance(self_attn_mask_mapping := attention_mask, dict):
- # this masking function does nothing to masking but forces `allow_is_causal_skip` to be False
- # as we always need a mask during decoding for merged attention.
- dummy_and_mask_function = lambda *args: torch.tensor(True, dtype=torch.bool) # noqa
- mask_kwargs = {
- "config": self.config,
- "inputs_embeds": inputs_embeds,
- "attention_mask": attention_mask,
- "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None,
- "position_ids": position_ids,
- "and_mask_function": dummy_and_mask_function,
- }
- self_attn_mask_mapping = {
- "full_attention": create_causal_mask(**mask_kwargs),
- "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
- }
- if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict):
- cross_attn_mask_mapping = {
- "full_attention": create_bidirectional_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=encoder_attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- and_mask_function=dummy_and_mask_function,
- )
- }
- merged_attn_mask_mapping = {
- "full_attention": torch.cat(
- [self_attn_mask_mapping["full_attention"], cross_attn_mask_mapping["full_attention"]], dim=-1
- ),
- "sliding_attention": torch.cat(
- [self_attn_mask_mapping["sliding_attention"], cross_attn_mask_mapping["full_attention"]], dim=-1
- ),
- }
- # input layer
- hidden_states = inputs_embeds
- # global and local position embeddings
- position_embeddings = {}
- for layer_type in self.config.layer_types:
- position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
- # dropout
- hidden_states = self.dropout(hidden_states)
- for i, layer_module in enumerate(self.layers[: self.config.num_hidden_layers]):
- hidden_states = layer_module(
- hidden_states,
- position_embeddings[self.config.layer_types[i]],
- merged_attn_mask_mapping[self.config.layer_types[i]],
- position_ids,
- past_key_values,
- use_cache,
- encoder_hidden_states,
- **kwargs,
- )
- hidden_states = self.norm(hidden_states)
- hidden_states = self.dropout(hidden_states)
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- )
- @auto_docstring
- class T5Gemma2Model(T5Gemma2PreTrainedModel):
- _tied_weights_keys = {
- "decoder.embed_tokens.weight": "encoder.text_model.embed_tokens.weight",
- "decoder.embed_tokens.eoi_embedding": "encoder.text_model.embed_tokens.eoi_embedding",
- }
- def __init__(self, config: T5Gemma2Config):
- super().__init__(config)
- # setup encoder and decoder
- self.encoder = T5Gemma2Encoder(config.encoder, config.eoi_token_index)
- self.decoder = T5Gemma2Decoder(config.decoder, config.eoi_token_index)
- self.post_init()
- def get_encoder(self):
- return self.encoder
- def get_decoder(self):
- return self.decoder
- def get_input_embeddings(self):
- return self.encoder.get_input_embeddings()
- def set_input_embeddings(self, new_embeddings):
- return self.encoder.set_input_embeddings(new_embeddings)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- # encoder inputs
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- # decoder inputs
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.BoolTensor | None = None,
- decoder_position_ids: torch.LongTensor | None = None,
- # others (mainly inference or cache related)
- encoder_outputs: BaseModelOutput | None = None,
- past_key_values: EncoderDecoderCache | None = None,
- inputs_embeds: torch.Tensor | None = None,
- decoder_inputs_embeds: torch.Tensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> Seq2SeqModelOutput:
- r"""
- decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
- Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
- config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
- """
- # encoder
- if encoder_outputs is None:
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- pixel_values=pixel_values,
- return_dict=True,
- **kwargs,
- )
- encoder_hidden_states = encoder_outputs.last_hidden_state
- # decoder
- decoder_outputs = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- position_ids=decoder_position_ids,
- inputs_embeds=decoder_inputs_embeds,
- past_key_values=past_key_values,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=attention_mask,
- use_cache=use_cache,
- return_dict=True,
- **kwargs,
- )
- return Seq2SeqModelOutput(
- last_hidden_state=decoder_outputs.last_hidden_state,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
- class T5Gemma2ForConditionalGeneration(T5Gemma2PreTrainedModel, GenerationMixin):
- _tied_weights_keys = {
- "lm_head.out_proj.weight": "model.encoder.text_model.embed_tokens.weight",
- }
- _tp_plan = {"lm_head.out_proj": "colwise_gather_output"}
- _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])}
- def __init__(self, config: T5Gemma2Config):
- super().__init__(config)
- self.model = T5Gemma2Model(config)
- self.vocab_size = config.decoder.vocab_size
- self.lm_head = T5Gemma2LMHead(config.decoder.hidden_size, self.vocab_size)
- self.loss_type = "ForMaskedLM"
- self.post_init()
- def set_output_embeddings(self, new_embeddings):
- self.lm_head.out_proj = new_embeddings
- def get_output_embeddings(self):
- return self.lm_head.out_proj
- def get_input_embeddings(self):
- return self.model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.model.set_input_embeddings(value)
- def get_encoder(self):
- return self.model.get_encoder()
- def get_decoder(self):
- return self.model.get_decoder()
- @can_return_tuple
- @auto_docstring
- def get_image_features(
- self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
- ) -> tuple | BaseModelOutputWithPooling:
- return self.get_encoder().get_image_features(pixel_values, **kwargs)
- @property
- def vision_tower(self):
- return self.get_encoder().vision_tower
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- # encoder inputs
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- # decoder inputs
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.BoolTensor | None = None,
- decoder_position_ids: torch.LongTensor | None = None,
- # others (mainly inference or cache related)
- encoder_outputs: BaseModelOutput | None = None,
- past_key_values: EncoderDecoderCache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.FloatTensor] | Seq2SeqLMOutput:
- r"""
- decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
- Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
- config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- """
- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
- # get decoder inputs from shifting lm labels to the right
- decoder_input_ids = self.prepare_decoder_input_ids_from_labels(labels)
- decoder_outputs: Seq2SeqModelOutput = self.model(
- input_ids=input_ids,
- pixel_values=pixel_values,
- attention_mask=attention_mask,
- position_ids=position_ids,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- decoder_position_ids=decoder_position_ids,
- encoder_outputs=encoder_outputs,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = decoder_outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- decoder_config = self.config.decoder
- if decoder_config.final_logit_softcapping is not None:
- logits = logits / decoder_config.final_logit_softcapping
- logits = torch.tanh(logits)
- logits = logits * decoder_config.final_logit_softcapping
- loss = None
- if labels is not None:
- # Input has right-shifted so we directly perform masked lm loss
- loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
- return Seq2SeqLMOutput(
- loss=loss,
- logits=logits,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.decoder_hidden_states,
- decoder_attentions=decoder_outputs.decoder_attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state,
- encoder_hidden_states=decoder_outputs.encoder_hidden_states,
- encoder_attentions=decoder_outputs.encoder_attentions,
- )
- def _prepare_cache_for_generation(
- self,
- generation_config: GenerationConfig,
- model_kwargs: dict,
- generation_mode: GenerationMode,
- batch_size: int,
- max_cache_length: int,
- ) -> bool:
- """Override cache preparation to support T5Gemma2-specific EncoderDecoder Cache."""
- # Build cache and past_key_values structure first and then override as needed.
- super()._prepare_cache_for_generation(
- generation_config,
- model_kwargs,
- generation_mode,
- batch_size,
- max_cache_length,
- )
- # If use_cache is False, do not prepare the cache.
- if generation_config.use_cache is False:
- return
- cache_implementation = generation_config.cache_implementation
- if cache_implementation is None:
- offload_cache = False
- else:
- offload_cache = "offloaded" in generation_config.cache_implementation
- # Main change: use full cache for cross-attention.
- cross_attn_config = copy.deepcopy(self.config.get_text_config(decoder=True))
- # cross-attention does not use sliding window
- del cross_attn_config.sliding_window
- del cross_attn_config.layer_types
- cross_attn_cache_kwargs = {
- "config": cross_attn_config,
- "offloading": offload_cache,
- }
- past_key_values = model_kwargs.get("past_key_values")
- if past_key_values is not None:
- if not isinstance(past_key_values, EncoderDecoderCache):
- raise ValueError(
- "The `past_key_values` in `model_kwargs` must be of type `EncoderDecoderCache` for T5Gemma2 model."
- )
- # Cache already established, no need to re-initialize.
- if len(past_key_values.is_updated) > 0 and past_key_values.is_updated.get(0):
- return
- cross_attn_cls = type(past_key_values.cross_attention_cache)
- if cross_attn_cls == StaticCache:
- cross_attn_cache_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
- # Update cross-attention cache only (switch from sliding_window to full).
- past_key_values.cross_attention_cache = cross_attn_cls(**cross_attn_cache_kwargs)
- else:
- # Initialize new cache.
- model_kwargs["past_key_values"] = EncoderDecoderCache(
- DynamicCache(
- **{
- "config": self.config.get_text_config(decoder=True),
- "offloading": offload_cache,
- }
- ), # self-attention cache
- DynamicCache(), # cross-attention cache
- )
- if hasattr(self, "_cache") and self._cache is not None:
- if not isinstance(self._cache, EncoderDecoderCache):
- raise ValueError("The internal cache must be of type `EncoderDecoderCache` for T5Gemma2 model.")
- self._cache = model_kwargs["past_key_values"]
- @auto_docstring
- class T5Gemma2ForSequenceClassification(T5Gemma2PreTrainedModel):
- def __init__(self, config: T5Gemma2Config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.hidden_size = config.decoder.hidden_size
- self.model = T5Gemma2Model(config)
- classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1)
- self.score = T5Gemma2ClassificationHead(self.hidden_size, self.num_labels, classifier_dropout)
- self.post_init()
- def get_input_embeddings(self):
- return self.model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.model.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.Tensor | None = None,
- decoder_position_ids: torch.LongTensor | None = None,
- encoder_outputs: BaseModelOutput | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> SequenceClassifierOutput:
- r"""
- decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
- Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
- config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- if inputs_embeds is not None or decoder_inputs_embeds is not None:
- raise NotImplementedError(
- f"Passing input embeddings is currently not supported for {self.__class__.__name__}."
- )
- if input_ids is None:
- raise ValueError("You have to specify input_ids")
- if decoder_input_ids is None:
- decoder_input_ids = self.prepare_decoder_input_ids_from_labels(input_ids)
- outputs: Seq2SeqModelOutput = self.model(
- input_ids,
- pixel_values=pixel_values,
- attention_mask=attention_mask,
- position_ids=position_ids,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- decoder_position_ids=decoder_position_ids,
- encoder_outputs=encoder_outputs,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=False,
- **kwargs,
- )
- last_hidden_state = outputs.last_hidden_state
- hidden_states = outputs.decoder_hidden_states
- attentions = outputs.decoder_attentions
- logits = self.score(last_hidden_state)
- batch_size = input_ids.shape[0]
- # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
- non_pad_mask = (decoder_input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
- token_indices = torch.arange(decoder_input_ids.shape[-1], device=logits.device, dtype=torch.int32)
- last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
- last_non_pad_token = torch.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1)
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
- return SequenceClassifierOutput(
- loss=loss,
- logits=pooled_logits,
- hidden_states=hidden_states,
- attentions=attentions,
- )
- @auto_docstring
- class T5Gemma2ForTokenClassification(T5Gemma2PreTrainedModel):
- def __init__(self, config: T5Gemma2Config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.hidden_size = config.decoder.hidden_size
- self.model = T5Gemma2Model(config)
- classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1)
- self.score = T5Gemma2ClassificationHead(self.hidden_size, self.num_labels, classifier_dropout)
- self.post_init()
- def get_input_embeddings(self):
- return self.model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.model.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.Tensor | None = None,
- decoder_position_ids: torch.LongTensor | None = None,
- encoder_outputs: BaseModelOutput | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> TokenClassifierOutput:
- r"""
- decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
- Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
- config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- if inputs_embeds is not None or decoder_inputs_embeds is not None:
- raise NotImplementedError(
- f"Passing input embeddings is currently not supported for {self.__class__.__name__}."
- )
- if input_ids is None:
- raise ValueError("You have to specify input_ids")
- if decoder_input_ids is None:
- decoder_input_ids = self.prepare_decoder_input_ids_from_labels(input_ids)
- outputs: Seq2SeqModelOutput = self.model(
- input_ids,
- pixel_values=pixel_values,
- attention_mask=attention_mask,
- position_ids=position_ids,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- decoder_position_ids=decoder_position_ids,
- encoder_outputs=encoder_outputs,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=False,
- **kwargs,
- )
- last_hidden_state = outputs.last_hidden_state
- hidden_states = outputs.decoder_hidden_states
- attentions = outputs.decoder_attentions
- logits = self.score(last_hidden_state)
- loss = None
- if labels is not None:
- loss = self.loss_function(logits, labels, self.config)
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=hidden_states,
- attentions=attentions,
- )
- __all__ = [
- "T5Gemma2ForConditionalGeneration",
- "T5Gemma2Model",
- "T5Gemma2Encoder",
- "T5Gemma2PreTrainedModel",
- "T5Gemma2ForSequenceClassification",
- "T5Gemma2ForTokenClassification",
- ]
|