modeling_t5gemma2.py 67 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/t5gemma2/modular_t5gemma2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_t5gemma2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import copy
  22. from collections.abc import Callable
  23. from typing import Optional
  24. import torch
  25. import torch.nn as nn
  26. from ... import initialization as init
  27. from ...activations import ACT2FN
  28. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
  29. from ...generation import GenerationConfig, GenerationMixin, GenerationMode
  30. from ...integrations import use_kernel_func_from_hub, use_kernelized_func
  31. from ...masking_utils import create_bidirectional_mask, create_causal_mask, create_sliding_window_causal_mask
  32. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  33. from ...modeling_layers import GradientCheckpointingLayer
  34. from ...modeling_outputs import (
  35. BaseModelOutput,
  36. BaseModelOutputWithPastAndCrossAttentions,
  37. BaseModelOutputWithPooling,
  38. Seq2SeqLMOutput,
  39. Seq2SeqModelOutput,
  40. SequenceClassifierOutput,
  41. TokenClassifierOutput,
  42. )
  43. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  44. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  45. from ...processing_utils import Unpack
  46. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
  47. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  48. from ...utils.output_capturing import OutputRecorder, capture_outputs
  49. from ..auto import AutoModel
  50. from .configuration_t5gemma2 import T5Gemma2Config, T5Gemma2DecoderConfig, T5Gemma2EncoderConfig, T5Gemma2TextConfig
  51. class T5Gemma2RMSNorm(nn.Module):
  52. def __init__(self, dim: int, eps: float = 1e-6):
  53. super().__init__()
  54. self.eps = eps
  55. self.weight = nn.Parameter(torch.zeros(dim))
  56. def _norm(self, x):
  57. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  58. def forward(self, x):
  59. output = self._norm(x.float())
  60. # Llama does x.to(float16) * w whilst T5Gemma2 is (x * w).to(float16)
  61. # See https://github.com/huggingface/transformers/pull/29402
  62. output = output * (1.0 + self.weight.float())
  63. return output.type_as(x)
  64. def extra_repr(self):
  65. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  66. class T5Gemma2MLP(nn.Module):
  67. def __init__(self, config: T5Gemma2TextConfig):
  68. super().__init__()
  69. self.config = config
  70. self.hidden_size = config.hidden_size
  71. self.intermediate_size = config.intermediate_size
  72. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  73. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  74. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  75. self.act_fn = ACT2FN[config.hidden_activation]
  76. self.dropout = nn.Dropout(config.dropout_rate)
  77. def forward(self, x):
  78. hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
  79. hidden_states = self.dropout(hidden_states)
  80. down_proj = self.down_proj(hidden_states)
  81. return down_proj
  82. class T5Gemma2RotaryEmbedding(nn.Module):
  83. inv_freq: torch.Tensor # fix linting for `register_buffer`
  84. def __init__(self, config: T5Gemma2TextConfig, device=None):
  85. super().__init__()
  86. self.max_seq_len_cached = config.max_position_embeddings
  87. self.original_max_seq_len = config.max_position_embeddings
  88. self.config = config
  89. self.layer_types = list(set(config.layer_types))
  90. self.rope_type = {}
  91. for layer_type in self.layer_types:
  92. rope_params = self.config.rope_parameters[layer_type]
  93. if rope_params is None:
  94. continue
  95. self.rope_type[layer_type] = rope_params["rope_type"]
  96. rope_init_fn: Callable = self.compute_default_rope_parameters
  97. if self.rope_type[layer_type] != "default":
  98. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
  99. curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
  100. self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
  101. self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
  102. setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
  103. @staticmethod
  104. def compute_default_rope_parameters(
  105. config: T5Gemma2TextConfig | None = None,
  106. device: Optional["torch.device"] = None,
  107. seq_len: int | None = None,
  108. layer_type: str | None = None,
  109. ) -> tuple["torch.Tensor", float]:
  110. """
  111. Computes the inverse frequencies according to the original RoPE implementation
  112. Args:
  113. config ([`~transformers.PreTrainedConfig`]):
  114. The model configuration.
  115. device (`torch.device`):
  116. The device to use for initialization of the inverse frequencies.
  117. seq_len (`int`, *optional*):
  118. The current sequence length. Unused for this type of RoPE.
  119. layer_type (`str`, *optional*):
  120. The current layer type if the model has different RoPE parameters per type.
  121. Should not be used unless `config.layer_types is not None`
  122. Returns:
  123. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  124. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  125. """
  126. # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
  127. base = config.rope_parameters[layer_type]["rope_theta"]
  128. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  129. attention_factor = 1.0 # Unused in this type of RoPE
  130. # Compute the inverse frequencies
  131. inv_freq = 1.0 / (
  132. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  133. )
  134. return inv_freq, attention_factor
  135. @torch.no_grad()
  136. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  137. def forward(self, x, position_ids, layer_type=None):
  138. inv_freq = getattr(self, f"{layer_type}_inv_freq")
  139. attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
  140. inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  141. position_ids_expanded = position_ids[:, None, :].float()
  142. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  143. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  144. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  145. emb = torch.cat((freqs, freqs), dim=-1)
  146. cos = emb.cos() * attention_scaling
  147. sin = emb.sin() * attention_scaling
  148. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  149. def rotate_half(x):
  150. """Rotates half the hidden dims of the input."""
  151. x1 = x[..., : x.shape[-1] // 2]
  152. x2 = x[..., x.shape[-1] // 2 :]
  153. return torch.cat((-x2, x1), dim=-1)
  154. @use_kernel_func_from_hub("rotary_pos_emb")
  155. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  156. """Applies Rotary Position Embedding to the query and key tensors.
  157. Args:
  158. q (`torch.Tensor`): The query tensor.
  159. k (`torch.Tensor`): The key tensor.
  160. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  161. sin (`torch.Tensor`): The sine part of the rotary embedding.
  162. unsqueeze_dim (`int`, *optional*, defaults to 1):
  163. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  164. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  165. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  166. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  167. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  168. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  169. Returns:
  170. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  171. """
  172. cos = cos.unsqueeze(unsqueeze_dim)
  173. sin = sin.unsqueeze(unsqueeze_dim)
  174. q_embed = (q * cos) + (rotate_half(q) * sin)
  175. k_embed = (k * cos) + (rotate_half(k) * sin)
  176. return q_embed, k_embed
  177. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  178. """
  179. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  180. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  181. """
  182. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  183. if n_rep == 1:
  184. return hidden_states
  185. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  186. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  187. def eager_attention_forward(
  188. module: nn.Module,
  189. query: torch.Tensor,
  190. key: torch.Tensor,
  191. value: torch.Tensor,
  192. attention_mask: torch.Tensor | None,
  193. dropout: float | int = 0.0,
  194. scaling: float | None = None,
  195. softcap: float | None = None,
  196. **kwargs,
  197. ) -> tuple[torch.Tensor, torch.Tensor]:
  198. if scaling is None:
  199. scaling = module.head_dim**-0.5
  200. key_states = repeat_kv(key, module.num_key_value_groups)
  201. value_states = repeat_kv(value, module.num_key_value_groups)
  202. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  203. if softcap is not None:
  204. attn_weights = attn_weights / softcap
  205. attn_weights = torch.tanh(attn_weights)
  206. attn_weights = attn_weights * softcap
  207. if attention_mask is not None:
  208. attn_weights = attn_weights + attention_mask
  209. # upcast attention to fp32
  210. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  211. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  212. attn_output = torch.matmul(attn_weights, value_states)
  213. attn_output = attn_output.transpose(1, 2).contiguous()
  214. return attn_output, attn_weights
  215. @use_kernelized_func(apply_rotary_pos_emb)
  216. class T5Gemma2SelfAttention(nn.Module):
  217. """Multi-headed attention from 'Attention Is All You Need' paper"""
  218. def __init__(self, config: T5Gemma2TextConfig, layer_idx: int):
  219. super().__init__()
  220. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  221. self.config = config
  222. self.layer_idx = layer_idx
  223. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  224. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  225. self.scaling = config.query_pre_attn_scalar**-0.5
  226. self.attention_dropout = self.config.attention_dropout
  227. self.is_causal = False # Only used by the encoder
  228. self.q_proj = nn.Linear(
  229. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  230. )
  231. self.k_proj = nn.Linear(
  232. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  233. )
  234. self.v_proj = nn.Linear(
  235. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  236. )
  237. self.o_proj = nn.Linear(
  238. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  239. )
  240. self.attn_logit_softcapping = self.config.attn_logit_softcapping
  241. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  242. self.is_sliding = self.layer_type == "sliding_attention"
  243. self.q_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  244. self.k_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  245. def forward(
  246. self,
  247. hidden_states: torch.Tensor,
  248. position_embeddings: torch.Tensor = None,
  249. attention_mask: torch.Tensor | None = None,
  250. past_key_values: Cache | None = None,
  251. **kwargs: Unpack[TransformersKwargs],
  252. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  253. input_shape = hidden_states.shape[:-1]
  254. hidden_shape = (*input_shape, -1, self.head_dim)
  255. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  256. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  257. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  258. query_states = self.q_norm(query_states)
  259. key_states = self.k_norm(key_states)
  260. cos, sin = position_embeddings
  261. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  262. if past_key_values is not None:
  263. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  264. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  265. self.config._attn_implementation, eager_attention_forward
  266. )
  267. attn_output, attn_weights = attention_interface(
  268. self,
  269. query_states,
  270. key_states,
  271. value_states,
  272. attention_mask,
  273. dropout=self.attention_dropout if self.training else 0.0,
  274. scaling=self.scaling,
  275. sliding_window=self.sliding_window,
  276. **kwargs,
  277. )
  278. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  279. attn_output = self.o_proj(attn_output)
  280. return attn_output, attn_weights
  281. @use_kernelized_func(apply_rotary_pos_emb)
  282. class T5Gemma2MergedAttention(nn.Module):
  283. """Merged self-attention and cross-attention for decoder."""
  284. def __init__(self, config: T5Gemma2TextConfig, layer_idx: int):
  285. super().__init__()
  286. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  287. self.config = config
  288. self.layer_idx = layer_idx
  289. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  290. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  291. self.scaling = config.query_pre_attn_scalar**-0.5
  292. self.attention_dropout = self.config.attention_dropout
  293. self.is_causal = False # Fused causal and encoder mask
  294. self.q_proj = nn.Linear(
  295. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  296. )
  297. self.k_proj = nn.Linear(
  298. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  299. )
  300. self.v_proj = nn.Linear(
  301. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  302. )
  303. self.o_proj = nn.Linear(
  304. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  305. )
  306. self.attn_logit_softcapping = self.config.attn_logit_softcapping
  307. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  308. self.is_sliding = self.layer_type == "sliding_attention"
  309. self.q_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  310. self.k_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  311. def forward(
  312. self,
  313. # decoder self-attention inputs
  314. hidden_states: torch.Tensor,
  315. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  316. merged_attention_mask: torch.Tensor | None,
  317. # cross-attention inputs
  318. encoder_hidden_states: torch.Tensor,
  319. # cache inputs
  320. past_key_values: EncoderDecoderCache | None = None,
  321. # others
  322. **kwargs: Unpack[FlashAttentionKwargs],
  323. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  324. # attention shapes.
  325. input_shape = hidden_states.shape[:-1]
  326. hidden_shape = (*input_shape, -1, self.head_dim)
  327. cross_input_shape = encoder_hidden_states.shape[:-1]
  328. cross_hidden_shape = (*cross_input_shape, -1, self.head_dim)
  329. # self-attention.
  330. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  331. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  332. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  333. query_states = self.q_norm(query_states)
  334. key_states = self.k_norm(key_states)
  335. cos, sin = position_embeddings
  336. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  337. if past_key_values is not None:
  338. # self-attention.
  339. self_attention_cache = past_key_values.self_attention_cache
  340. key_states, value_states = self_attention_cache.update(key_states, value_states, self.layer_idx)
  341. # cross-attention.
  342. is_updated = past_key_values.is_updated.get(self.layer_idx)
  343. cross_attention_cache = past_key_values.cross_attention_cache
  344. if past_key_values is None or not is_updated:
  345. cross_key_states = self.k_proj(encoder_hidden_states).view(cross_hidden_shape).transpose(1, 2)
  346. cross_value_states = self.v_proj(encoder_hidden_states).view(cross_hidden_shape).transpose(1, 2)
  347. cross_key_states = self.k_norm(cross_key_states)
  348. if past_key_values is not None:
  349. cross_key_states, cross_value_states = cross_attention_cache.update(
  350. cross_key_states, cross_value_states, self.layer_idx
  351. )
  352. past_key_values.is_updated[self.layer_idx] = True
  353. else:
  354. cross_key_states = cross_attention_cache.layers[self.layer_idx].keys
  355. cross_value_states = cross_attention_cache.layers[self.layer_idx].values
  356. # merged attention.
  357. query_states = query_states
  358. cross_key_size = cross_input_shape[1]
  359. key_states = torch.cat([key_states, cross_key_states], dim=2)
  360. value_states = torch.cat([value_states, cross_value_states], dim=2)
  361. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  362. self.config._attn_implementation, eager_attention_forward
  363. )
  364. attn_output, attn_weights = attention_interface(
  365. self,
  366. query_states,
  367. key_states,
  368. value_states,
  369. merged_attention_mask,
  370. dropout=self.attention_dropout if self.training else 0.0,
  371. scaling=self.scaling,
  372. **kwargs,
  373. )
  374. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  375. attn_output = self.o_proj(attn_output)
  376. # decompose merged attention weights into self & cross attention weights
  377. if attn_weights is not None:
  378. self_attn_weights = attn_weights[..., :-cross_key_size]
  379. cross_attn_weights = attn_weights[..., -cross_key_size:]
  380. else:
  381. self_attn_weights, cross_attn_weights = None, None
  382. return attn_output, self_attn_weights, cross_attn_weights
  383. class T5Gemma2EncoderLayer(GradientCheckpointingLayer):
  384. """Encoder sub-layer."""
  385. def __init__(self, config, layer_idx: int):
  386. super().__init__()
  387. self.hidden_size = config.hidden_size
  388. self.config = config
  389. self.layer_idx = layer_idx
  390. self.attention_type = config.layer_types[layer_idx]
  391. self.self_attn = T5Gemma2SelfAttention(
  392. config=config,
  393. layer_idx=layer_idx,
  394. )
  395. self.pre_self_attn_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  396. self.post_self_attn_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  397. self.mlp = T5Gemma2MLP(config)
  398. self.pre_feedforward_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  399. self.post_feedforward_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  400. self.dropout = nn.Dropout(config.dropout_rate)
  401. def forward(
  402. self,
  403. hidden_states: torch.Tensor,
  404. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  405. attention_mask: torch.Tensor | None = None,
  406. position_ids: torch.LongTensor | None = None,
  407. **kwargs,
  408. ) -> tuple[torch.FloatTensor,]:
  409. residual = hidden_states
  410. hidden_states = self.pre_self_attn_layernorm(hidden_states)
  411. hidden_states, _ = self.self_attn(
  412. hidden_states=hidden_states,
  413. position_embeddings=position_embeddings,
  414. attention_mask=attention_mask,
  415. position_ids=position_ids,
  416. past_key_values=None,
  417. **kwargs,
  418. )
  419. hidden_states = self.post_self_attn_layernorm(hidden_states)
  420. hidden_states = residual + self.dropout(hidden_states)
  421. residual = hidden_states
  422. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  423. hidden_states = self.mlp(hidden_states)
  424. hidden_states = self.post_feedforward_layernorm(hidden_states)
  425. hidden_states = residual + self.dropout(hidden_states)
  426. return hidden_states
  427. class T5Gemma2DecoderLayer(GradientCheckpointingLayer):
  428. """Decoder sub-layer: merged attention instead of vanilla self-attention."""
  429. def __init__(self, config, layer_idx: int):
  430. super().__init__()
  431. self.hidden_size = config.hidden_size
  432. self.config = config
  433. self.layer_idx = layer_idx
  434. self.attention_type = config.layer_types[layer_idx]
  435. # replace vanilla self-attention with merged attention to support joint cross-attention.
  436. self.self_attn = T5Gemma2MergedAttention(
  437. config=config,
  438. layer_idx=layer_idx,
  439. )
  440. self.pre_self_attn_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  441. self.post_self_attn_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  442. self.mlp = T5Gemma2MLP(config)
  443. self.pre_feedforward_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  444. self.post_feedforward_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  445. self.dropout = nn.Dropout(config.dropout_rate)
  446. def forward(
  447. self,
  448. hidden_states: torch.Tensor,
  449. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  450. merged_attention_mask: torch.Tensor | None = None,
  451. position_ids: torch.LongTensor | None = None,
  452. past_key_values: EncoderDecoderCache | None = None,
  453. use_cache: bool | None = False,
  454. encoder_hidden_states: torch.Tensor | None = None,
  455. **kwargs,
  456. ) -> torch.FloatTensor:
  457. residual = hidden_states
  458. hidden_states = self.pre_self_attn_layernorm(hidden_states)
  459. hidden_states, _, _ = self.self_attn(
  460. hidden_states=hidden_states,
  461. position_embeddings=position_embeddings,
  462. merged_attention_mask=merged_attention_mask,
  463. position_ids=position_ids,
  464. past_key_values=past_key_values,
  465. use_cache=use_cache,
  466. encoder_hidden_states=encoder_hidden_states,
  467. **kwargs,
  468. )
  469. hidden_states = self.post_self_attn_layernorm(hidden_states)
  470. hidden_states = residual + self.dropout(hidden_states)
  471. residual = hidden_states
  472. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  473. hidden_states = self.mlp(hidden_states)
  474. hidden_states = self.post_feedforward_layernorm(hidden_states)
  475. hidden_states = residual + self.dropout(hidden_states)
  476. return hidden_states
  477. class T5Gemma2LMHead(nn.Module):
  478. """Head for language modeling (generation) tasks."""
  479. def __init__(self, hidden_size: int, vocab_size: int, bias: bool = False):
  480. super().__init__()
  481. self.out_proj = nn.Linear(hidden_size, vocab_size, bias=bias)
  482. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  483. logits = self.out_proj(hidden_states)
  484. return logits
  485. class T5Gemma2ClassificationHead(nn.Module):
  486. """Head for sentence-level classification tasks."""
  487. def __init__(self, hidden_size: int, num_labels: int, classifier_dropout_rate: float = 0.0):
  488. super().__init__()
  489. self.dropout = nn.Dropout(p=classifier_dropout_rate)
  490. self.out_proj = nn.Linear(hidden_size, num_labels)
  491. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  492. hidden_states = self.dropout(hidden_states)
  493. hidden_states = self.out_proj(hidden_states)
  494. return hidden_states
  495. class T5Gemma2MultiModalProjector(nn.Module):
  496. def __init__(self, config: T5Gemma2EncoderConfig):
  497. super().__init__()
  498. self.mm_input_projection_weight = nn.Parameter(
  499. torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
  500. )
  501. self.mm_soft_emb_norm = T5Gemma2RMSNorm(
  502. config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
  503. )
  504. self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)
  505. self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
  506. self.kernel_size = self.patches_per_image // self.tokens_per_side
  507. self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
  508. def forward(self, vision_outputs: torch.Tensor):
  509. batch_size, _, hidden_size = vision_outputs.shape
  510. reshaped_vision_outputs = vision_outputs.transpose(1, 2)
  511. reshaped_vision_outputs = reshaped_vision_outputs.reshape(
  512. batch_size, hidden_size, self.patches_per_image, self.patches_per_image
  513. )
  514. reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
  515. pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
  516. pooled_vision_outputs = pooled_vision_outputs.flatten(2)
  517. pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
  518. normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
  519. projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)
  520. return projected_vision_outputs.type_as(vision_outputs)
  521. class T5Gemma2TextScaledWordEmbedding(nn.Embedding):
  522. """T5Gemma2 Embedding: override to add eoi token embedding separately."""
  523. def __init__(
  524. self,
  525. num_embeddings: int,
  526. embedding_dim: int,
  527. padding_idx: int,
  528. embed_scale: float = 1.0,
  529. eoi_token_index: int = 256_000,
  530. ):
  531. super().__init__(num_embeddings, embedding_dim, padding_idx)
  532. self.scalar_embed_scale = embed_scale
  533. self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
  534. self.eoi_token_index = eoi_token_index
  535. self.eoi_embedding = nn.Parameter(torch.zeros(self.embedding_dim))
  536. def forward(self, input_ids: torch.Tensor):
  537. input_embeddings = super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
  538. input_embeddings[input_ids == self.eoi_token_index] = self.eoi_embedding.to(input_embeddings.dtype)
  539. return input_embeddings
  540. @auto_docstring
  541. class T5Gemma2PreTrainedModel(PreTrainedModel):
  542. config: T5Gemma2Config
  543. base_model_prefix = "model"
  544. supports_gradient_checkpointing = True
  545. _no_split_modules = [
  546. "T5Gemma2EncoderLayer",
  547. "T5Gemma2DecoderLayer",
  548. "SiglipVisionEmbeddings",
  549. "SiglipEncoderLayer",
  550. "SiglipMultiheadAttentionPoolingHead",
  551. ]
  552. _skip_keys_device_placement = ["past_key_values"]
  553. # Mask creation is incompatible
  554. # FA due to non-default creation / SWA
  555. _supports_flash_attn = False
  556. _supports_sdpa = True
  557. # Flex due to custom masks not compatible to be merged after creation
  558. _supports_flex_attn = False
  559. _can_compile_fullgraph = True
  560. _supports_attention_backend = True
  561. _can_record_outputs = {
  562. "hidden_states": [T5Gemma2EncoderLayer, T5Gemma2DecoderLayer],
  563. "attentions": [
  564. OutputRecorder(T5Gemma2SelfAttention, index=1, layer_name="self_attn"),
  565. OutputRecorder(T5Gemma2MergedAttention, index=1, layer_name="self_attn"),
  566. OutputRecorder(T5Gemma2MergedAttention, index=2, layer_name="cross_attn"),
  567. ],
  568. }
  569. input_modalities = ("image", "text")
  570. @torch.no_grad()
  571. def _init_weights(self, module):
  572. super()._init_weights(module)
  573. if isinstance(module, T5Gemma2MultiModalProjector):
  574. init.zeros_(module.mm_input_projection_weight)
  575. elif isinstance(module, T5Gemma2TextScaledWordEmbedding):
  576. init.zeros_(module.eoi_embedding)
  577. init.constant_(module.embed_scale, module.scalar_embed_scale)
  578. elif isinstance(module, T5Gemma2ClassificationHead):
  579. scale = module.out_proj.weight.shape[0] ** -0.5
  580. init.normal_(module.out_proj.weight, mean=0.0, std=self.config.initializer_range * scale)
  581. if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
  582. init.zeros_(module.out_proj.bias)
  583. # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
  584. elif "RMSNorm" in module.__class__.__name__:
  585. init.zeros_(module.weight)
  586. elif isinstance(module, T5Gemma2RotaryEmbedding):
  587. for layer_type in module.layer_types:
  588. rope_init_fn = module.compute_default_rope_parameters
  589. if module.rope_type[layer_type] != "default":
  590. rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
  591. curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
  592. init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
  593. init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
  594. def prepare_decoder_input_ids_from_labels(self, input_ids):
  595. """
  596. Shifts input_ids to the right, prepends the decoder_start_token_id, and handles
  597. pad_token_id replacement for labels that were -100.
  598. This is a common preparation step for decoder inputs in sequence-to-sequence models.
  599. """
  600. decoder_config = self.config.decoder
  601. decoder_start_token_id = decoder_config.bos_token_id
  602. pad_token_id = decoder_config.pad_token_id
  603. if decoder_start_token_id is None:
  604. raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ")
  605. # shift inputs to the right
  606. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  607. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  608. shifted_input_ids[..., 0] = decoder_start_token_id
  609. if pad_token_id is None:
  610. raise ValueError("self.model.config.decoder.pad_token_id has to be defined.")
  611. # Is this T5 specific?
  612. # replace possible -100 values in labels by `pad_token_id`
  613. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  614. return shifted_input_ids
  615. def sliding_window_mask_function(sliding_window: int, is_causal=True) -> Callable:
  616. """
  617. This creates uni/bidirectional attention mask with sliding window.
  618. """
  619. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  620. if is_causal:
  621. left_window_size, right_window_size = sliding_window, 0
  622. else:
  623. left_window_size, right_window_size = ((sliding_window + 1) // 2, (sliding_window) // 2 + 1)
  624. dist = q_idx - kv_idx
  625. left_mask = (dist >= 0) & (dist < left_window_size)
  626. right_mask = (dist < 0) & (-dist < right_window_size)
  627. return left_mask | right_mask
  628. return inner_mask
  629. class T5Gemma2TextEncoder(T5Gemma2PreTrainedModel):
  630. config: T5Gemma2TextConfig
  631. _can_record_outputs = {
  632. "attentions": T5Gemma2SelfAttention,
  633. "hidden_states": T5Gemma2EncoderLayer,
  634. }
  635. def __init__(
  636. self,
  637. config: T5Gemma2TextConfig,
  638. eoi_token_index: int = 256_000,
  639. ):
  640. super().__init__(config)
  641. self.padding_idx = config.pad_token_id
  642. self.vocab_size = config.vocab_size
  643. self.embed_tokens = T5Gemma2TextScaledWordEmbedding(
  644. config.vocab_size,
  645. config.hidden_size,
  646. self.padding_idx,
  647. embed_scale=config.hidden_size**0.5,
  648. eoi_token_index=eoi_token_index,
  649. )
  650. self.norm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  651. self.gradient_checkpointing = False
  652. self.layers = nn.ModuleList(
  653. [T5Gemma2EncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  654. )
  655. self.dropout = nn.Dropout(config.dropout_rate)
  656. self.rotary_emb = T5Gemma2RotaryEmbedding(config)
  657. # Initialize weights and apply final processing
  658. self.post_init()
  659. @merge_with_config_defaults
  660. @capture_outputs
  661. @auto_docstring
  662. def forward(
  663. self,
  664. input_ids: torch.LongTensor | None = None,
  665. attention_mask: torch.Tensor | None = None,
  666. position_ids: torch.LongTensor | None = None,
  667. inputs_embeds: torch.FloatTensor | None = None,
  668. # Unused for processor compatibility kept in signature.
  669. token_type_ids: torch.Tensor | None = None,
  670. **kwargs: Unpack[TransformersKwargs],
  671. ) -> BaseModelOutput:
  672. if (input_ids is None) ^ (inputs_embeds is not None):
  673. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  674. # As we want to pass `past_key_values=None` explicitly everywhere, we need to pop them from kwargs if present
  675. kwargs.pop("past_key_values", None)
  676. if inputs_embeds is None:
  677. inputs_embeds = self.embed_tokens(input_ids)
  678. if position_ids is None:
  679. position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
  680. if not isinstance(self_attn_mask_mapping := attention_mask, dict):
  681. mask_kwargs = {
  682. "config": self.config,
  683. "inputs_embeds": inputs_embeds,
  684. "attention_mask": attention_mask,
  685. }
  686. self_attn_mask_mapping = {
  687. "full_attention": create_bidirectional_mask(**mask_kwargs),
  688. "sliding_attention": create_bidirectional_mask(
  689. **mask_kwargs,
  690. and_mask_function=sliding_window_mask_function(self.config.sliding_window, is_causal=False),
  691. ),
  692. }
  693. # input layer
  694. hidden_states = inputs_embeds
  695. # global and local position embeddings
  696. position_embeddings = {}
  697. for layer_type in self.config.layer_types:
  698. position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
  699. # dropout
  700. hidden_states = self.dropout(hidden_states)
  701. for i, layer_module in enumerate(self.layers[: self.config.num_hidden_layers]):
  702. hidden_states = layer_module(
  703. hidden_states,
  704. position_embeddings[self.config.layer_types[i]],
  705. self_attn_mask_mapping[self.config.layer_types[i]],
  706. position_ids,
  707. **kwargs,
  708. )
  709. hidden_states = self.norm(hidden_states)
  710. hidden_states = self.dropout(hidden_states)
  711. return BaseModelOutput(
  712. last_hidden_state=hidden_states,
  713. )
  714. class T5Gemma2Encoder(T5Gemma2PreTrainedModel):
  715. config: T5Gemma2EncoderConfig
  716. def __init__(
  717. self,
  718. config: T5Gemma2EncoderConfig,
  719. eoi_token_index: int = 256_000,
  720. ):
  721. super().__init__(config)
  722. self.text_model = T5Gemma2TextEncoder._from_config(config.text_config, eoi_token_index=eoi_token_index)
  723. self.vision_tower = AutoModel.from_config(config=config.vision_config)
  724. self.multi_modal_projector = T5Gemma2MultiModalProjector(config)
  725. # Initialize weights and apply final processing
  726. self.post_init()
  727. def get_input_embeddings(self):
  728. return self.text_model.get_input_embeddings()
  729. def set_input_embeddings(self, new_embeddings):
  730. return self.text_model.set_input_embeddings(new_embeddings)
  731. @can_return_tuple
  732. @auto_docstring
  733. def get_image_features(
  734. self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
  735. ) -> tuple | BaseModelOutputWithPooling:
  736. # pixel_values: (batch_size, channels, height, width)
  737. # image_features: Image feature tensor of shape (num_images, image_length, embed_dim).
  738. vision_outputs = self.vision_tower(pixel_values=pixel_values, return_dict=True, **kwargs)
  739. last_hidden_state = vision_outputs.last_hidden_state
  740. image_features = self.multi_modal_projector(last_hidden_state)
  741. vision_outputs.pooler_output = image_features
  742. return vision_outputs
  743. def get_image_placeholder_mask(
  744. self,
  745. input_ids: torch.LongTensor | None,
  746. inputs_embeds: torch.FloatTensor | None,
  747. image_features: torch.FloatTensor,
  748. ):
  749. """
  750. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  751. equal to the length of multimodal features. If the lengths are different, an error is raised.
  752. """
  753. image_token_id = self.config.image_token_id
  754. if input_ids is None:
  755. if inputs_embeds is None:
  756. raise ValueError("Either `input_ids` or `inputs_embeds` has to be provided.")
  757. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  758. torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
  759. )
  760. special_image_mask = special_image_mask.all(-1)
  761. else:
  762. special_image_mask = input_ids == image_token_id
  763. n_image_tokens = special_image_mask.sum()
  764. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  765. n_image_features = image_features.shape[0] * image_features.shape[1]
  766. torch_compilable_check(
  767. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  768. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}",
  769. )
  770. return special_image_mask
  771. @auto_docstring
  772. def forward(
  773. self,
  774. input_ids: torch.LongTensor | None = None,
  775. attention_mask: torch.Tensor | None = None,
  776. position_ids: torch.LongTensor | None = None,
  777. inputs_embeds: torch.FloatTensor | None = None,
  778. pixel_values: torch.FloatTensor | None = None,
  779. # Unused for processor compatibility kept in signature.
  780. token_type_ids: torch.Tensor | None = None,
  781. **kwargs: Unpack[TransformersKwargs],
  782. ) -> BaseModelOutput:
  783. if (input_ids is None) ^ (inputs_embeds is not None):
  784. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  785. if inputs_embeds is None:
  786. inputs_embeds = self.text_model.embed_tokens(input_ids)
  787. if pixel_values is not None:
  788. image_features = self.get_image_features(pixel_values, return_dict=True).pooler_output
  789. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  790. image_mask = self.get_image_placeholder_mask(
  791. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  792. )
  793. inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features)
  794. outputs = self.text_model(
  795. inputs_embeds=inputs_embeds,
  796. attention_mask=attention_mask,
  797. position_ids=position_ids,
  798. **kwargs,
  799. )
  800. return outputs
  801. class T5Gemma2Decoder(T5Gemma2PreTrainedModel):
  802. config: T5Gemma2DecoderConfig
  803. _can_record_outputs = {
  804. "attentions": OutputRecorder(T5Gemma2MergedAttention, index=1),
  805. "cross_attentions": OutputRecorder(T5Gemma2MergedAttention, index=2),
  806. "hidden_states": T5Gemma2DecoderLayer,
  807. }
  808. def __init__(self, config: T5Gemma2DecoderConfig, eoi_token_index: int = 256_000):
  809. super().__init__(config)
  810. self.padding_idx = config.pad_token_id
  811. self.vocab_size = config.vocab_size
  812. self.embed_tokens = T5Gemma2TextScaledWordEmbedding(
  813. config.vocab_size,
  814. config.hidden_size,
  815. config.pad_token_id,
  816. embed_scale=config.hidden_size**0.5,
  817. eoi_token_index=eoi_token_index,
  818. )
  819. self.norm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  820. self.gradient_checkpointing = False
  821. self.layers = nn.ModuleList(
  822. [T5Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  823. )
  824. self.dropout = nn.Dropout(config.dropout_rate)
  825. self.rotary_emb = T5Gemma2RotaryEmbedding(config)
  826. self.post_init()
  827. @merge_with_config_defaults
  828. @capture_outputs
  829. @auto_docstring
  830. def forward(
  831. self,
  832. input_ids: torch.LongTensor | None = None,
  833. attention_mask: torch.Tensor | None = None,
  834. position_ids: torch.LongTensor | None = None,
  835. past_key_values: EncoderDecoderCache | None = None,
  836. inputs_embeds: torch.FloatTensor | None = None,
  837. use_cache: bool | None = None,
  838. encoder_hidden_states: torch.Tensor | None = None,
  839. encoder_attention_mask: torch.Tensor | None = None,
  840. **kwargs: Unpack[TransformersKwargs],
  841. ) -> BaseModelOutputWithPastAndCrossAttentions:
  842. if (input_ids is None) ^ (inputs_embeds is not None):
  843. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  844. if encoder_hidden_states is None:
  845. raise ValueError("`encoder_hidden_states` must be given in decoder")
  846. if inputs_embeds is None:
  847. inputs_embeds = self.embed_tokens(input_ids)
  848. if not self.training and use_cache and past_key_values is None:
  849. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache())
  850. if position_ids is None:
  851. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  852. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  853. position_ids = position_ids.unsqueeze(0)
  854. if not isinstance(self_attn_mask_mapping := attention_mask, dict):
  855. # this masking function does nothing to masking but forces `allow_is_causal_skip` to be False
  856. # as we always need a mask during decoding for merged attention.
  857. dummy_and_mask_function = lambda *args: torch.tensor(True, dtype=torch.bool) # noqa
  858. mask_kwargs = {
  859. "config": self.config,
  860. "inputs_embeds": inputs_embeds,
  861. "attention_mask": attention_mask,
  862. "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None,
  863. "position_ids": position_ids,
  864. "and_mask_function": dummy_and_mask_function,
  865. }
  866. self_attn_mask_mapping = {
  867. "full_attention": create_causal_mask(**mask_kwargs),
  868. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  869. }
  870. if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict):
  871. cross_attn_mask_mapping = {
  872. "full_attention": create_bidirectional_mask(
  873. config=self.config,
  874. inputs_embeds=inputs_embeds,
  875. attention_mask=encoder_attention_mask,
  876. encoder_hidden_states=encoder_hidden_states,
  877. and_mask_function=dummy_and_mask_function,
  878. )
  879. }
  880. merged_attn_mask_mapping = {
  881. "full_attention": torch.cat(
  882. [self_attn_mask_mapping["full_attention"], cross_attn_mask_mapping["full_attention"]], dim=-1
  883. ),
  884. "sliding_attention": torch.cat(
  885. [self_attn_mask_mapping["sliding_attention"], cross_attn_mask_mapping["full_attention"]], dim=-1
  886. ),
  887. }
  888. # input layer
  889. hidden_states = inputs_embeds
  890. # global and local position embeddings
  891. position_embeddings = {}
  892. for layer_type in self.config.layer_types:
  893. position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
  894. # dropout
  895. hidden_states = self.dropout(hidden_states)
  896. for i, layer_module in enumerate(self.layers[: self.config.num_hidden_layers]):
  897. hidden_states = layer_module(
  898. hidden_states,
  899. position_embeddings[self.config.layer_types[i]],
  900. merged_attn_mask_mapping[self.config.layer_types[i]],
  901. position_ids,
  902. past_key_values,
  903. use_cache,
  904. encoder_hidden_states,
  905. **kwargs,
  906. )
  907. hidden_states = self.norm(hidden_states)
  908. hidden_states = self.dropout(hidden_states)
  909. return BaseModelOutputWithPastAndCrossAttentions(
  910. last_hidden_state=hidden_states,
  911. past_key_values=past_key_values,
  912. )
  913. @auto_docstring
  914. class T5Gemma2Model(T5Gemma2PreTrainedModel):
  915. _tied_weights_keys = {
  916. "decoder.embed_tokens.weight": "encoder.text_model.embed_tokens.weight",
  917. "decoder.embed_tokens.eoi_embedding": "encoder.text_model.embed_tokens.eoi_embedding",
  918. }
  919. def __init__(self, config: T5Gemma2Config):
  920. super().__init__(config)
  921. # setup encoder and decoder
  922. self.encoder = T5Gemma2Encoder(config.encoder, config.eoi_token_index)
  923. self.decoder = T5Gemma2Decoder(config.decoder, config.eoi_token_index)
  924. self.post_init()
  925. def get_encoder(self):
  926. return self.encoder
  927. def get_decoder(self):
  928. return self.decoder
  929. def get_input_embeddings(self):
  930. return self.encoder.get_input_embeddings()
  931. def set_input_embeddings(self, new_embeddings):
  932. return self.encoder.set_input_embeddings(new_embeddings)
  933. @can_return_tuple
  934. @auto_docstring
  935. def forward(
  936. self,
  937. # encoder inputs
  938. input_ids: torch.LongTensor | None = None,
  939. pixel_values: torch.FloatTensor | None = None,
  940. attention_mask: torch.FloatTensor | None = None,
  941. position_ids: torch.LongTensor | None = None,
  942. # decoder inputs
  943. decoder_input_ids: torch.LongTensor | None = None,
  944. decoder_attention_mask: torch.BoolTensor | None = None,
  945. decoder_position_ids: torch.LongTensor | None = None,
  946. # others (mainly inference or cache related)
  947. encoder_outputs: BaseModelOutput | None = None,
  948. past_key_values: EncoderDecoderCache | None = None,
  949. inputs_embeds: torch.Tensor | None = None,
  950. decoder_inputs_embeds: torch.Tensor | None = None,
  951. use_cache: bool | None = None,
  952. **kwargs: Unpack[TransformersKwargs],
  953. ) -> Seq2SeqModelOutput:
  954. r"""
  955. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  956. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  957. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  958. """
  959. # encoder
  960. if encoder_outputs is None:
  961. encoder_outputs = self.encoder(
  962. input_ids=input_ids,
  963. attention_mask=attention_mask,
  964. position_ids=position_ids,
  965. inputs_embeds=inputs_embeds,
  966. pixel_values=pixel_values,
  967. return_dict=True,
  968. **kwargs,
  969. )
  970. encoder_hidden_states = encoder_outputs.last_hidden_state
  971. # decoder
  972. decoder_outputs = self.decoder(
  973. input_ids=decoder_input_ids,
  974. attention_mask=decoder_attention_mask,
  975. position_ids=decoder_position_ids,
  976. inputs_embeds=decoder_inputs_embeds,
  977. past_key_values=past_key_values,
  978. encoder_hidden_states=encoder_hidden_states,
  979. encoder_attention_mask=attention_mask,
  980. use_cache=use_cache,
  981. return_dict=True,
  982. **kwargs,
  983. )
  984. return Seq2SeqModelOutput(
  985. last_hidden_state=decoder_outputs.last_hidden_state,
  986. past_key_values=decoder_outputs.past_key_values,
  987. decoder_hidden_states=decoder_outputs.hidden_states,
  988. decoder_attentions=decoder_outputs.attentions,
  989. cross_attentions=decoder_outputs.cross_attentions,
  990. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  991. encoder_hidden_states=encoder_outputs.hidden_states,
  992. encoder_attentions=encoder_outputs.attentions,
  993. )
  994. class T5Gemma2ForConditionalGeneration(T5Gemma2PreTrainedModel, GenerationMixin):
  995. _tied_weights_keys = {
  996. "lm_head.out_proj.weight": "model.encoder.text_model.embed_tokens.weight",
  997. }
  998. _tp_plan = {"lm_head.out_proj": "colwise_gather_output"}
  999. _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])}
  1000. def __init__(self, config: T5Gemma2Config):
  1001. super().__init__(config)
  1002. self.model = T5Gemma2Model(config)
  1003. self.vocab_size = config.decoder.vocab_size
  1004. self.lm_head = T5Gemma2LMHead(config.decoder.hidden_size, self.vocab_size)
  1005. self.loss_type = "ForMaskedLM"
  1006. self.post_init()
  1007. def set_output_embeddings(self, new_embeddings):
  1008. self.lm_head.out_proj = new_embeddings
  1009. def get_output_embeddings(self):
  1010. return self.lm_head.out_proj
  1011. def get_input_embeddings(self):
  1012. return self.model.get_input_embeddings()
  1013. def set_input_embeddings(self, value):
  1014. self.model.set_input_embeddings(value)
  1015. def get_encoder(self):
  1016. return self.model.get_encoder()
  1017. def get_decoder(self):
  1018. return self.model.get_decoder()
  1019. @can_return_tuple
  1020. @auto_docstring
  1021. def get_image_features(
  1022. self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
  1023. ) -> tuple | BaseModelOutputWithPooling:
  1024. return self.get_encoder().get_image_features(pixel_values, **kwargs)
  1025. @property
  1026. def vision_tower(self):
  1027. return self.get_encoder().vision_tower
  1028. @can_return_tuple
  1029. @auto_docstring
  1030. def forward(
  1031. self,
  1032. # encoder inputs
  1033. input_ids: torch.LongTensor | None = None,
  1034. pixel_values: torch.FloatTensor | None = None,
  1035. attention_mask: torch.FloatTensor | None = None,
  1036. position_ids: torch.LongTensor | None = None,
  1037. # decoder inputs
  1038. decoder_input_ids: torch.LongTensor | None = None,
  1039. decoder_attention_mask: torch.BoolTensor | None = None,
  1040. decoder_position_ids: torch.LongTensor | None = None,
  1041. # others (mainly inference or cache related)
  1042. encoder_outputs: BaseModelOutput | None = None,
  1043. past_key_values: EncoderDecoderCache | None = None,
  1044. inputs_embeds: torch.FloatTensor | None = None,
  1045. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1046. labels: torch.LongTensor | None = None,
  1047. use_cache: bool | None = None,
  1048. logits_to_keep: int | torch.Tensor = 0,
  1049. **kwargs: Unpack[TransformersKwargs],
  1050. ) -> tuple[torch.FloatTensor] | Seq2SeqLMOutput:
  1051. r"""
  1052. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  1053. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  1054. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  1055. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1056. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1057. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1058. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1059. """
  1060. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  1061. # get decoder inputs from shifting lm labels to the right
  1062. decoder_input_ids = self.prepare_decoder_input_ids_from_labels(labels)
  1063. decoder_outputs: Seq2SeqModelOutput = self.model(
  1064. input_ids=input_ids,
  1065. pixel_values=pixel_values,
  1066. attention_mask=attention_mask,
  1067. position_ids=position_ids,
  1068. decoder_input_ids=decoder_input_ids,
  1069. decoder_attention_mask=decoder_attention_mask,
  1070. decoder_position_ids=decoder_position_ids,
  1071. encoder_outputs=encoder_outputs,
  1072. past_key_values=past_key_values,
  1073. inputs_embeds=inputs_embeds,
  1074. decoder_inputs_embeds=decoder_inputs_embeds,
  1075. use_cache=use_cache,
  1076. **kwargs,
  1077. )
  1078. hidden_states = decoder_outputs.last_hidden_state
  1079. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1080. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1081. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1082. decoder_config = self.config.decoder
  1083. if decoder_config.final_logit_softcapping is not None:
  1084. logits = logits / decoder_config.final_logit_softcapping
  1085. logits = torch.tanh(logits)
  1086. logits = logits * decoder_config.final_logit_softcapping
  1087. loss = None
  1088. if labels is not None:
  1089. # Input has right-shifted so we directly perform masked lm loss
  1090. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  1091. return Seq2SeqLMOutput(
  1092. loss=loss,
  1093. logits=logits,
  1094. past_key_values=decoder_outputs.past_key_values,
  1095. decoder_hidden_states=decoder_outputs.decoder_hidden_states,
  1096. decoder_attentions=decoder_outputs.decoder_attentions,
  1097. cross_attentions=decoder_outputs.cross_attentions,
  1098. encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state,
  1099. encoder_hidden_states=decoder_outputs.encoder_hidden_states,
  1100. encoder_attentions=decoder_outputs.encoder_attentions,
  1101. )
  1102. def _prepare_cache_for_generation(
  1103. self,
  1104. generation_config: GenerationConfig,
  1105. model_kwargs: dict,
  1106. generation_mode: GenerationMode,
  1107. batch_size: int,
  1108. max_cache_length: int,
  1109. ) -> bool:
  1110. """Override cache preparation to support T5Gemma2-specific EncoderDecoder Cache."""
  1111. # Build cache and past_key_values structure first and then override as needed.
  1112. super()._prepare_cache_for_generation(
  1113. generation_config,
  1114. model_kwargs,
  1115. generation_mode,
  1116. batch_size,
  1117. max_cache_length,
  1118. )
  1119. # If use_cache is False, do not prepare the cache.
  1120. if generation_config.use_cache is False:
  1121. return
  1122. cache_implementation = generation_config.cache_implementation
  1123. if cache_implementation is None:
  1124. offload_cache = False
  1125. else:
  1126. offload_cache = "offloaded" in generation_config.cache_implementation
  1127. # Main change: use full cache for cross-attention.
  1128. cross_attn_config = copy.deepcopy(self.config.get_text_config(decoder=True))
  1129. # cross-attention does not use sliding window
  1130. del cross_attn_config.sliding_window
  1131. del cross_attn_config.layer_types
  1132. cross_attn_cache_kwargs = {
  1133. "config": cross_attn_config,
  1134. "offloading": offload_cache,
  1135. }
  1136. past_key_values = model_kwargs.get("past_key_values")
  1137. if past_key_values is not None:
  1138. if not isinstance(past_key_values, EncoderDecoderCache):
  1139. raise ValueError(
  1140. "The `past_key_values` in `model_kwargs` must be of type `EncoderDecoderCache` for T5Gemma2 model."
  1141. )
  1142. # Cache already established, no need to re-initialize.
  1143. if len(past_key_values.is_updated) > 0 and past_key_values.is_updated.get(0):
  1144. return
  1145. cross_attn_cls = type(past_key_values.cross_attention_cache)
  1146. if cross_attn_cls == StaticCache:
  1147. cross_attn_cache_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
  1148. # Update cross-attention cache only (switch from sliding_window to full).
  1149. past_key_values.cross_attention_cache = cross_attn_cls(**cross_attn_cache_kwargs)
  1150. else:
  1151. # Initialize new cache.
  1152. model_kwargs["past_key_values"] = EncoderDecoderCache(
  1153. DynamicCache(
  1154. **{
  1155. "config": self.config.get_text_config(decoder=True),
  1156. "offloading": offload_cache,
  1157. }
  1158. ), # self-attention cache
  1159. DynamicCache(), # cross-attention cache
  1160. )
  1161. if hasattr(self, "_cache") and self._cache is not None:
  1162. if not isinstance(self._cache, EncoderDecoderCache):
  1163. raise ValueError("The internal cache must be of type `EncoderDecoderCache` for T5Gemma2 model.")
  1164. self._cache = model_kwargs["past_key_values"]
  1165. @auto_docstring
  1166. class T5Gemma2ForSequenceClassification(T5Gemma2PreTrainedModel):
  1167. def __init__(self, config: T5Gemma2Config):
  1168. super().__init__(config)
  1169. self.num_labels = config.num_labels
  1170. self.hidden_size = config.decoder.hidden_size
  1171. self.model = T5Gemma2Model(config)
  1172. classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1)
  1173. self.score = T5Gemma2ClassificationHead(self.hidden_size, self.num_labels, classifier_dropout)
  1174. self.post_init()
  1175. def get_input_embeddings(self):
  1176. return self.model.get_input_embeddings()
  1177. def set_input_embeddings(self, value):
  1178. self.model.set_input_embeddings(value)
  1179. @can_return_tuple
  1180. @auto_docstring
  1181. def forward(
  1182. self,
  1183. input_ids: torch.LongTensor | None = None,
  1184. pixel_values: torch.FloatTensor | None = None,
  1185. attention_mask: torch.Tensor | None = None,
  1186. position_ids: torch.LongTensor | None = None,
  1187. decoder_input_ids: torch.LongTensor | None = None,
  1188. decoder_attention_mask: torch.Tensor | None = None,
  1189. decoder_position_ids: torch.LongTensor | None = None,
  1190. encoder_outputs: BaseModelOutput | None = None,
  1191. inputs_embeds: torch.FloatTensor | None = None,
  1192. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1193. labels: torch.LongTensor | None = None,
  1194. **kwargs: Unpack[TransformersKwargs],
  1195. ) -> SequenceClassifierOutput:
  1196. r"""
  1197. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  1198. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  1199. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  1200. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1201. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1202. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1203. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1204. """
  1205. if inputs_embeds is not None or decoder_inputs_embeds is not None:
  1206. raise NotImplementedError(
  1207. f"Passing input embeddings is currently not supported for {self.__class__.__name__}."
  1208. )
  1209. if input_ids is None:
  1210. raise ValueError("You have to specify input_ids")
  1211. if decoder_input_ids is None:
  1212. decoder_input_ids = self.prepare_decoder_input_ids_from_labels(input_ids)
  1213. outputs: Seq2SeqModelOutput = self.model(
  1214. input_ids,
  1215. pixel_values=pixel_values,
  1216. attention_mask=attention_mask,
  1217. position_ids=position_ids,
  1218. decoder_input_ids=decoder_input_ids,
  1219. decoder_attention_mask=decoder_attention_mask,
  1220. decoder_position_ids=decoder_position_ids,
  1221. encoder_outputs=encoder_outputs,
  1222. inputs_embeds=inputs_embeds,
  1223. decoder_inputs_embeds=decoder_inputs_embeds,
  1224. use_cache=False,
  1225. **kwargs,
  1226. )
  1227. last_hidden_state = outputs.last_hidden_state
  1228. hidden_states = outputs.decoder_hidden_states
  1229. attentions = outputs.decoder_attentions
  1230. logits = self.score(last_hidden_state)
  1231. batch_size = input_ids.shape[0]
  1232. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  1233. non_pad_mask = (decoder_input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  1234. token_indices = torch.arange(decoder_input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  1235. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  1236. last_non_pad_token = torch.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1)
  1237. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  1238. loss = None
  1239. if labels is not None:
  1240. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  1241. return SequenceClassifierOutput(
  1242. loss=loss,
  1243. logits=pooled_logits,
  1244. hidden_states=hidden_states,
  1245. attentions=attentions,
  1246. )
  1247. @auto_docstring
  1248. class T5Gemma2ForTokenClassification(T5Gemma2PreTrainedModel):
  1249. def __init__(self, config: T5Gemma2Config):
  1250. super().__init__(config)
  1251. self.num_labels = config.num_labels
  1252. self.hidden_size = config.decoder.hidden_size
  1253. self.model = T5Gemma2Model(config)
  1254. classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1)
  1255. self.score = T5Gemma2ClassificationHead(self.hidden_size, self.num_labels, classifier_dropout)
  1256. self.post_init()
  1257. def get_input_embeddings(self):
  1258. return self.model.get_input_embeddings()
  1259. def set_input_embeddings(self, value):
  1260. self.model.set_input_embeddings(value)
  1261. @can_return_tuple
  1262. @auto_docstring
  1263. def forward(
  1264. self,
  1265. input_ids: torch.LongTensor | None = None,
  1266. pixel_values: torch.FloatTensor | None = None,
  1267. attention_mask: torch.Tensor | None = None,
  1268. position_ids: torch.LongTensor | None = None,
  1269. decoder_input_ids: torch.LongTensor | None = None,
  1270. decoder_attention_mask: torch.Tensor | None = None,
  1271. decoder_position_ids: torch.LongTensor | None = None,
  1272. encoder_outputs: BaseModelOutput | None = None,
  1273. inputs_embeds: torch.FloatTensor | None = None,
  1274. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1275. labels: torch.LongTensor | None = None,
  1276. **kwargs: Unpack[TransformersKwargs],
  1277. ) -> TokenClassifierOutput:
  1278. r"""
  1279. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  1280. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  1281. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  1282. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1283. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1284. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1285. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1286. """
  1287. if inputs_embeds is not None or decoder_inputs_embeds is not None:
  1288. raise NotImplementedError(
  1289. f"Passing input embeddings is currently not supported for {self.__class__.__name__}."
  1290. )
  1291. if input_ids is None:
  1292. raise ValueError("You have to specify input_ids")
  1293. if decoder_input_ids is None:
  1294. decoder_input_ids = self.prepare_decoder_input_ids_from_labels(input_ids)
  1295. outputs: Seq2SeqModelOutput = self.model(
  1296. input_ids,
  1297. pixel_values=pixel_values,
  1298. attention_mask=attention_mask,
  1299. position_ids=position_ids,
  1300. decoder_input_ids=decoder_input_ids,
  1301. decoder_attention_mask=decoder_attention_mask,
  1302. decoder_position_ids=decoder_position_ids,
  1303. encoder_outputs=encoder_outputs,
  1304. inputs_embeds=inputs_embeds,
  1305. decoder_inputs_embeds=decoder_inputs_embeds,
  1306. use_cache=False,
  1307. **kwargs,
  1308. )
  1309. last_hidden_state = outputs.last_hidden_state
  1310. hidden_states = outputs.decoder_hidden_states
  1311. attentions = outputs.decoder_attentions
  1312. logits = self.score(last_hidden_state)
  1313. loss = None
  1314. if labels is not None:
  1315. loss = self.loss_function(logits, labels, self.config)
  1316. return TokenClassifierOutput(
  1317. loss=loss,
  1318. logits=logits,
  1319. hidden_states=hidden_states,
  1320. attentions=attentions,
  1321. )
  1322. __all__ = [
  1323. "T5Gemma2ForConditionalGeneration",
  1324. "T5Gemma2Model",
  1325. "T5Gemma2Encoder",
  1326. "T5Gemma2PreTrainedModel",
  1327. "T5Gemma2ForSequenceClassification",
  1328. "T5Gemma2ForTokenClassification",
  1329. ]