| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151 |
- # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
- #
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Callable
- from typing import Any
- import torch
- import torch.nn as nn
- from huggingface_hub.dataclasses import strict
- from ... import initialization as init
- from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
- from ...configuration_utils import PreTrainedConfig
- from ...generation import GenerationMixin
- from ...masking_utils import (
- create_bidirectional_mask,
- create_bidirectional_sliding_window_mask,
- create_causal_mask,
- create_sliding_window_causal_mask,
- )
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPastAndCrossAttentions,
- Seq2SeqLMOutput,
- Seq2SeqModelOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
- )
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import (
- TransformersKwargs,
- auto_docstring,
- can_return_tuple,
- logging,
- )
- from ...utils.generic import merge_with_config_defaults
- from ...utils.output_capturing import OutputRecorder, capture_outputs
- from ..gemma2.configuration_gemma2 import Gemma2Config
- from ..gemma2.modeling_gemma2 import (
- Gemma2Attention,
- Gemma2MLP,
- Gemma2PreTrainedModel,
- Gemma2RMSNorm,
- Gemma2RotaryEmbedding,
- eager_attention_forward,
- )
- logger = logging.get_logger(__name__)
- @auto_docstring(checkpoint="google/t5_gemma_module-7b")
- @strict
- class T5GemmaModuleConfig(Gemma2Config):
- r"""
- query_pre_attn_scalar (`float`, *optional*, defaults to 256):
- scaling factor used on the attention scores
- final_logit_softcapping (`float`, *optional*, defaults to 30.0):
- scaling factor when applying tanh softcapping on the logits.
- attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
- scaling factor when applying tanh softcapping on the attention scores.
- ```python
- >>> from transformers import T5GemmaModuleModel, T5GemmaModuleConfig
- >>> # Initializing a T5GemmaModule t5_gemma_module-7b style configuration
- >>> configuration = T5GemmaModuleConfig()
- >>> # Initializing a model from the t5_gemma_module-7b style configuration
- >>> model = T5GemmaModuleModel(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
- is_decoder: bool = False
- use_bidirectional_attention = AttributeError()
- @auto_docstring(checkpoint="google/t5_gemma_module-7b")
- @strict
- class T5GemmaConfig(PreTrainedConfig):
- r"""
- encoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*):
- Configuration for the encoder.
- decoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*):
- Configuration for the decoder.
- Example:
- ```python
- >>> from transformers import T5GemmaConfig, T5GemmaModel
- >>> t5gemma_config = T5GemmaConfig.from_pretrained("google/t5gemma-2b-2b-prefixlm-it")
- >>> model = T5GemmaModel(t5gemma_config)
- ```"""
- model_type = "t5gemma"
- keys_to_ignore_at_inference = ["past_key_values"]
- sub_configs = {"encoder": T5GemmaModuleConfig, "decoder": T5GemmaModuleConfig}
- encoder: T5GemmaModuleConfig | dict[Any, Any] | None = None
- decoder: T5GemmaModuleConfig | dict[Any, Any] | None = None
- is_encoder_decoder: bool = True
- dropout_rate: int | float = 0.0
- classifier_dropout_rate: int | float = 0.0
- attention_dropout: float | int = 0.0
- tie_word_embeddings: bool = True
- vocab_size: int = 256000
- def __post_init__(self, **kwargs):
- if isinstance(self.encoder, dict):
- self.encoder = T5GemmaModuleConfig(**self.encoder)
- elif self.encoder is None:
- self.encoder = T5GemmaModuleConfig()
- if isinstance(self.decoder, dict):
- self.decoder = T5GemmaModuleConfig(**self.decoder)
- elif self.decoder is None:
- self.decoder = T5GemmaModuleConfig()
- self.encoder.is_decoder = False
- self.encoder.dropout_rate = self.dropout_rate
- self.encoder.attention_dropout = self.attention_dropout
- self.decoder.is_decoder = True
- self.decoder.use_cache = True
- self.decoder.dropout_rate = self.dropout_rate
- self.decoder.attention_dropout = self.attention_dropout
- self.decoder.cross_attention_hidden_size = self.encoder.hidden_size
- self.initializer_range = kwargs.pop("initializer_range", self.decoder.initializer_range)
- for special_token_key in ["bos_token_id", "pad_token_id", "eos_token_id"]:
- if special_token_key not in kwargs:
- kwargs[special_token_key] = getattr(self.decoder, special_token_key)
- super().__post_init__(**kwargs)
- class T5GemmaRMSNorm(Gemma2RMSNorm):
- pass
- class T5GemmaMLP(Gemma2MLP):
- def __init__(self, config):
- super().__init__(config)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(self, x):
- hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
- hidden_states = self.dropout(hidden_states)
- down_proj = self.down_proj(hidden_states)
- return down_proj
- class T5GemmaRotaryEmbedding(Gemma2RotaryEmbedding):
- pass
- class T5GemmaSelfAttention(Gemma2Attention):
- def __init__(self, config: T5GemmaModuleConfig, layer_idx: int):
- super().__init__(config, layer_idx)
- # Required by flash attention: encoder selfattention is non-causal
- self.is_causal = config.is_decoder
- class T5GemmaCrossAttention(Gemma2Attention):
- def __init__(self, config: T5GemmaModuleConfig, layer_idx: int):
- super().__init__(config, layer_idx)
- del self.sliding_window
- del self.layer_type
- self.is_causal = False
- if config.cross_attention_hidden_size is None:
- raise ValueError("Cross-attention needs cross_attention_hidden_size to be specified.")
- self.k_proj = nn.Linear(
- config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None,
- encoder_hidden_states: torch.Tensor | None,
- past_key_values: Cache | None = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
- if encoder_hidden_states is None:
- raise ValueError("Encoder hidden state is required for cross attention.")
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- if past_key_values is not None:
- is_updated = past_key_values.is_updated.get(self.layer_idx)
- curr_past_key_values = past_key_values.cross_attention_cache
- if past_key_values is None or not is_updated:
- encoder_input_shape = encoder_hidden_states.shape[:-1]
- encoder_hidden_shape = (*encoder_input_shape, -1, self.head_dim)
- key_states = self.k_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
- value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
- if past_key_values is not None:
- key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
- past_key_values.is_updated[self.layer_idx] = True
- else:
- key_states = curr_past_key_values.layers[self.layer_idx].keys
- value_states = curr_past_key_values.layers[self.layer_idx].values
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=self.attention_dropout if self.training else 0.0,
- scaling=self.scaling,
- sliding_window=None,
- softcap=self.attn_logit_softcapping,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class T5GemmaEncoderLayer(GradientCheckpointingLayer):
- """Encoder sub-layer."""
- def __init__(self, config, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.config = config
- self.layer_idx = layer_idx
- self.attention_type = config.layer_types[layer_idx]
- self.self_attn = T5GemmaSelfAttention(
- config=config,
- layer_idx=layer_idx,
- )
- self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.mlp = T5GemmaMLP(config)
- self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- **kwargs,
- ) -> tuple[torch.FloatTensor,]:
- residual = hidden_states
- hidden_states = self.pre_self_attn_layernorm(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=None,
- **kwargs,
- )
- hidden_states = self.post_self_attn_layernorm(hidden_states)
- hidden_states = residual + self.dropout(hidden_states)
- residual = hidden_states
- hidden_states = self.pre_feedforward_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = self.post_feedforward_layernorm(hidden_states)
- hidden_states = residual + self.dropout(hidden_states)
- return hidden_states
- class T5GemmaDecoderLayer(GradientCheckpointingLayer):
- """Decoder sub-layer: an extra cross-attention layer."""
- def __init__(self, config, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.config = config
- self.layer_idx = layer_idx
- self.attention_type = config.layer_types[layer_idx]
- self.self_attn = T5GemmaSelfAttention(
- config=config,
- layer_idx=layer_idx,
- )
- self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.mlp = T5GemmaMLP(config)
- self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.dropout = nn.Dropout(config.dropout_rate)
- self.cross_attn = T5GemmaCrossAttention(config=config, layer_idx=layer_idx)
- self.pre_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: EncoderDecoderCache | None = None,
- use_cache: bool | None = False,
- encoder_hidden_states: torch.Tensor | None = None,
- encoder_attention_mask: torch.Tensor | None = None,
- **kwargs,
- ) -> torch.FloatTensor:
- residual = hidden_states
- hidden_states = self.pre_self_attn_layernorm(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = self.post_self_attn_layernorm(hidden_states)
- hidden_states = residual + self.dropout(hidden_states)
- residual = hidden_states
- hidden_states = self.pre_cross_attn_layernorm(hidden_states)
- hidden_states, _ = self.cross_attn(
- hidden_states=hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- past_key_values=past_key_values,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = self.post_cross_attn_layernorm(hidden_states)
- hidden_states = residual + self.dropout(hidden_states)
- residual = hidden_states
- hidden_states = self.pre_feedforward_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = self.post_feedforward_layernorm(hidden_states)
- hidden_states = residual + self.dropout(hidden_states)
- return hidden_states
- class T5GemmaClassificationHead(nn.Module):
- """Head for sentence-level classification tasks."""
- def __init__(self, hidden_size: int, num_labels: int, classifier_dropout_rate: float = 0.0):
- super().__init__()
- self.dropout = nn.Dropout(p=classifier_dropout_rate)
- self.out_proj = nn.Linear(hidden_size, num_labels)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.out_proj(hidden_states)
- return hidden_states
- class T5GemmaLMHead(nn.Module):
- """Head for language modeling (generation) tasks."""
- def __init__(self, hidden_size: int, vocab_size: int, bias: bool = False):
- super().__init__()
- self.out_proj = nn.Linear(hidden_size, vocab_size, bias=bias)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- logits = self.out_proj(hidden_states)
- return logits
- @auto_docstring
- class T5GemmaPreTrainedModel(Gemma2PreTrainedModel):
- config: T5GemmaConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["T5GemmaEncoderLayer", "T5GemmaDecoderLayer"]
- _can_record_outputs = {
- "hidden_states": T5GemmaDecoderLayer,
- "attentions": [
- OutputRecorder(T5GemmaSelfAttention, index=1, layer_name="self_attn"),
- OutputRecorder(T5GemmaSelfAttention, index=1, layer_name="cross_attn"),
- OutputRecorder(T5GemmaCrossAttention, index=1, layer_name="cross_attn"),
- ],
- }
- @torch.no_grad()
- def _init_weights(self, module):
- # TODO: support initialization for encoders and decoders separately(?)
- PreTrainedModel._init_weights(self, module)
- std = self.config.initializer_range
- if isinstance(module, T5GemmaClassificationHead):
- scale = module.out_proj.weight.shape[0] ** -0.5
- init.normal_(module.out_proj.weight, mean=0.0, std=std * scale)
- if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
- init.zeros_(module.out_proj.bias)
- elif isinstance(module, T5GemmaLMHead):
- if not self.config.tie_word_embeddings:
- scale = module.out_proj.weight.shape[0] ** -0.5
- init.normal_(module.out_proj.weight, mean=0.0, std=std * scale)
- # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
- elif "RMSNorm" in module.__class__.__name__:
- init.zeros_(module.weight)
- def _shift_right(self, input_ids):
- """
- Shifts input_ids to the right, prepends the decoder_start_token_id, and handles
- pad_token_id replacement for labels that were -100.
- This is a common preparation step for decoder inputs in sequence-to-sequence models.
- """
- decoder_start_token_id = self.config.decoder.bos_token_id
- pad_token_id = self.config.decoder.pad_token_id
- if decoder_start_token_id is None:
- raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ")
- # shift inputs to the right
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
- shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
- shifted_input_ids[..., 0] = decoder_start_token_id
- if pad_token_id is None:
- raise ValueError("self.model.config.decoder.pad_token_id has to be defined.")
- # Is this T5 specific?
- # replace possible -100 values in labels by `pad_token_id`
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
- return shifted_input_ids
- def make_default_2d_attention_mask(
- token_ids: torch.LongTensor | None,
- hidden_states: torch.Tensor,
- pad_token_id: int | None,
- ) -> torch.Tensor:
- """Construct the default attention mask."""
- if token_ids is not None:
- if pad_token_id is None:
- raise ValueError("`pad_token_id` is required for padding information.")
- attention_mask = (token_ids != pad_token_id).to(hidden_states.device, torch.long)
- else:
- attention_mask = torch.ones(
- (hidden_states.shape[0], hidden_states.shape[1]), device=hidden_states.device, dtype=torch.long
- )
- return attention_mask
- class T5GemmaEncoder(T5GemmaPreTrainedModel):
- _can_record_outputs = {
- "attentions": T5GemmaSelfAttention,
- "hidden_states": T5GemmaEncoderLayer,
- }
- def __init__(self, config):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.gradient_checkpointing = False
- self.layers = nn.ModuleList(
- [T5GemmaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.dropout = nn.Dropout(config.dropout_rate)
- self.rotary_emb = T5GemmaRotaryEmbedding(config=config)
- # Initialize weights and apply final processing
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutput:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- # As we want to pass `past_key_values=None` explicitly everywhere, we need to pop them from kwargs if present
- kwargs.pop("past_key_values", None)
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- if position_ids is None:
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
- position_ids = position_ids.unsqueeze(0)
- if attention_mask is None:
- attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id)
- if not isinstance(self_attn_mask_mapping := attention_mask, dict):
- mask_kwargs = {
- "config": self.config,
- "inputs_embeds": inputs_embeds,
- "attention_mask": attention_mask,
- }
- self_attn_mask_mapping = {
- "full_attention": create_bidirectional_mask(**mask_kwargs),
- "sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs),
- }
- hidden_states = inputs_embeds
- normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
- hidden_states = hidden_states * normalizer
- hidden_states = self.dropout(hidden_states)
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
- for i, layer_module in enumerate(self.layers[: self.config.num_hidden_layers]):
- hidden_states = layer_module(
- hidden_states,
- position_embeddings,
- self_attn_mask_mapping[self.config.layer_types[i]],
- position_ids,
- **kwargs,
- )
- hidden_states = self.norm(hidden_states)
- hidden_states = self.dropout(hidden_states)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- )
- class T5GemmaDecoder(T5GemmaPreTrainedModel):
- _can_record_outputs = {
- "attentions": OutputRecorder(T5GemmaSelfAttention, index=1),
- "cross_attentions": OutputRecorder(T5GemmaCrossAttention, index=1),
- "hidden_states": T5GemmaDecoderLayer,
- }
- def __init__(self, config):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.gradient_checkpointing = False
- self.layers = nn.ModuleList(
- [T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.dropout = nn.Dropout(config.dropout_rate)
- self.rotary_emb = T5GemmaRotaryEmbedding(config=config)
- # Initialize weights and apply final processing
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: EncoderDecoderCache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- encoder_hidden_states: torch.Tensor | None = None,
- encoder_attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if encoder_hidden_states is None:
- raise ValueError("`encoder_hidden_states` must be given in decoder")
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- if not self.training and use_cache and past_key_values is None:
- # We do not pass the config to the cross attn cache to avoid initializing SWA
- # --> we use full attention between our cross attentions
- past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache())
- if position_ids is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
- position_ids = position_ids.unsqueeze(0)
- if attention_mask is None and past_key_values is None:
- attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id)
- if not isinstance(self_attn_mask_mapping := attention_mask, dict):
- mask_kwargs = {
- "config": self.config,
- "inputs_embeds": inputs_embeds,
- "attention_mask": attention_mask,
- "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None,
- "position_ids": position_ids,
- }
- self_attn_mask_mapping = {
- "full_attention": create_causal_mask(**mask_kwargs),
- "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
- }
- if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict):
- cross_attn_mask_mapping = {
- "full_attention": create_bidirectional_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=encoder_attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- )
- }
- hidden_states = inputs_embeds
- normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
- hidden_states = hidden_states * normalizer
- hidden_states = self.dropout(hidden_states)
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
- for i, layer_module in enumerate(self.layers[: self.config.num_hidden_layers]):
- hidden_states = layer_module(
- hidden_states,
- position_embeddings,
- self_attn_mask_mapping[self.config.layer_types[i]],
- position_ids,
- past_key_values,
- use_cache,
- encoder_hidden_states,
- cross_attn_mask_mapping["full_attention"],
- **kwargs,
- )
- hidden_states = self.norm(hidden_states)
- hidden_states = self.dropout(hidden_states)
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- )
- @auto_docstring
- class T5GemmaModel(T5GemmaPreTrainedModel):
- def __init__(self, config: T5GemmaConfig):
- super().__init__(config)
- if not config.is_encoder_decoder:
- raise ValueError("T5GemmaModel only support encoder-decoder modeling. Use `T5GemmaEncoderModel` instead.")
- self.encoder = T5GemmaEncoder(config.encoder)
- self.decoder = T5GemmaDecoder(config.decoder)
- self.post_init()
- def get_input_embeddings(self):
- return self.encoder.get_input_embeddings()
- def set_input_embeddings(self, new_embeddings):
- return self.encoder.set_input_embeddings(new_embeddings)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.BoolTensor | None = None,
- decoder_position_ids: torch.LongTensor | None = None,
- encoder_outputs: BaseModelOutput | None = None,
- past_key_values: EncoderDecoderCache | None = None,
- inputs_embeds: torch.Tensor | None = None,
- decoder_inputs_embeds: torch.Tensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> Seq2SeqModelOutput:
- r"""
- decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
- Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
- config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
- """
- if encoder_outputs is None:
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- **kwargs,
- )
- encoder_hidden_states = encoder_outputs.last_hidden_state
- decoder_outputs = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- position_ids=decoder_position_ids,
- inputs_embeds=decoder_inputs_embeds,
- past_key_values=past_key_values,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=attention_mask,
- use_cache=use_cache,
- **kwargs,
- )
- return Seq2SeqModelOutput(
- last_hidden_state=decoder_outputs.last_hidden_state,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states
- if kwargs.get("output_hidden_states", False)
- else (decoder_outputs.last_hidden_state,),
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
- @auto_docstring
- class T5GemmaEncoderModel(T5GemmaPreTrainedModel):
- def __init__(self, config: T5GemmaConfig):
- super().__init__(config)
- if config.is_encoder_decoder:
- raise ValueError("T5GemmaEncoderModel only supports encoder-only model. Use `T5GemmaModel` instead.")
- self.encoder = T5GemmaEncoder(config.encoder)
- self.post_init()
- def get_input_embeddings(self):
- return self.encoder.get_input_embeddings()
- def set_input_embeddings(self, new_embeddings):
- return self.encoder.set_input_embeddings(new_embeddings)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutput:
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- **kwargs,
- )
- return encoder_outputs
- class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"lm_head.out_proj.weight": "model.decoder.embed_tokens.weight"}
- _tp_plan = {"lm_head.out_proj": "colwise_gather_output"}
- _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])}
- def __init__(self, config: T5GemmaConfig):
- config.is_encoder_decoder = True
- super().__init__(config)
- self.model = T5GemmaModel(config)
- self.vocab_size = config.decoder.vocab_size
- self.lm_head = T5GemmaLMHead(config.decoder.hidden_size, self.vocab_size)
- self.loss_type = "ForMaskedLM"
- self.post_init()
- def set_output_embeddings(self, new_embeddings):
- self.lm_head.out_proj = new_embeddings
- def get_output_embeddings(self):
- return self.lm_head.out_proj
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.BoolTensor | None = None,
- decoder_position_ids: torch.LongTensor | None = None,
- encoder_outputs: BaseModelOutput | None = None,
- past_key_values: EncoderDecoderCache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.FloatTensor] | Seq2SeqLMOutput:
- r"""
- decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
- Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
- config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- """
- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
- # get decoder inputs from shifting lm labels to the right
- decoder_input_ids = self._shift_right(labels)
- decoder_outputs: Seq2SeqModelOutput = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- decoder_position_ids=decoder_position_ids,
- encoder_outputs=encoder_outputs,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = decoder_outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- decoder_config = self.get_decoder().config
- if decoder_config.final_logit_softcapping is not None:
- logits = logits / decoder_config.final_logit_softcapping
- logits = torch.tanh(logits)
- logits = logits * decoder_config.final_logit_softcapping
- loss = None
- if labels is not None:
- # Input has right-shifted so we directly perform masked lm loss
- loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
- return Seq2SeqLMOutput(
- loss=loss,
- logits=logits,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.decoder_hidden_states,
- decoder_attentions=decoder_outputs.decoder_attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state,
- encoder_hidden_states=decoder_outputs.encoder_hidden_states,
- encoder_attentions=decoder_outputs.encoder_attentions,
- )
- def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
- return self._shift_right(labels)
- @auto_docstring
- class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
- def __init__(self, config: T5GemmaConfig, is_encoder_decoder: bool | None = None):
- r"""
- is_encoder_decoder (`Optional`, *optional*):
- Whether use encoder_decoder for sequence classification. When set to False, only encoder is used.
- """
- if is_encoder_decoder is not None:
- config.is_encoder_decoder = is_encoder_decoder
- super().__init__(config)
- self.num_labels = config.num_labels
- if config.is_encoder_decoder:
- self.model = T5GemmaModel(config)
- else:
- self.model = T5GemmaEncoderModel(config)
- hidden_size = config.encoder.hidden_size
- if config.is_encoder_decoder:
- hidden_size = config.decoder.hidden_size
- classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1)
- self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout)
- self.post_init()
- def get_input_embeddings(self):
- return self.model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.model.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.Tensor | None = None,
- decoder_position_ids: torch.LongTensor | None = None,
- encoder_outputs: BaseModelOutput | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> SequenceClassifierOutput:
- r"""
- decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
- Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
- config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None):
- raise NotImplementedError(
- f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode."
- )
- # Following T5, we automatically creates decoder_input_ids from input_ids if no decoder_input_ids are provided
- if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None):
- if input_ids is None:
- raise ValueError(
- "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
- "passed, `input_ids` cannot be `None`. Please pass either "
- "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
- )
- decoder_input_ids = self._shift_right(input_ids)
- if self.config.is_encoder_decoder:
- outputs: Seq2SeqModelOutput = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- decoder_position_ids=decoder_position_ids,
- encoder_outputs=encoder_outputs,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=False,
- **kwargs,
- )
- last_hidden_state = outputs.last_hidden_state
- hidden_states = outputs.decoder_hidden_states
- attentions = outputs.decoder_attentions
- else:
- outputs: BaseModelOutput = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- **kwargs,
- )
- last_hidden_state = outputs.last_hidden_state
- hidden_states = outputs.hidden_states
- attentions = outputs.attentions
- logits = self.score(last_hidden_state)
- if input_ids is not None:
- batch_size = input_ids.shape[0]
- else:
- batch_size = inputs_embeds.shape[0]
- if self.config.pad_token_id is None and batch_size != 1:
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
- if self.config.pad_token_id is None:
- last_non_pad_token = -1
- elif input_ids is not None:
- # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
- non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
- token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
- last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
- if self.config.is_encoder_decoder:
- last_non_pad_token += 1 # due to the right shift.
- last_non_pad_token = torch.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1)
- else:
- last_non_pad_token = -1
- logger.warning_once(
- f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
- )
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
- return SequenceClassifierOutput(
- loss=loss,
- logits=pooled_logits,
- hidden_states=hidden_states,
- attentions=attentions,
- )
- @auto_docstring
- class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
- def __init__(self, config: T5GemmaConfig, is_encoder_decoder: bool | None = None):
- r"""
- is_encoder_decoder (`Optional`, *optional*):
- Whether use encoder_decoder for token classification. When set to False, only encoder is used.
- """
- if is_encoder_decoder is not None:
- config.is_encoder_decoder = is_encoder_decoder
- super().__init__(config)
- self.num_labels = config.num_labels
- if config.is_encoder_decoder:
- self.model = T5GemmaModel(config)
- else:
- self.model = T5GemmaEncoderModel(config)
- hidden_size = config.encoder.hidden_size
- if config.is_encoder_decoder:
- hidden_size = config.decoder.hidden_size
- classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1)
- self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout)
- self.post_init()
- def get_input_embeddings(self):
- return self.model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.model.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.Tensor | None = None,
- decoder_position_ids: torch.LongTensor | None = None,
- encoder_outputs: BaseModelOutput | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> TokenClassifierOutput:
- r"""
- decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
- Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
- config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None):
- raise NotImplementedError(
- f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode."
- )
- if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None):
- if input_ids is None:
- raise ValueError(
- "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
- "passed, `input_ids` cannot be `None`. Please pass either "
- "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
- )
- decoder_input_ids = self._shift_right(input_ids)
- if self.config.is_encoder_decoder:
- outputs: Seq2SeqModelOutput = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- decoder_position_ids=decoder_position_ids,
- encoder_outputs=encoder_outputs,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=False,
- **kwargs,
- )
- last_hidden_state = outputs.last_hidden_state
- hidden_states = outputs.decoder_hidden_states
- attentions = outputs.decoder_attentions
- else:
- outputs: BaseModelOutput = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- **kwargs,
- )
- last_hidden_state = outputs.last_hidden_state
- hidden_states = outputs.hidden_states
- attentions = outputs.attentions
- logits = self.score(last_hidden_state)
- loss = None
- if labels is not None:
- loss = self.loss_function(logits, labels, self.config)
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=hidden_states,
- attentions=attentions,
- )
- __all__ = [
- "T5GemmaConfig",
- "T5GemmaModuleConfig",
- "T5GemmaForConditionalGeneration",
- "T5GemmaModel",
- "T5GemmaEncoderModel",
- "T5GemmaPreTrainedModel",
- "T5GemmaForSequenceClassification",
- "T5GemmaForTokenClassification",
- ]
|