modular_t5gemma2.py 55 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373
  1. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import copy
  16. from collections.abc import Callable
  17. from typing import Any, Optional
  18. import torch
  19. import torch.nn as nn
  20. from huggingface_hub.dataclasses import strict
  21. from ... import initialization as init
  22. from ...cache_utils import DynamicCache, EncoderDecoderCache, StaticCache
  23. from ...configuration_utils import PreTrainedConfig
  24. from ...generation import GenerationConfig, GenerationMixin, GenerationMode
  25. from ...masking_utils import create_bidirectional_mask
  26. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  27. from ...modeling_outputs import (
  28. BaseModelOutput,
  29. BaseModelOutputWithPastAndCrossAttentions,
  30. BaseModelOutputWithPooling,
  31. Seq2SeqLMOutput,
  32. Seq2SeqModelOutput,
  33. SequenceClassifierOutput,
  34. TokenClassifierOutput,
  35. )
  36. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
  37. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  38. from ...processing_utils import Unpack
  39. from ...utils import (
  40. TransformersKwargs,
  41. auto_docstring,
  42. can_return_tuple,
  43. logging,
  44. torch_compilable_check,
  45. )
  46. from ...utils.generic import merge_with_config_defaults
  47. from ...utils.output_capturing import OutputRecorder, capture_outputs
  48. from ..auto import AutoModel
  49. from ..gemma3.configuration_gemma3 import Gemma3Config, Gemma3TextConfig
  50. from ..gemma3.modeling_gemma3 import (
  51. Gemma3Attention,
  52. Gemma3MLP,
  53. Gemma3MultiModalProjector,
  54. Gemma3PreTrainedModel,
  55. Gemma3RMSNorm,
  56. Gemma3RotaryEmbedding,
  57. Gemma3TextScaledWordEmbedding,
  58. apply_rotary_pos_emb,
  59. create_causal_mask,
  60. create_sliding_window_causal_mask,
  61. eager_attention_forward,
  62. )
  63. from ..siglip import SiglipVisionConfig
  64. from ..t5gemma.modeling_t5gemma import (
  65. T5GemmaClassificationHead,
  66. T5GemmaEncoderLayer,
  67. T5GemmaLMHead,
  68. )
  69. logger = logging.get_logger(__name__)
  70. @auto_docstring(checkpoint="google/t5gemma-2-270m-270m")
  71. @strict
  72. class T5Gemma2TextConfig(Gemma3TextConfig, PreTrainedConfig):
  73. r"""
  74. query_pre_attn_scalar (`float`, *optional*, defaults to 256):
  75. Scaling factor used on the attention scores
  76. final_logit_softcapping (`float`, *optional*):
  77. Scaling factor when applying tanh softcapping on the logits.
  78. attn_logit_softcapping (`float`, *optional*):
  79. Scaling factor when applying tanh softcapping on the attention scores.
  80. """
  81. model_type = "t5gemma2_text"
  82. use_bidirectional_attention = AttributeError()
  83. def __post_init__(self, **kwargs):
  84. # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
  85. _sliding_window_pattern = kwargs.pop("sliding_window_pattern", 6)
  86. if self.layer_types is None:
  87. self.layer_types = [
  88. "sliding_attention" if bool((i + 1) % _sliding_window_pattern) else "full_attention"
  89. for i in range(self.num_hidden_layers)
  90. ]
  91. PreTrainedConfig.__post_init__(**kwargs)
  92. @auto_docstring(checkpoint="google/t5gemma-2-270m-270m")
  93. @strict
  94. class T5Gemma2EncoderConfig(Gemma3Config):
  95. model_type = "t5gemma2_encoder"
  96. sub_configs = {
  97. "text_config": T5Gemma2TextConfig,
  98. "vision_config": SiglipVisionConfig,
  99. }
  100. @auto_docstring(checkpoint="google/t5gemma-2-270m-270m")
  101. @strict
  102. class T5Gemma2DecoderConfig(Gemma3TextConfig, PreTrainedConfig):
  103. r"""
  104. query_pre_attn_scalar (`float`, *optional*, defaults to 256):
  105. Scaling factor used on the attention scores
  106. final_logit_softcapping (`float`, *optional*):
  107. Scaling factor when applying tanh softcapping on the logits.
  108. attn_logit_softcapping (`float`, *optional*):
  109. Scaling factor when applying tanh softcapping on the attention scores.
  110. """
  111. model_type = "t5gemma2_decoder"
  112. use_bidirectional_attention = AttributeError()
  113. def __post_init__(self, **kwargs):
  114. # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
  115. _sliding_window_pattern = kwargs.pop("sliding_window_pattern", 6)
  116. if self.layer_types is None:
  117. self.layer_types = [
  118. "sliding_attention" if bool((i + 1) % _sliding_window_pattern) else "full_attention"
  119. for i in range(self.num_hidden_layers)
  120. ]
  121. PreTrainedConfig.__post_init__(**kwargs)
  122. @auto_docstring(checkpoint="google/t5gemma-2-270m-270m")
  123. @strict
  124. class T5Gemma2Config(PreTrainedConfig):
  125. r"""
  126. encoder (`Union[T5Gemma2EncoderConfig, dict]`, optional, *optional*):
  127. Configuration for the encoder.
  128. decoder (`Union[T5Gemma2DecoderConfig, dict]`, optional, *optional*):
  129. Configuration for the decoder.
  130. eoi_token_index (`int`, *optional*):
  131. The end-of-image token index to wrap the image prompt. Will be same as
  132. `self.encoder.eoi_token_index`
  133. ```python
  134. >>> from transformers import T5Gemma2Config, T5Gemma2Model
  135. >>> t5gemma2_config = T5Gemma2Config.from_pretrained("google/t5gemma-270m-270m")
  136. >>> model = T5Gemma2Model(t5gemma2_config)
  137. ```
  138. """
  139. model_type = "t5gemma2"
  140. keys_to_ignore_at_inference = ["past_key_values"]
  141. sub_configs = {
  142. "encoder": T5Gemma2EncoderConfig,
  143. "decoder": T5Gemma2DecoderConfig,
  144. }
  145. attribute_map = {
  146. "image_token_id": "image_token_index",
  147. "eoi_token_id": "eoi_token_index",
  148. }
  149. encoder: T5Gemma2EncoderConfig | dict[str, Any] | None = None
  150. decoder: T5Gemma2DecoderConfig | dict[str, Any] | None = None
  151. is_encoder_decoder: bool = True
  152. dropout_rate: float | int = 0.0
  153. attention_dropout: float | int = 0.0
  154. classifier_dropout_rate: float | int = 0.0
  155. initializer_range: float = 0.02
  156. image_token_index: int = 256_001
  157. eoi_token_index: int | None = None
  158. tie_word_embeddings: bool = True
  159. def __post_init__(self, **kwargs):
  160. if isinstance(self.encoder, dict):
  161. self.encoder = T5Gemma2EncoderConfig(**self.encoder)
  162. elif self.encoder is None:
  163. self.encoder = T5Gemma2EncoderConfig()
  164. logger.info("encoder is None, using default T5Gemma2EncoderConfig encoder config.")
  165. if isinstance(self.decoder, dict):
  166. self.decoder = T5Gemma2DecoderConfig(**self.decoder)
  167. elif self.decoder is None:
  168. self.decoder = T5Gemma2DecoderConfig()
  169. logger.info("decoder is None, using default T5Gemma2DecoderConfig decoder config.")
  170. self.encoder.text_config.dropout_rate = self.dropout_rate
  171. self.encoder.text_config.attention_dropout = self.attention_dropout
  172. self.encoder.vision_config.attention_dropout = self.attention_dropout
  173. self.encoder.image_token_index = self.image_token_index
  174. self.decoder.dropout_rate = self.dropout_rate
  175. self.decoder.attention_dropout = self.attention_dropout
  176. self.eoi_token_index = self.encoder.eoi_token_index
  177. for special_token_key in ["bos_token_id", "pad_token_id", "eos_token_id", "vocab_size"]:
  178. if special_token_key not in kwargs:
  179. kwargs[special_token_key] = getattr(self.decoder, special_token_key)
  180. super().__post_init__(**kwargs)
  181. def validate_architecture(self):
  182. """Part of `@strict`-powered validation. Validates the architecture of the config."""
  183. if self.encoder.text_config.hidden_size != self.decoder.hidden_size:
  184. raise ValueError(
  185. "Imbalanced encoder-decoder is not supported in T5Gemma2: "
  186. f"encoder ({self.encoder.text_config.hidden_size}) vs decoder ({self.decoder.hidden_size})."
  187. )
  188. if not self.is_encoder_decoder:
  189. raise ValueError("T5Gemma2Model only support encoder-decoder modeling.")
  190. if self.encoder.text_config.vocab_size != self.decoder.vocab_size:
  191. raise ValueError(
  192. "Imbalanced encoder-decoder vocabulary size is not supported in T5Gemma2: "
  193. f"encoder ({self.encoder.text_config.vocab_size}) vs decoder ({self.decoder.vocab_size})."
  194. )
  195. class T5Gemma2RMSNorm(Gemma3RMSNorm):
  196. pass
  197. class T5Gemma2MLP(Gemma3MLP):
  198. def __init__(self, config: T5Gemma2TextConfig):
  199. super().__init__(config)
  200. self.dropout = nn.Dropout(config.dropout_rate)
  201. def forward(self, x):
  202. hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
  203. hidden_states = self.dropout(hidden_states)
  204. down_proj = self.down_proj(hidden_states)
  205. return down_proj
  206. class T5Gemma2RotaryEmbedding(Gemma3RotaryEmbedding):
  207. def __init__(self, config: T5Gemma2TextConfig, device=None):
  208. super().__init__(config, device)
  209. @staticmethod
  210. def compute_default_rope_parameters(
  211. config: T5Gemma2TextConfig | None = None,
  212. device: Optional["torch.device"] = None,
  213. seq_len: int | None = None,
  214. layer_type: str | None = None,
  215. ) -> tuple["torch.Tensor", float]:
  216. return super().compute_default_rope_parameters(config, device, seq_len, layer_type)
  217. class T5Gemma2SelfAttention(Gemma3Attention):
  218. def __init__(self, config: T5Gemma2TextConfig, layer_idx: int):
  219. super().__init__(config, layer_idx)
  220. self.is_causal = False # Only used by the encoder
  221. class T5Gemma2MergedAttention(Gemma3Attention):
  222. """Merged self-attention and cross-attention for decoder."""
  223. def __init__(self, config: T5Gemma2TextConfig, layer_idx: int):
  224. super().__init__(config, layer_idx)
  225. self.is_causal = False # Fused causal and encoder mask
  226. def forward(
  227. self,
  228. # decoder self-attention inputs
  229. hidden_states: torch.Tensor,
  230. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  231. merged_attention_mask: torch.Tensor | None,
  232. # cross-attention inputs
  233. encoder_hidden_states: torch.Tensor,
  234. # cache inputs
  235. past_key_values: EncoderDecoderCache | None = None,
  236. # others
  237. **kwargs: Unpack[FlashAttentionKwargs],
  238. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  239. # attention shapes.
  240. input_shape = hidden_states.shape[:-1]
  241. hidden_shape = (*input_shape, -1, self.head_dim)
  242. cross_input_shape = encoder_hidden_states.shape[:-1]
  243. cross_hidden_shape = (*cross_input_shape, -1, self.head_dim)
  244. # self-attention.
  245. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  246. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  247. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  248. query_states = self.q_norm(query_states)
  249. key_states = self.k_norm(key_states)
  250. cos, sin = position_embeddings
  251. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  252. if past_key_values is not None:
  253. # self-attention.
  254. self_attention_cache = past_key_values.self_attention_cache
  255. key_states, value_states = self_attention_cache.update(key_states, value_states, self.layer_idx)
  256. # cross-attention.
  257. is_updated = past_key_values.is_updated.get(self.layer_idx)
  258. cross_attention_cache = past_key_values.cross_attention_cache
  259. if past_key_values is None or not is_updated:
  260. cross_key_states = self.k_proj(encoder_hidden_states).view(cross_hidden_shape).transpose(1, 2)
  261. cross_value_states = self.v_proj(encoder_hidden_states).view(cross_hidden_shape).transpose(1, 2)
  262. cross_key_states = self.k_norm(cross_key_states)
  263. if past_key_values is not None:
  264. cross_key_states, cross_value_states = cross_attention_cache.update(
  265. cross_key_states, cross_value_states, self.layer_idx
  266. )
  267. past_key_values.is_updated[self.layer_idx] = True
  268. else:
  269. cross_key_states = cross_attention_cache.layers[self.layer_idx].keys
  270. cross_value_states = cross_attention_cache.layers[self.layer_idx].values
  271. # merged attention.
  272. query_states = query_states
  273. cross_key_size = cross_input_shape[1]
  274. key_states = torch.cat([key_states, cross_key_states], dim=2)
  275. value_states = torch.cat([value_states, cross_value_states], dim=2)
  276. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  277. self.config._attn_implementation, eager_attention_forward
  278. )
  279. attn_output, attn_weights = attention_interface(
  280. self,
  281. query_states,
  282. key_states,
  283. value_states,
  284. merged_attention_mask,
  285. dropout=self.attention_dropout if self.training else 0.0,
  286. scaling=self.scaling,
  287. **kwargs,
  288. )
  289. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  290. attn_output = self.o_proj(attn_output)
  291. # decompose merged attention weights into self & cross attention weights
  292. if attn_weights is not None:
  293. self_attn_weights = attn_weights[..., :-cross_key_size]
  294. cross_attn_weights = attn_weights[..., -cross_key_size:]
  295. else:
  296. self_attn_weights, cross_attn_weights = None, None
  297. return attn_output, self_attn_weights, cross_attn_weights
  298. def sliding_window_mask_function(sliding_window: int, is_causal=True) -> Callable:
  299. """
  300. This creates uni/bidirectional attention mask with sliding window.
  301. """
  302. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  303. if is_causal:
  304. left_window_size, right_window_size = sliding_window, 0
  305. else:
  306. left_window_size, right_window_size = ((sliding_window + 1) // 2, (sliding_window) // 2 + 1)
  307. dist = q_idx - kv_idx
  308. left_mask = (dist >= 0) & (dist < left_window_size)
  309. right_mask = (dist < 0) & (-dist < right_window_size)
  310. return left_mask | right_mask
  311. return inner_mask
  312. class T5Gemma2EncoderLayer(T5GemmaEncoderLayer):
  313. pass
  314. class T5Gemma2DecoderLayer(T5GemmaEncoderLayer):
  315. """Decoder sub-layer: merged attention instead of vanilla self-attention."""
  316. def __init__(self, config, layer_idx: int):
  317. super().__init__(config, layer_idx)
  318. # replace vanilla self-attention with merged attention to support joint cross-attention.
  319. self.self_attn = T5Gemma2MergedAttention(
  320. config=config,
  321. layer_idx=layer_idx,
  322. )
  323. def forward(
  324. self,
  325. hidden_states: torch.Tensor,
  326. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  327. merged_attention_mask: torch.Tensor | None = None,
  328. position_ids: torch.LongTensor | None = None,
  329. past_key_values: EncoderDecoderCache | None = None,
  330. use_cache: bool | None = False,
  331. encoder_hidden_states: torch.Tensor | None = None,
  332. **kwargs,
  333. ) -> torch.FloatTensor:
  334. residual = hidden_states
  335. hidden_states = self.pre_self_attn_layernorm(hidden_states)
  336. hidden_states, _, _ = self.self_attn(
  337. hidden_states=hidden_states,
  338. position_embeddings=position_embeddings,
  339. merged_attention_mask=merged_attention_mask,
  340. position_ids=position_ids,
  341. past_key_values=past_key_values,
  342. use_cache=use_cache,
  343. encoder_hidden_states=encoder_hidden_states,
  344. **kwargs,
  345. )
  346. hidden_states = self.post_self_attn_layernorm(hidden_states)
  347. hidden_states = residual + self.dropout(hidden_states)
  348. residual = hidden_states
  349. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  350. hidden_states = self.mlp(hidden_states)
  351. hidden_states = self.post_feedforward_layernorm(hidden_states)
  352. hidden_states = residual + self.dropout(hidden_states)
  353. return hidden_states
  354. class T5Gemma2LMHead(T5GemmaLMHead):
  355. pass
  356. class T5Gemma2ClassificationHead(T5GemmaClassificationHead):
  357. pass
  358. class T5Gemma2MultiModalProjector(Gemma3MultiModalProjector):
  359. def __init__(self, config: T5Gemma2EncoderConfig):
  360. super().__init__(config)
  361. class T5Gemma2TextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
  362. """T5Gemma2 Embedding: override to add eoi token embedding separately."""
  363. def __init__(
  364. self,
  365. num_embeddings: int,
  366. embedding_dim: int,
  367. padding_idx: int,
  368. embed_scale: float = 1.0,
  369. eoi_token_index: int = 256_000,
  370. ):
  371. super().__init__(num_embeddings, embedding_dim, padding_idx, embed_scale)
  372. self.eoi_token_index = eoi_token_index
  373. self.eoi_embedding = nn.Parameter(torch.zeros(self.embedding_dim))
  374. def forward(self, input_ids: torch.Tensor):
  375. input_embeddings = super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
  376. input_embeddings[input_ids == self.eoi_token_index] = self.eoi_embedding.to(input_embeddings.dtype)
  377. return input_embeddings
  378. @auto_docstring
  379. class T5Gemma2PreTrainedModel(Gemma3PreTrainedModel):
  380. config: T5Gemma2Config
  381. base_model_prefix = "model"
  382. supports_gradient_checkpointing = True
  383. # Mask creation is incompatible
  384. # FA due to non-default creation / SWA
  385. _supports_flash_attn = False
  386. # Flex due to custom masks not compatible to be merged after creation
  387. _supports_flex_attn = False
  388. _no_split_modules = [
  389. "T5Gemma2EncoderLayer",
  390. "T5Gemma2DecoderLayer",
  391. "SiglipVisionEmbeddings",
  392. "SiglipEncoderLayer",
  393. "SiglipMultiheadAttentionPoolingHead",
  394. ]
  395. _can_record_outputs = {
  396. "hidden_states": [T5Gemma2EncoderLayer, T5Gemma2DecoderLayer],
  397. "attentions": [
  398. OutputRecorder(T5Gemma2SelfAttention, index=1, layer_name="self_attn"),
  399. OutputRecorder(T5Gemma2MergedAttention, index=1, layer_name="self_attn"),
  400. OutputRecorder(T5Gemma2MergedAttention, index=2, layer_name="cross_attn"),
  401. ],
  402. }
  403. def _init_weights(self, module):
  404. PreTrainedModel._init_weights(self, module)
  405. if isinstance(module, T5Gemma2MultiModalProjector):
  406. init.zeros_(module.mm_input_projection_weight)
  407. elif isinstance(module, T5Gemma2TextScaledWordEmbedding):
  408. init.zeros_(module.eoi_embedding)
  409. init.constant_(module.embed_scale, module.scalar_embed_scale)
  410. elif isinstance(module, T5Gemma2ClassificationHead):
  411. scale = module.out_proj.weight.shape[0] ** -0.5
  412. init.normal_(module.out_proj.weight, mean=0.0, std=self.config.initializer_range * scale)
  413. if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
  414. init.zeros_(module.out_proj.bias)
  415. # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
  416. elif "RMSNorm" in module.__class__.__name__:
  417. init.zeros_(module.weight)
  418. elif isinstance(module, T5Gemma2RotaryEmbedding):
  419. for layer_type in module.layer_types:
  420. rope_init_fn = module.compute_default_rope_parameters
  421. if module.rope_type[layer_type] != "default":
  422. rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
  423. curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
  424. init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
  425. init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
  426. def prepare_decoder_input_ids_from_labels(self, input_ids):
  427. """
  428. Shifts input_ids to the right, prepends the decoder_start_token_id, and handles
  429. pad_token_id replacement for labels that were -100.
  430. This is a common preparation step for decoder inputs in sequence-to-sequence models.
  431. """
  432. decoder_config = self.config.decoder
  433. decoder_start_token_id = decoder_config.bos_token_id
  434. pad_token_id = decoder_config.pad_token_id
  435. if decoder_start_token_id is None:
  436. raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ")
  437. # shift inputs to the right
  438. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  439. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  440. shifted_input_ids[..., 0] = decoder_start_token_id
  441. if pad_token_id is None:
  442. raise ValueError("self.model.config.decoder.pad_token_id has to be defined.")
  443. # Is this T5 specific?
  444. # replace possible -100 values in labels by `pad_token_id`
  445. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  446. return shifted_input_ids
  447. class T5Gemma2TextEncoder(T5Gemma2PreTrainedModel):
  448. config: T5Gemma2TextConfig
  449. _can_record_outputs = {
  450. "attentions": T5Gemma2SelfAttention,
  451. "hidden_states": T5Gemma2EncoderLayer,
  452. }
  453. def __init__(
  454. self,
  455. config: T5Gemma2TextConfig,
  456. eoi_token_index: int = 256_000,
  457. ):
  458. super().__init__(config)
  459. self.padding_idx = config.pad_token_id
  460. self.vocab_size = config.vocab_size
  461. self.embed_tokens = T5Gemma2TextScaledWordEmbedding(
  462. config.vocab_size,
  463. config.hidden_size,
  464. self.padding_idx,
  465. embed_scale=config.hidden_size**0.5,
  466. eoi_token_index=eoi_token_index,
  467. )
  468. self.norm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  469. self.gradient_checkpointing = False
  470. self.layers = nn.ModuleList(
  471. [T5Gemma2EncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  472. )
  473. self.dropout = nn.Dropout(config.dropout_rate)
  474. self.rotary_emb = T5Gemma2RotaryEmbedding(config)
  475. # Initialize weights and apply final processing
  476. self.post_init()
  477. @merge_with_config_defaults
  478. @capture_outputs
  479. @auto_docstring
  480. def forward(
  481. self,
  482. input_ids: torch.LongTensor | None = None,
  483. attention_mask: torch.Tensor | None = None,
  484. position_ids: torch.LongTensor | None = None,
  485. inputs_embeds: torch.FloatTensor | None = None,
  486. # Unused for processor compatibility kept in signature.
  487. token_type_ids: torch.Tensor | None = None,
  488. **kwargs: Unpack[TransformersKwargs],
  489. ) -> BaseModelOutput:
  490. if (input_ids is None) ^ (inputs_embeds is not None):
  491. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  492. # As we want to pass `past_key_values=None` explicitly everywhere, we need to pop them from kwargs if present
  493. kwargs.pop("past_key_values", None)
  494. if inputs_embeds is None:
  495. inputs_embeds = self.embed_tokens(input_ids)
  496. if position_ids is None:
  497. position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
  498. if not isinstance(self_attn_mask_mapping := attention_mask, dict):
  499. mask_kwargs = {
  500. "config": self.config,
  501. "inputs_embeds": inputs_embeds,
  502. "attention_mask": attention_mask,
  503. }
  504. self_attn_mask_mapping = {
  505. "full_attention": create_bidirectional_mask(**mask_kwargs),
  506. "sliding_attention": create_bidirectional_mask(
  507. **mask_kwargs,
  508. and_mask_function=sliding_window_mask_function(self.config.sliding_window, is_causal=False),
  509. ),
  510. }
  511. # input layer
  512. hidden_states = inputs_embeds
  513. # global and local position embeddings
  514. position_embeddings = {}
  515. for layer_type in self.config.layer_types:
  516. position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
  517. # dropout
  518. hidden_states = self.dropout(hidden_states)
  519. for i, layer_module in enumerate(self.layers[: self.config.num_hidden_layers]):
  520. hidden_states = layer_module(
  521. hidden_states,
  522. position_embeddings[self.config.layer_types[i]],
  523. self_attn_mask_mapping[self.config.layer_types[i]],
  524. position_ids,
  525. **kwargs,
  526. )
  527. hidden_states = self.norm(hidden_states)
  528. hidden_states = self.dropout(hidden_states)
  529. return BaseModelOutput(
  530. last_hidden_state=hidden_states,
  531. )
  532. class T5Gemma2Encoder(T5Gemma2PreTrainedModel):
  533. config: T5Gemma2EncoderConfig
  534. def __init__(
  535. self,
  536. config: T5Gemma2EncoderConfig,
  537. eoi_token_index: int = 256_000,
  538. ):
  539. super().__init__(config)
  540. self.text_model = T5Gemma2TextEncoder._from_config(config.text_config, eoi_token_index=eoi_token_index)
  541. self.vision_tower = AutoModel.from_config(config=config.vision_config)
  542. self.multi_modal_projector = T5Gemma2MultiModalProjector(config)
  543. # Initialize weights and apply final processing
  544. self.post_init()
  545. def get_input_embeddings(self):
  546. return self.text_model.get_input_embeddings()
  547. def set_input_embeddings(self, new_embeddings):
  548. return self.text_model.set_input_embeddings(new_embeddings)
  549. @can_return_tuple
  550. @auto_docstring
  551. def get_image_features(
  552. self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
  553. ) -> tuple | BaseModelOutputWithPooling:
  554. # pixel_values: (batch_size, channels, height, width)
  555. # image_features: Image feature tensor of shape (num_images, image_length, embed_dim).
  556. vision_outputs = self.vision_tower(pixel_values=pixel_values, return_dict=True, **kwargs)
  557. last_hidden_state = vision_outputs.last_hidden_state
  558. image_features = self.multi_modal_projector(last_hidden_state)
  559. vision_outputs.pooler_output = image_features
  560. return vision_outputs
  561. def get_image_placeholder_mask(
  562. self,
  563. input_ids: torch.LongTensor | None,
  564. inputs_embeds: torch.FloatTensor | None,
  565. image_features: torch.FloatTensor,
  566. ):
  567. """
  568. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  569. equal to the length of multimodal features. If the lengths are different, an error is raised.
  570. """
  571. image_token_id = self.config.image_token_id
  572. if input_ids is None:
  573. if inputs_embeds is None:
  574. raise ValueError("Either `input_ids` or `inputs_embeds` has to be provided.")
  575. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  576. torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
  577. )
  578. special_image_mask = special_image_mask.all(-1)
  579. else:
  580. special_image_mask = input_ids == image_token_id
  581. n_image_tokens = special_image_mask.sum()
  582. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  583. n_image_features = image_features.shape[0] * image_features.shape[1]
  584. torch_compilable_check(
  585. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  586. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}",
  587. )
  588. return special_image_mask
  589. @auto_docstring
  590. def forward(
  591. self,
  592. input_ids: torch.LongTensor | None = None,
  593. attention_mask: torch.Tensor | None = None,
  594. position_ids: torch.LongTensor | None = None,
  595. inputs_embeds: torch.FloatTensor | None = None,
  596. pixel_values: torch.FloatTensor | None = None,
  597. # Unused for processor compatibility kept in signature.
  598. token_type_ids: torch.Tensor | None = None,
  599. **kwargs: Unpack[TransformersKwargs],
  600. ) -> BaseModelOutput:
  601. if (input_ids is None) ^ (inputs_embeds is not None):
  602. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  603. if inputs_embeds is None:
  604. inputs_embeds = self.text_model.embed_tokens(input_ids)
  605. if pixel_values is not None:
  606. image_features = self.get_image_features(pixel_values, return_dict=True).pooler_output
  607. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  608. image_mask = self.get_image_placeholder_mask(
  609. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  610. )
  611. inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features)
  612. outputs = self.text_model(
  613. inputs_embeds=inputs_embeds,
  614. attention_mask=attention_mask,
  615. position_ids=position_ids,
  616. **kwargs,
  617. )
  618. return outputs
  619. class T5Gemma2Decoder(T5Gemma2PreTrainedModel):
  620. config: T5Gemma2DecoderConfig
  621. _can_record_outputs = {
  622. "attentions": OutputRecorder(T5Gemma2MergedAttention, index=1),
  623. "cross_attentions": OutputRecorder(T5Gemma2MergedAttention, index=2),
  624. "hidden_states": T5Gemma2DecoderLayer,
  625. }
  626. def __init__(self, config: T5Gemma2DecoderConfig, eoi_token_index: int = 256_000):
  627. super().__init__(config)
  628. self.padding_idx = config.pad_token_id
  629. self.vocab_size = config.vocab_size
  630. self.embed_tokens = T5Gemma2TextScaledWordEmbedding(
  631. config.vocab_size,
  632. config.hidden_size,
  633. config.pad_token_id,
  634. embed_scale=config.hidden_size**0.5,
  635. eoi_token_index=eoi_token_index,
  636. )
  637. self.norm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  638. self.gradient_checkpointing = False
  639. self.layers = nn.ModuleList(
  640. [T5Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  641. )
  642. self.dropout = nn.Dropout(config.dropout_rate)
  643. self.rotary_emb = T5Gemma2RotaryEmbedding(config)
  644. self.post_init()
  645. @merge_with_config_defaults
  646. @capture_outputs
  647. @auto_docstring
  648. def forward(
  649. self,
  650. input_ids: torch.LongTensor | None = None,
  651. attention_mask: torch.Tensor | None = None,
  652. position_ids: torch.LongTensor | None = None,
  653. past_key_values: EncoderDecoderCache | None = None,
  654. inputs_embeds: torch.FloatTensor | None = None,
  655. use_cache: bool | None = None,
  656. encoder_hidden_states: torch.Tensor | None = None,
  657. encoder_attention_mask: torch.Tensor | None = None,
  658. **kwargs: Unpack[TransformersKwargs],
  659. ) -> BaseModelOutputWithPastAndCrossAttentions:
  660. if (input_ids is None) ^ (inputs_embeds is not None):
  661. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  662. if encoder_hidden_states is None:
  663. raise ValueError("`encoder_hidden_states` must be given in decoder")
  664. if inputs_embeds is None:
  665. inputs_embeds = self.embed_tokens(input_ids)
  666. if not self.training and use_cache and past_key_values is None:
  667. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache())
  668. if position_ids is None:
  669. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  670. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  671. position_ids = position_ids.unsqueeze(0)
  672. if not isinstance(self_attn_mask_mapping := attention_mask, dict):
  673. # this masking function does nothing to masking but forces `allow_is_causal_skip` to be False
  674. # as we always need a mask during decoding for merged attention.
  675. dummy_and_mask_function = lambda *args: torch.tensor(True, dtype=torch.bool) # noqa
  676. mask_kwargs = {
  677. "config": self.config,
  678. "inputs_embeds": inputs_embeds,
  679. "attention_mask": attention_mask,
  680. "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None,
  681. "position_ids": position_ids,
  682. "and_mask_function": dummy_and_mask_function,
  683. }
  684. self_attn_mask_mapping = {
  685. "full_attention": create_causal_mask(**mask_kwargs),
  686. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  687. }
  688. if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict):
  689. cross_attn_mask_mapping = {
  690. "full_attention": create_bidirectional_mask(
  691. config=self.config,
  692. inputs_embeds=inputs_embeds,
  693. attention_mask=encoder_attention_mask,
  694. encoder_hidden_states=encoder_hidden_states,
  695. and_mask_function=dummy_and_mask_function,
  696. )
  697. }
  698. merged_attn_mask_mapping = {
  699. "full_attention": torch.cat(
  700. [self_attn_mask_mapping["full_attention"], cross_attn_mask_mapping["full_attention"]], dim=-1
  701. ),
  702. "sliding_attention": torch.cat(
  703. [self_attn_mask_mapping["sliding_attention"], cross_attn_mask_mapping["full_attention"]], dim=-1
  704. ),
  705. }
  706. # input layer
  707. hidden_states = inputs_embeds
  708. # global and local position embeddings
  709. position_embeddings = {}
  710. for layer_type in self.config.layer_types:
  711. position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
  712. # dropout
  713. hidden_states = self.dropout(hidden_states)
  714. for i, layer_module in enumerate(self.layers[: self.config.num_hidden_layers]):
  715. hidden_states = layer_module(
  716. hidden_states,
  717. position_embeddings[self.config.layer_types[i]],
  718. merged_attn_mask_mapping[self.config.layer_types[i]],
  719. position_ids,
  720. past_key_values,
  721. use_cache,
  722. encoder_hidden_states,
  723. **kwargs,
  724. )
  725. hidden_states = self.norm(hidden_states)
  726. hidden_states = self.dropout(hidden_states)
  727. return BaseModelOutputWithPastAndCrossAttentions(
  728. last_hidden_state=hidden_states,
  729. past_key_values=past_key_values,
  730. )
  731. @auto_docstring
  732. class T5Gemma2Model(T5Gemma2PreTrainedModel):
  733. _tied_weights_keys = {
  734. "decoder.embed_tokens.weight": "encoder.text_model.embed_tokens.weight",
  735. "decoder.embed_tokens.eoi_embedding": "encoder.text_model.embed_tokens.eoi_embedding",
  736. }
  737. def __init__(self, config: T5Gemma2Config):
  738. super().__init__(config)
  739. # setup encoder and decoder
  740. self.encoder = T5Gemma2Encoder(config.encoder, config.eoi_token_index)
  741. self.decoder = T5Gemma2Decoder(config.decoder, config.eoi_token_index)
  742. self.post_init()
  743. def get_encoder(self):
  744. return self.encoder
  745. def get_decoder(self):
  746. return self.decoder
  747. def get_input_embeddings(self):
  748. return self.encoder.get_input_embeddings()
  749. def set_input_embeddings(self, new_embeddings):
  750. return self.encoder.set_input_embeddings(new_embeddings)
  751. @can_return_tuple
  752. @auto_docstring
  753. def forward(
  754. self,
  755. # encoder inputs
  756. input_ids: torch.LongTensor | None = None,
  757. pixel_values: torch.FloatTensor | None = None,
  758. attention_mask: torch.FloatTensor | None = None,
  759. position_ids: torch.LongTensor | None = None,
  760. # decoder inputs
  761. decoder_input_ids: torch.LongTensor | None = None,
  762. decoder_attention_mask: torch.BoolTensor | None = None,
  763. decoder_position_ids: torch.LongTensor | None = None,
  764. # others (mainly inference or cache related)
  765. encoder_outputs: BaseModelOutput | None = None,
  766. past_key_values: EncoderDecoderCache | None = None,
  767. inputs_embeds: torch.Tensor | None = None,
  768. decoder_inputs_embeds: torch.Tensor | None = None,
  769. use_cache: bool | None = None,
  770. **kwargs: Unpack[TransformersKwargs],
  771. ) -> Seq2SeqModelOutput:
  772. r"""
  773. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  774. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  775. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  776. """
  777. # encoder
  778. if encoder_outputs is None:
  779. encoder_outputs = self.encoder(
  780. input_ids=input_ids,
  781. attention_mask=attention_mask,
  782. position_ids=position_ids,
  783. inputs_embeds=inputs_embeds,
  784. pixel_values=pixel_values,
  785. return_dict=True,
  786. **kwargs,
  787. )
  788. encoder_hidden_states = encoder_outputs.last_hidden_state
  789. # decoder
  790. decoder_outputs = self.decoder(
  791. input_ids=decoder_input_ids,
  792. attention_mask=decoder_attention_mask,
  793. position_ids=decoder_position_ids,
  794. inputs_embeds=decoder_inputs_embeds,
  795. past_key_values=past_key_values,
  796. encoder_hidden_states=encoder_hidden_states,
  797. encoder_attention_mask=attention_mask,
  798. use_cache=use_cache,
  799. return_dict=True,
  800. **kwargs,
  801. )
  802. return Seq2SeqModelOutput(
  803. last_hidden_state=decoder_outputs.last_hidden_state,
  804. past_key_values=decoder_outputs.past_key_values,
  805. decoder_hidden_states=decoder_outputs.hidden_states,
  806. decoder_attentions=decoder_outputs.attentions,
  807. cross_attentions=decoder_outputs.cross_attentions,
  808. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  809. encoder_hidden_states=encoder_outputs.hidden_states,
  810. encoder_attentions=encoder_outputs.attentions,
  811. )
  812. class T5Gemma2ForConditionalGeneration(T5Gemma2PreTrainedModel, GenerationMixin):
  813. _tied_weights_keys = {
  814. "lm_head.out_proj.weight": "model.encoder.text_model.embed_tokens.weight",
  815. }
  816. _tp_plan = {"lm_head.out_proj": "colwise_gather_output"}
  817. _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])}
  818. def __init__(self, config: T5Gemma2Config):
  819. super().__init__(config)
  820. self.model = T5Gemma2Model(config)
  821. self.vocab_size = config.decoder.vocab_size
  822. self.lm_head = T5Gemma2LMHead(config.decoder.hidden_size, self.vocab_size)
  823. self.loss_type = "ForMaskedLM"
  824. self.post_init()
  825. def set_output_embeddings(self, new_embeddings):
  826. self.lm_head.out_proj = new_embeddings
  827. def get_output_embeddings(self):
  828. return self.lm_head.out_proj
  829. def get_input_embeddings(self):
  830. return self.model.get_input_embeddings()
  831. def set_input_embeddings(self, value):
  832. self.model.set_input_embeddings(value)
  833. def get_encoder(self):
  834. return self.model.get_encoder()
  835. def get_decoder(self):
  836. return self.model.get_decoder()
  837. @can_return_tuple
  838. @auto_docstring
  839. def get_image_features(
  840. self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
  841. ) -> tuple | BaseModelOutputWithPooling:
  842. return self.get_encoder().get_image_features(pixel_values, **kwargs)
  843. @property
  844. def vision_tower(self):
  845. return self.get_encoder().vision_tower
  846. @can_return_tuple
  847. @auto_docstring
  848. def forward(
  849. self,
  850. # encoder inputs
  851. input_ids: torch.LongTensor | None = None,
  852. pixel_values: torch.FloatTensor | None = None,
  853. attention_mask: torch.FloatTensor | None = None,
  854. position_ids: torch.LongTensor | None = None,
  855. # decoder inputs
  856. decoder_input_ids: torch.LongTensor | None = None,
  857. decoder_attention_mask: torch.BoolTensor | None = None,
  858. decoder_position_ids: torch.LongTensor | None = None,
  859. # others (mainly inference or cache related)
  860. encoder_outputs: BaseModelOutput | None = None,
  861. past_key_values: EncoderDecoderCache | None = None,
  862. inputs_embeds: torch.FloatTensor | None = None,
  863. decoder_inputs_embeds: torch.FloatTensor | None = None,
  864. labels: torch.LongTensor | None = None,
  865. use_cache: bool | None = None,
  866. logits_to_keep: int | torch.Tensor = 0,
  867. **kwargs: Unpack[TransformersKwargs],
  868. ) -> tuple[torch.FloatTensor] | Seq2SeqLMOutput:
  869. r"""
  870. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  871. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  872. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  873. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  874. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  875. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  876. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  877. """
  878. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  879. # get decoder inputs from shifting lm labels to the right
  880. decoder_input_ids = self.prepare_decoder_input_ids_from_labels(labels)
  881. decoder_outputs: Seq2SeqModelOutput = self.model(
  882. input_ids=input_ids,
  883. pixel_values=pixel_values,
  884. attention_mask=attention_mask,
  885. position_ids=position_ids,
  886. decoder_input_ids=decoder_input_ids,
  887. decoder_attention_mask=decoder_attention_mask,
  888. decoder_position_ids=decoder_position_ids,
  889. encoder_outputs=encoder_outputs,
  890. past_key_values=past_key_values,
  891. inputs_embeds=inputs_embeds,
  892. decoder_inputs_embeds=decoder_inputs_embeds,
  893. use_cache=use_cache,
  894. **kwargs,
  895. )
  896. hidden_states = decoder_outputs.last_hidden_state
  897. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  898. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  899. logits = self.lm_head(hidden_states[:, slice_indices, :])
  900. decoder_config = self.config.decoder
  901. if decoder_config.final_logit_softcapping is not None:
  902. logits = logits / decoder_config.final_logit_softcapping
  903. logits = torch.tanh(logits)
  904. logits = logits * decoder_config.final_logit_softcapping
  905. loss = None
  906. if labels is not None:
  907. # Input has right-shifted so we directly perform masked lm loss
  908. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  909. return Seq2SeqLMOutput(
  910. loss=loss,
  911. logits=logits,
  912. past_key_values=decoder_outputs.past_key_values,
  913. decoder_hidden_states=decoder_outputs.decoder_hidden_states,
  914. decoder_attentions=decoder_outputs.decoder_attentions,
  915. cross_attentions=decoder_outputs.cross_attentions,
  916. encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state,
  917. encoder_hidden_states=decoder_outputs.encoder_hidden_states,
  918. encoder_attentions=decoder_outputs.encoder_attentions,
  919. )
  920. def _prepare_cache_for_generation(
  921. self,
  922. generation_config: GenerationConfig,
  923. model_kwargs: dict,
  924. generation_mode: GenerationMode,
  925. batch_size: int,
  926. max_cache_length: int,
  927. ) -> bool:
  928. """Override cache preparation to support T5Gemma2-specific EncoderDecoder Cache."""
  929. # Build cache and past_key_values structure first and then override as needed.
  930. super()._prepare_cache_for_generation(
  931. generation_config,
  932. model_kwargs,
  933. generation_mode,
  934. batch_size,
  935. max_cache_length,
  936. )
  937. # If use_cache is False, do not prepare the cache.
  938. if generation_config.use_cache is False:
  939. return
  940. cache_implementation = generation_config.cache_implementation
  941. if cache_implementation is None:
  942. offload_cache = False
  943. else:
  944. offload_cache = "offloaded" in generation_config.cache_implementation
  945. # Main change: use full cache for cross-attention.
  946. cross_attn_config = copy.deepcopy(self.config.get_text_config(decoder=True))
  947. # cross-attention does not use sliding window
  948. del cross_attn_config.sliding_window
  949. del cross_attn_config.layer_types
  950. cross_attn_cache_kwargs = {
  951. "config": cross_attn_config,
  952. "offloading": offload_cache,
  953. }
  954. past_key_values = model_kwargs.get("past_key_values")
  955. if past_key_values is not None:
  956. if not isinstance(past_key_values, EncoderDecoderCache):
  957. raise ValueError(
  958. "The `past_key_values` in `model_kwargs` must be of type `EncoderDecoderCache` for T5Gemma2 model."
  959. )
  960. # Cache already established, no need to re-initialize.
  961. if len(past_key_values.is_updated) > 0 and past_key_values.is_updated.get(0):
  962. return
  963. cross_attn_cls = type(past_key_values.cross_attention_cache)
  964. if cross_attn_cls == StaticCache:
  965. cross_attn_cache_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
  966. # Update cross-attention cache only (switch from sliding_window to full).
  967. past_key_values.cross_attention_cache = cross_attn_cls(**cross_attn_cache_kwargs)
  968. else:
  969. # Initialize new cache.
  970. model_kwargs["past_key_values"] = EncoderDecoderCache(
  971. DynamicCache(
  972. **{
  973. "config": self.config.get_text_config(decoder=True),
  974. "offloading": offload_cache,
  975. }
  976. ), # self-attention cache
  977. DynamicCache(), # cross-attention cache
  978. )
  979. if hasattr(self, "_cache") and self._cache is not None:
  980. if not isinstance(self._cache, EncoderDecoderCache):
  981. raise ValueError("The internal cache must be of type `EncoderDecoderCache` for T5Gemma2 model.")
  982. self._cache = model_kwargs["past_key_values"]
  983. @auto_docstring
  984. class T5Gemma2ForSequenceClassification(T5Gemma2PreTrainedModel):
  985. def __init__(self, config: T5Gemma2Config):
  986. super().__init__(config)
  987. self.num_labels = config.num_labels
  988. self.hidden_size = config.decoder.hidden_size
  989. self.model = T5Gemma2Model(config)
  990. classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1)
  991. self.score = T5Gemma2ClassificationHead(self.hidden_size, self.num_labels, classifier_dropout)
  992. self.post_init()
  993. def get_input_embeddings(self):
  994. return self.model.get_input_embeddings()
  995. def set_input_embeddings(self, value):
  996. self.model.set_input_embeddings(value)
  997. @can_return_tuple
  998. @auto_docstring
  999. def forward(
  1000. self,
  1001. input_ids: torch.LongTensor | None = None,
  1002. pixel_values: torch.FloatTensor | None = None,
  1003. attention_mask: torch.Tensor | None = None,
  1004. position_ids: torch.LongTensor | None = None,
  1005. decoder_input_ids: torch.LongTensor | None = None,
  1006. decoder_attention_mask: torch.Tensor | None = None,
  1007. decoder_position_ids: torch.LongTensor | None = None,
  1008. encoder_outputs: BaseModelOutput | None = None,
  1009. inputs_embeds: torch.FloatTensor | None = None,
  1010. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1011. labels: torch.LongTensor | None = None,
  1012. **kwargs: Unpack[TransformersKwargs],
  1013. ) -> SequenceClassifierOutput:
  1014. r"""
  1015. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  1016. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  1017. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  1018. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1019. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1020. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1021. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1022. """
  1023. if inputs_embeds is not None or decoder_inputs_embeds is not None:
  1024. raise NotImplementedError(
  1025. f"Passing input embeddings is currently not supported for {self.__class__.__name__}."
  1026. )
  1027. if input_ids is None:
  1028. raise ValueError("You have to specify input_ids")
  1029. if decoder_input_ids is None:
  1030. decoder_input_ids = self.prepare_decoder_input_ids_from_labels(input_ids)
  1031. outputs: Seq2SeqModelOutput = self.model(
  1032. input_ids,
  1033. pixel_values=pixel_values,
  1034. attention_mask=attention_mask,
  1035. position_ids=position_ids,
  1036. decoder_input_ids=decoder_input_ids,
  1037. decoder_attention_mask=decoder_attention_mask,
  1038. decoder_position_ids=decoder_position_ids,
  1039. encoder_outputs=encoder_outputs,
  1040. inputs_embeds=inputs_embeds,
  1041. decoder_inputs_embeds=decoder_inputs_embeds,
  1042. use_cache=False,
  1043. **kwargs,
  1044. )
  1045. last_hidden_state = outputs.last_hidden_state
  1046. hidden_states = outputs.decoder_hidden_states
  1047. attentions = outputs.decoder_attentions
  1048. logits = self.score(last_hidden_state)
  1049. batch_size = input_ids.shape[0]
  1050. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  1051. non_pad_mask = (decoder_input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  1052. token_indices = torch.arange(decoder_input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  1053. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  1054. last_non_pad_token = torch.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1)
  1055. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  1056. loss = None
  1057. if labels is not None:
  1058. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  1059. return SequenceClassifierOutput(
  1060. loss=loss,
  1061. logits=pooled_logits,
  1062. hidden_states=hidden_states,
  1063. attentions=attentions,
  1064. )
  1065. @auto_docstring
  1066. class T5Gemma2ForTokenClassification(T5Gemma2PreTrainedModel):
  1067. def __init__(self, config: T5Gemma2Config):
  1068. super().__init__(config)
  1069. self.num_labels = config.num_labels
  1070. self.hidden_size = config.decoder.hidden_size
  1071. self.model = T5Gemma2Model(config)
  1072. classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1)
  1073. self.score = T5Gemma2ClassificationHead(self.hidden_size, self.num_labels, classifier_dropout)
  1074. self.post_init()
  1075. def get_input_embeddings(self):
  1076. return self.model.get_input_embeddings()
  1077. def set_input_embeddings(self, value):
  1078. self.model.set_input_embeddings(value)
  1079. @can_return_tuple
  1080. @auto_docstring
  1081. def forward(
  1082. self,
  1083. input_ids: torch.LongTensor | None = None,
  1084. pixel_values: torch.FloatTensor | None = None,
  1085. attention_mask: torch.Tensor | None = None,
  1086. position_ids: torch.LongTensor | None = None,
  1087. decoder_input_ids: torch.LongTensor | None = None,
  1088. decoder_attention_mask: torch.Tensor | None = None,
  1089. decoder_position_ids: torch.LongTensor | None = None,
  1090. encoder_outputs: BaseModelOutput | None = None,
  1091. inputs_embeds: torch.FloatTensor | None = None,
  1092. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1093. labels: torch.LongTensor | None = None,
  1094. **kwargs: Unpack[TransformersKwargs],
  1095. ) -> TokenClassifierOutput:
  1096. r"""
  1097. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  1098. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  1099. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  1100. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1101. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1102. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1103. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1104. """
  1105. if inputs_embeds is not None or decoder_inputs_embeds is not None:
  1106. raise NotImplementedError(
  1107. f"Passing input embeddings is currently not supported for {self.__class__.__name__}."
  1108. )
  1109. if input_ids is None:
  1110. raise ValueError("You have to specify input_ids")
  1111. if decoder_input_ids is None:
  1112. decoder_input_ids = self.prepare_decoder_input_ids_from_labels(input_ids)
  1113. outputs: Seq2SeqModelOutput = self.model(
  1114. input_ids,
  1115. pixel_values=pixel_values,
  1116. attention_mask=attention_mask,
  1117. position_ids=position_ids,
  1118. decoder_input_ids=decoder_input_ids,
  1119. decoder_attention_mask=decoder_attention_mask,
  1120. decoder_position_ids=decoder_position_ids,
  1121. encoder_outputs=encoder_outputs,
  1122. inputs_embeds=inputs_embeds,
  1123. decoder_inputs_embeds=decoder_inputs_embeds,
  1124. use_cache=False,
  1125. **kwargs,
  1126. )
  1127. last_hidden_state = outputs.last_hidden_state
  1128. hidden_states = outputs.decoder_hidden_states
  1129. attentions = outputs.decoder_attentions
  1130. logits = self.score(last_hidden_state)
  1131. loss = None
  1132. if labels is not None:
  1133. loss = self.loss_function(logits, labels, self.config)
  1134. return TokenClassifierOutput(
  1135. loss=loss,
  1136. logits=logits,
  1137. hidden_states=hidden_states,
  1138. attentions=attentions,
  1139. )
  1140. __all__ = [
  1141. "T5Gemma2Config",
  1142. "T5Gemma2TextConfig",
  1143. "T5Gemma2EncoderConfig",
  1144. "T5Gemma2DecoderConfig",
  1145. "T5Gemma2ForConditionalGeneration",
  1146. "T5Gemma2Model",
  1147. "T5Gemma2Encoder",
  1148. "T5Gemma2PreTrainedModel",
  1149. "T5Gemma2ForSequenceClassification",
  1150. "T5Gemma2ForTokenClassification",
  1151. ]