modular_t5gemma.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151
  1. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from collections.abc import Callable
  16. from typing import Any
  17. import torch
  18. import torch.nn as nn
  19. from huggingface_hub.dataclasses import strict
  20. from ... import initialization as init
  21. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  22. from ...configuration_utils import PreTrainedConfig
  23. from ...generation import GenerationMixin
  24. from ...masking_utils import (
  25. create_bidirectional_mask,
  26. create_bidirectional_sliding_window_mask,
  27. create_causal_mask,
  28. create_sliding_window_causal_mask,
  29. )
  30. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import (
  33. BaseModelOutput,
  34. BaseModelOutputWithPastAndCrossAttentions,
  35. Seq2SeqLMOutput,
  36. Seq2SeqModelOutput,
  37. SequenceClassifierOutput,
  38. TokenClassifierOutput,
  39. )
  40. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  41. from ...processing_utils import Unpack
  42. from ...utils import (
  43. TransformersKwargs,
  44. auto_docstring,
  45. can_return_tuple,
  46. logging,
  47. )
  48. from ...utils.generic import merge_with_config_defaults
  49. from ...utils.output_capturing import OutputRecorder, capture_outputs
  50. from ..gemma2.configuration_gemma2 import Gemma2Config
  51. from ..gemma2.modeling_gemma2 import (
  52. Gemma2Attention,
  53. Gemma2MLP,
  54. Gemma2PreTrainedModel,
  55. Gemma2RMSNorm,
  56. Gemma2RotaryEmbedding,
  57. eager_attention_forward,
  58. )
  59. logger = logging.get_logger(__name__)
  60. @auto_docstring(checkpoint="google/t5_gemma_module-7b")
  61. @strict
  62. class T5GemmaModuleConfig(Gemma2Config):
  63. r"""
  64. query_pre_attn_scalar (`float`, *optional*, defaults to 256):
  65. scaling factor used on the attention scores
  66. final_logit_softcapping (`float`, *optional*, defaults to 30.0):
  67. scaling factor when applying tanh softcapping on the logits.
  68. attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
  69. scaling factor when applying tanh softcapping on the attention scores.
  70. ```python
  71. >>> from transformers import T5GemmaModuleModel, T5GemmaModuleConfig
  72. >>> # Initializing a T5GemmaModule t5_gemma_module-7b style configuration
  73. >>> configuration = T5GemmaModuleConfig()
  74. >>> # Initializing a model from the t5_gemma_module-7b style configuration
  75. >>> model = T5GemmaModuleModel(configuration)
  76. >>> # Accessing the model configuration
  77. >>> configuration = model.config
  78. ```"""
  79. is_decoder: bool = False
  80. use_bidirectional_attention = AttributeError()
  81. @auto_docstring(checkpoint="google/t5_gemma_module-7b")
  82. @strict
  83. class T5GemmaConfig(PreTrainedConfig):
  84. r"""
  85. encoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*):
  86. Configuration for the encoder.
  87. decoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*):
  88. Configuration for the decoder.
  89. Example:
  90. ```python
  91. >>> from transformers import T5GemmaConfig, T5GemmaModel
  92. >>> t5gemma_config = T5GemmaConfig.from_pretrained("google/t5gemma-2b-2b-prefixlm-it")
  93. >>> model = T5GemmaModel(t5gemma_config)
  94. ```"""
  95. model_type = "t5gemma"
  96. keys_to_ignore_at_inference = ["past_key_values"]
  97. sub_configs = {"encoder": T5GemmaModuleConfig, "decoder": T5GemmaModuleConfig}
  98. encoder: T5GemmaModuleConfig | dict[Any, Any] | None = None
  99. decoder: T5GemmaModuleConfig | dict[Any, Any] | None = None
  100. is_encoder_decoder: bool = True
  101. dropout_rate: int | float = 0.0
  102. classifier_dropout_rate: int | float = 0.0
  103. attention_dropout: float | int = 0.0
  104. tie_word_embeddings: bool = True
  105. vocab_size: int = 256000
  106. def __post_init__(self, **kwargs):
  107. if isinstance(self.encoder, dict):
  108. self.encoder = T5GemmaModuleConfig(**self.encoder)
  109. elif self.encoder is None:
  110. self.encoder = T5GemmaModuleConfig()
  111. if isinstance(self.decoder, dict):
  112. self.decoder = T5GemmaModuleConfig(**self.decoder)
  113. elif self.decoder is None:
  114. self.decoder = T5GemmaModuleConfig()
  115. self.encoder.is_decoder = False
  116. self.encoder.dropout_rate = self.dropout_rate
  117. self.encoder.attention_dropout = self.attention_dropout
  118. self.decoder.is_decoder = True
  119. self.decoder.use_cache = True
  120. self.decoder.dropout_rate = self.dropout_rate
  121. self.decoder.attention_dropout = self.attention_dropout
  122. self.decoder.cross_attention_hidden_size = self.encoder.hidden_size
  123. self.initializer_range = kwargs.pop("initializer_range", self.decoder.initializer_range)
  124. for special_token_key in ["bos_token_id", "pad_token_id", "eos_token_id"]:
  125. if special_token_key not in kwargs:
  126. kwargs[special_token_key] = getattr(self.decoder, special_token_key)
  127. super().__post_init__(**kwargs)
  128. class T5GemmaRMSNorm(Gemma2RMSNorm):
  129. pass
  130. class T5GemmaMLP(Gemma2MLP):
  131. def __init__(self, config):
  132. super().__init__(config)
  133. self.dropout = nn.Dropout(config.dropout_rate)
  134. def forward(self, x):
  135. hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
  136. hidden_states = self.dropout(hidden_states)
  137. down_proj = self.down_proj(hidden_states)
  138. return down_proj
  139. class T5GemmaRotaryEmbedding(Gemma2RotaryEmbedding):
  140. pass
  141. class T5GemmaSelfAttention(Gemma2Attention):
  142. def __init__(self, config: T5GemmaModuleConfig, layer_idx: int):
  143. super().__init__(config, layer_idx)
  144. # Required by flash attention: encoder selfattention is non-causal
  145. self.is_causal = config.is_decoder
  146. class T5GemmaCrossAttention(Gemma2Attention):
  147. def __init__(self, config: T5GemmaModuleConfig, layer_idx: int):
  148. super().__init__(config, layer_idx)
  149. del self.sliding_window
  150. del self.layer_type
  151. self.is_causal = False
  152. if config.cross_attention_hidden_size is None:
  153. raise ValueError("Cross-attention needs cross_attention_hidden_size to be specified.")
  154. self.k_proj = nn.Linear(
  155. config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  156. )
  157. self.v_proj = nn.Linear(
  158. config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  159. )
  160. def forward(
  161. self,
  162. hidden_states: torch.Tensor,
  163. attention_mask: torch.Tensor | None,
  164. encoder_hidden_states: torch.Tensor | None,
  165. past_key_values: Cache | None = None,
  166. **kwargs: Unpack[FlashAttentionKwargs],
  167. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  168. if encoder_hidden_states is None:
  169. raise ValueError("Encoder hidden state is required for cross attention.")
  170. input_shape = hidden_states.shape[:-1]
  171. hidden_shape = (*input_shape, -1, self.head_dim)
  172. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  173. if past_key_values is not None:
  174. is_updated = past_key_values.is_updated.get(self.layer_idx)
  175. curr_past_key_values = past_key_values.cross_attention_cache
  176. if past_key_values is None or not is_updated:
  177. encoder_input_shape = encoder_hidden_states.shape[:-1]
  178. encoder_hidden_shape = (*encoder_input_shape, -1, self.head_dim)
  179. key_states = self.k_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
  180. value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
  181. if past_key_values is not None:
  182. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  183. past_key_values.is_updated[self.layer_idx] = True
  184. else:
  185. key_states = curr_past_key_values.layers[self.layer_idx].keys
  186. value_states = curr_past_key_values.layers[self.layer_idx].values
  187. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  188. self.config._attn_implementation, eager_attention_forward
  189. )
  190. attn_output, attn_weights = attention_interface(
  191. self,
  192. query_states,
  193. key_states,
  194. value_states,
  195. attention_mask,
  196. dropout=self.attention_dropout if self.training else 0.0,
  197. scaling=self.scaling,
  198. sliding_window=None,
  199. softcap=self.attn_logit_softcapping,
  200. **kwargs,
  201. )
  202. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  203. attn_output = self.o_proj(attn_output)
  204. return attn_output, attn_weights
  205. class T5GemmaEncoderLayer(GradientCheckpointingLayer):
  206. """Encoder sub-layer."""
  207. def __init__(self, config, layer_idx: int):
  208. super().__init__()
  209. self.hidden_size = config.hidden_size
  210. self.config = config
  211. self.layer_idx = layer_idx
  212. self.attention_type = config.layer_types[layer_idx]
  213. self.self_attn = T5GemmaSelfAttention(
  214. config=config,
  215. layer_idx=layer_idx,
  216. )
  217. self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  218. self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  219. self.mlp = T5GemmaMLP(config)
  220. self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  221. self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  222. self.dropout = nn.Dropout(config.dropout_rate)
  223. def forward(
  224. self,
  225. hidden_states: torch.Tensor,
  226. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  227. attention_mask: torch.Tensor | None = None,
  228. position_ids: torch.LongTensor | None = None,
  229. **kwargs,
  230. ) -> tuple[torch.FloatTensor,]:
  231. residual = hidden_states
  232. hidden_states = self.pre_self_attn_layernorm(hidden_states)
  233. hidden_states, _ = self.self_attn(
  234. hidden_states=hidden_states,
  235. position_embeddings=position_embeddings,
  236. attention_mask=attention_mask,
  237. position_ids=position_ids,
  238. past_key_values=None,
  239. **kwargs,
  240. )
  241. hidden_states = self.post_self_attn_layernorm(hidden_states)
  242. hidden_states = residual + self.dropout(hidden_states)
  243. residual = hidden_states
  244. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  245. hidden_states = self.mlp(hidden_states)
  246. hidden_states = self.post_feedforward_layernorm(hidden_states)
  247. hidden_states = residual + self.dropout(hidden_states)
  248. return hidden_states
  249. class T5GemmaDecoderLayer(GradientCheckpointingLayer):
  250. """Decoder sub-layer: an extra cross-attention layer."""
  251. def __init__(self, config, layer_idx: int):
  252. super().__init__()
  253. self.hidden_size = config.hidden_size
  254. self.config = config
  255. self.layer_idx = layer_idx
  256. self.attention_type = config.layer_types[layer_idx]
  257. self.self_attn = T5GemmaSelfAttention(
  258. config=config,
  259. layer_idx=layer_idx,
  260. )
  261. self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  262. self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  263. self.mlp = T5GemmaMLP(config)
  264. self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  265. self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  266. self.dropout = nn.Dropout(config.dropout_rate)
  267. self.cross_attn = T5GemmaCrossAttention(config=config, layer_idx=layer_idx)
  268. self.pre_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  269. self.post_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  270. def forward(
  271. self,
  272. hidden_states: torch.Tensor,
  273. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  274. attention_mask: torch.Tensor | None = None,
  275. position_ids: torch.LongTensor | None = None,
  276. past_key_values: EncoderDecoderCache | None = None,
  277. use_cache: bool | None = False,
  278. encoder_hidden_states: torch.Tensor | None = None,
  279. encoder_attention_mask: torch.Tensor | None = None,
  280. **kwargs,
  281. ) -> torch.FloatTensor:
  282. residual = hidden_states
  283. hidden_states = self.pre_self_attn_layernorm(hidden_states)
  284. hidden_states, _ = self.self_attn(
  285. hidden_states=hidden_states,
  286. position_embeddings=position_embeddings,
  287. attention_mask=attention_mask,
  288. position_ids=position_ids,
  289. past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
  290. use_cache=use_cache,
  291. **kwargs,
  292. )
  293. hidden_states = self.post_self_attn_layernorm(hidden_states)
  294. hidden_states = residual + self.dropout(hidden_states)
  295. residual = hidden_states
  296. hidden_states = self.pre_cross_attn_layernorm(hidden_states)
  297. hidden_states, _ = self.cross_attn(
  298. hidden_states=hidden_states,
  299. encoder_hidden_states=encoder_hidden_states,
  300. attention_mask=encoder_attention_mask,
  301. past_key_values=past_key_values,
  302. use_cache=use_cache,
  303. **kwargs,
  304. )
  305. hidden_states = self.post_cross_attn_layernorm(hidden_states)
  306. hidden_states = residual + self.dropout(hidden_states)
  307. residual = hidden_states
  308. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  309. hidden_states = self.mlp(hidden_states)
  310. hidden_states = self.post_feedforward_layernorm(hidden_states)
  311. hidden_states = residual + self.dropout(hidden_states)
  312. return hidden_states
  313. class T5GemmaClassificationHead(nn.Module):
  314. """Head for sentence-level classification tasks."""
  315. def __init__(self, hidden_size: int, num_labels: int, classifier_dropout_rate: float = 0.0):
  316. super().__init__()
  317. self.dropout = nn.Dropout(p=classifier_dropout_rate)
  318. self.out_proj = nn.Linear(hidden_size, num_labels)
  319. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  320. hidden_states = self.dropout(hidden_states)
  321. hidden_states = self.out_proj(hidden_states)
  322. return hidden_states
  323. class T5GemmaLMHead(nn.Module):
  324. """Head for language modeling (generation) tasks."""
  325. def __init__(self, hidden_size: int, vocab_size: int, bias: bool = False):
  326. super().__init__()
  327. self.out_proj = nn.Linear(hidden_size, vocab_size, bias=bias)
  328. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  329. logits = self.out_proj(hidden_states)
  330. return logits
  331. @auto_docstring
  332. class T5GemmaPreTrainedModel(Gemma2PreTrainedModel):
  333. config: T5GemmaConfig
  334. base_model_prefix = "model"
  335. supports_gradient_checkpointing = True
  336. _no_split_modules = ["T5GemmaEncoderLayer", "T5GemmaDecoderLayer"]
  337. _can_record_outputs = {
  338. "hidden_states": T5GemmaDecoderLayer,
  339. "attentions": [
  340. OutputRecorder(T5GemmaSelfAttention, index=1, layer_name="self_attn"),
  341. OutputRecorder(T5GemmaSelfAttention, index=1, layer_name="cross_attn"),
  342. OutputRecorder(T5GemmaCrossAttention, index=1, layer_name="cross_attn"),
  343. ],
  344. }
  345. @torch.no_grad()
  346. def _init_weights(self, module):
  347. # TODO: support initialization for encoders and decoders separately(?)
  348. PreTrainedModel._init_weights(self, module)
  349. std = self.config.initializer_range
  350. if isinstance(module, T5GemmaClassificationHead):
  351. scale = module.out_proj.weight.shape[0] ** -0.5
  352. init.normal_(module.out_proj.weight, mean=0.0, std=std * scale)
  353. if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
  354. init.zeros_(module.out_proj.bias)
  355. elif isinstance(module, T5GemmaLMHead):
  356. if not self.config.tie_word_embeddings:
  357. scale = module.out_proj.weight.shape[0] ** -0.5
  358. init.normal_(module.out_proj.weight, mean=0.0, std=std * scale)
  359. # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
  360. elif "RMSNorm" in module.__class__.__name__:
  361. init.zeros_(module.weight)
  362. def _shift_right(self, input_ids):
  363. """
  364. Shifts input_ids to the right, prepends the decoder_start_token_id, and handles
  365. pad_token_id replacement for labels that were -100.
  366. This is a common preparation step for decoder inputs in sequence-to-sequence models.
  367. """
  368. decoder_start_token_id = self.config.decoder.bos_token_id
  369. pad_token_id = self.config.decoder.pad_token_id
  370. if decoder_start_token_id is None:
  371. raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ")
  372. # shift inputs to the right
  373. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  374. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  375. shifted_input_ids[..., 0] = decoder_start_token_id
  376. if pad_token_id is None:
  377. raise ValueError("self.model.config.decoder.pad_token_id has to be defined.")
  378. # Is this T5 specific?
  379. # replace possible -100 values in labels by `pad_token_id`
  380. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  381. return shifted_input_ids
  382. def make_default_2d_attention_mask(
  383. token_ids: torch.LongTensor | None,
  384. hidden_states: torch.Tensor,
  385. pad_token_id: int | None,
  386. ) -> torch.Tensor:
  387. """Construct the default attention mask."""
  388. if token_ids is not None:
  389. if pad_token_id is None:
  390. raise ValueError("`pad_token_id` is required for padding information.")
  391. attention_mask = (token_ids != pad_token_id).to(hidden_states.device, torch.long)
  392. else:
  393. attention_mask = torch.ones(
  394. (hidden_states.shape[0], hidden_states.shape[1]), device=hidden_states.device, dtype=torch.long
  395. )
  396. return attention_mask
  397. class T5GemmaEncoder(T5GemmaPreTrainedModel):
  398. _can_record_outputs = {
  399. "attentions": T5GemmaSelfAttention,
  400. "hidden_states": T5GemmaEncoderLayer,
  401. }
  402. def __init__(self, config):
  403. super().__init__(config)
  404. self.padding_idx = config.pad_token_id
  405. self.vocab_size = config.vocab_size
  406. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  407. self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  408. self.gradient_checkpointing = False
  409. self.layers = nn.ModuleList(
  410. [T5GemmaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  411. )
  412. self.dropout = nn.Dropout(config.dropout_rate)
  413. self.rotary_emb = T5GemmaRotaryEmbedding(config=config)
  414. # Initialize weights and apply final processing
  415. self.post_init()
  416. @merge_with_config_defaults
  417. @capture_outputs
  418. def forward(
  419. self,
  420. input_ids: torch.LongTensor | None = None,
  421. attention_mask: torch.Tensor | None = None,
  422. position_ids: torch.LongTensor | None = None,
  423. inputs_embeds: torch.FloatTensor | None = None,
  424. **kwargs: Unpack[TransformersKwargs],
  425. ) -> tuple | BaseModelOutput:
  426. if (input_ids is None) ^ (inputs_embeds is not None):
  427. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  428. # As we want to pass `past_key_values=None` explicitly everywhere, we need to pop them from kwargs if present
  429. kwargs.pop("past_key_values", None)
  430. if inputs_embeds is None:
  431. inputs_embeds = self.embed_tokens(input_ids)
  432. if position_ids is None:
  433. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
  434. position_ids = position_ids.unsqueeze(0)
  435. if attention_mask is None:
  436. attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id)
  437. if not isinstance(self_attn_mask_mapping := attention_mask, dict):
  438. mask_kwargs = {
  439. "config": self.config,
  440. "inputs_embeds": inputs_embeds,
  441. "attention_mask": attention_mask,
  442. }
  443. self_attn_mask_mapping = {
  444. "full_attention": create_bidirectional_mask(**mask_kwargs),
  445. "sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs),
  446. }
  447. hidden_states = inputs_embeds
  448. normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
  449. hidden_states = hidden_states * normalizer
  450. hidden_states = self.dropout(hidden_states)
  451. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  452. for i, layer_module in enumerate(self.layers[: self.config.num_hidden_layers]):
  453. hidden_states = layer_module(
  454. hidden_states,
  455. position_embeddings,
  456. self_attn_mask_mapping[self.config.layer_types[i]],
  457. position_ids,
  458. **kwargs,
  459. )
  460. hidden_states = self.norm(hidden_states)
  461. hidden_states = self.dropout(hidden_states)
  462. return BaseModelOutput(
  463. last_hidden_state=hidden_states,
  464. )
  465. class T5GemmaDecoder(T5GemmaPreTrainedModel):
  466. _can_record_outputs = {
  467. "attentions": OutputRecorder(T5GemmaSelfAttention, index=1),
  468. "cross_attentions": OutputRecorder(T5GemmaCrossAttention, index=1),
  469. "hidden_states": T5GemmaDecoderLayer,
  470. }
  471. def __init__(self, config):
  472. super().__init__(config)
  473. self.padding_idx = config.pad_token_id
  474. self.vocab_size = config.vocab_size
  475. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  476. self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  477. self.gradient_checkpointing = False
  478. self.layers = nn.ModuleList(
  479. [T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  480. )
  481. self.dropout = nn.Dropout(config.dropout_rate)
  482. self.rotary_emb = T5GemmaRotaryEmbedding(config=config)
  483. # Initialize weights and apply final processing
  484. self.post_init()
  485. @merge_with_config_defaults
  486. @capture_outputs
  487. def forward(
  488. self,
  489. input_ids: torch.LongTensor | None = None,
  490. attention_mask: torch.Tensor | None = None,
  491. position_ids: torch.LongTensor | None = None,
  492. past_key_values: EncoderDecoderCache | None = None,
  493. inputs_embeds: torch.FloatTensor | None = None,
  494. use_cache: bool | None = None,
  495. encoder_hidden_states: torch.Tensor | None = None,
  496. encoder_attention_mask: torch.Tensor | None = None,
  497. **kwargs: Unpack[TransformersKwargs],
  498. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  499. if (input_ids is None) ^ (inputs_embeds is not None):
  500. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  501. if encoder_hidden_states is None:
  502. raise ValueError("`encoder_hidden_states` must be given in decoder")
  503. if inputs_embeds is None:
  504. inputs_embeds = self.embed_tokens(input_ids)
  505. if not self.training and use_cache and past_key_values is None:
  506. # We do not pass the config to the cross attn cache to avoid initializing SWA
  507. # --> we use full attention between our cross attentions
  508. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache())
  509. if position_ids is None:
  510. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  511. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  512. position_ids = position_ids.unsqueeze(0)
  513. if attention_mask is None and past_key_values is None:
  514. attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id)
  515. if not isinstance(self_attn_mask_mapping := attention_mask, dict):
  516. mask_kwargs = {
  517. "config": self.config,
  518. "inputs_embeds": inputs_embeds,
  519. "attention_mask": attention_mask,
  520. "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None,
  521. "position_ids": position_ids,
  522. }
  523. self_attn_mask_mapping = {
  524. "full_attention": create_causal_mask(**mask_kwargs),
  525. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  526. }
  527. if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict):
  528. cross_attn_mask_mapping = {
  529. "full_attention": create_bidirectional_mask(
  530. config=self.config,
  531. inputs_embeds=inputs_embeds,
  532. attention_mask=encoder_attention_mask,
  533. encoder_hidden_states=encoder_hidden_states,
  534. )
  535. }
  536. hidden_states = inputs_embeds
  537. normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
  538. hidden_states = hidden_states * normalizer
  539. hidden_states = self.dropout(hidden_states)
  540. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  541. for i, layer_module in enumerate(self.layers[: self.config.num_hidden_layers]):
  542. hidden_states = layer_module(
  543. hidden_states,
  544. position_embeddings,
  545. self_attn_mask_mapping[self.config.layer_types[i]],
  546. position_ids,
  547. past_key_values,
  548. use_cache,
  549. encoder_hidden_states,
  550. cross_attn_mask_mapping["full_attention"],
  551. **kwargs,
  552. )
  553. hidden_states = self.norm(hidden_states)
  554. hidden_states = self.dropout(hidden_states)
  555. return BaseModelOutputWithPastAndCrossAttentions(
  556. last_hidden_state=hidden_states,
  557. past_key_values=past_key_values,
  558. )
  559. @auto_docstring
  560. class T5GemmaModel(T5GemmaPreTrainedModel):
  561. def __init__(self, config: T5GemmaConfig):
  562. super().__init__(config)
  563. if not config.is_encoder_decoder:
  564. raise ValueError("T5GemmaModel only support encoder-decoder modeling. Use `T5GemmaEncoderModel` instead.")
  565. self.encoder = T5GemmaEncoder(config.encoder)
  566. self.decoder = T5GemmaDecoder(config.decoder)
  567. self.post_init()
  568. def get_input_embeddings(self):
  569. return self.encoder.get_input_embeddings()
  570. def set_input_embeddings(self, new_embeddings):
  571. return self.encoder.set_input_embeddings(new_embeddings)
  572. @can_return_tuple
  573. @auto_docstring
  574. def forward(
  575. self,
  576. input_ids: torch.LongTensor | None = None,
  577. attention_mask: torch.FloatTensor | None = None,
  578. position_ids: torch.LongTensor | None = None,
  579. decoder_input_ids: torch.LongTensor | None = None,
  580. decoder_attention_mask: torch.BoolTensor | None = None,
  581. decoder_position_ids: torch.LongTensor | None = None,
  582. encoder_outputs: BaseModelOutput | None = None,
  583. past_key_values: EncoderDecoderCache | None = None,
  584. inputs_embeds: torch.Tensor | None = None,
  585. decoder_inputs_embeds: torch.Tensor | None = None,
  586. use_cache: bool | None = None,
  587. **kwargs: Unpack[TransformersKwargs],
  588. ) -> Seq2SeqModelOutput:
  589. r"""
  590. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  591. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  592. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  593. """
  594. if encoder_outputs is None:
  595. encoder_outputs = self.encoder(
  596. input_ids=input_ids,
  597. attention_mask=attention_mask,
  598. position_ids=position_ids,
  599. inputs_embeds=inputs_embeds,
  600. **kwargs,
  601. )
  602. encoder_hidden_states = encoder_outputs.last_hidden_state
  603. decoder_outputs = self.decoder(
  604. input_ids=decoder_input_ids,
  605. attention_mask=decoder_attention_mask,
  606. position_ids=decoder_position_ids,
  607. inputs_embeds=decoder_inputs_embeds,
  608. past_key_values=past_key_values,
  609. encoder_hidden_states=encoder_hidden_states,
  610. encoder_attention_mask=attention_mask,
  611. use_cache=use_cache,
  612. **kwargs,
  613. )
  614. return Seq2SeqModelOutput(
  615. last_hidden_state=decoder_outputs.last_hidden_state,
  616. past_key_values=decoder_outputs.past_key_values,
  617. decoder_hidden_states=decoder_outputs.hidden_states
  618. if kwargs.get("output_hidden_states", False)
  619. else (decoder_outputs.last_hidden_state,),
  620. decoder_attentions=decoder_outputs.attentions,
  621. cross_attentions=decoder_outputs.cross_attentions,
  622. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  623. encoder_hidden_states=encoder_outputs.hidden_states,
  624. encoder_attentions=encoder_outputs.attentions,
  625. )
  626. @auto_docstring
  627. class T5GemmaEncoderModel(T5GemmaPreTrainedModel):
  628. def __init__(self, config: T5GemmaConfig):
  629. super().__init__(config)
  630. if config.is_encoder_decoder:
  631. raise ValueError("T5GemmaEncoderModel only supports encoder-only model. Use `T5GemmaModel` instead.")
  632. self.encoder = T5GemmaEncoder(config.encoder)
  633. self.post_init()
  634. def get_input_embeddings(self):
  635. return self.encoder.get_input_embeddings()
  636. def set_input_embeddings(self, new_embeddings):
  637. return self.encoder.set_input_embeddings(new_embeddings)
  638. @can_return_tuple
  639. @auto_docstring
  640. def forward(
  641. self,
  642. input_ids: torch.LongTensor | None = None,
  643. attention_mask: torch.FloatTensor | None = None,
  644. position_ids: torch.LongTensor | None = None,
  645. inputs_embeds: torch.Tensor | None = None,
  646. **kwargs: Unpack[TransformersKwargs],
  647. ) -> BaseModelOutput:
  648. encoder_outputs = self.encoder(
  649. input_ids=input_ids,
  650. attention_mask=attention_mask,
  651. position_ids=position_ids,
  652. inputs_embeds=inputs_embeds,
  653. **kwargs,
  654. )
  655. return encoder_outputs
  656. class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
  657. _tied_weights_keys = {"lm_head.out_proj.weight": "model.decoder.embed_tokens.weight"}
  658. _tp_plan = {"lm_head.out_proj": "colwise_gather_output"}
  659. _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])}
  660. def __init__(self, config: T5GemmaConfig):
  661. config.is_encoder_decoder = True
  662. super().__init__(config)
  663. self.model = T5GemmaModel(config)
  664. self.vocab_size = config.decoder.vocab_size
  665. self.lm_head = T5GemmaLMHead(config.decoder.hidden_size, self.vocab_size)
  666. self.loss_type = "ForMaskedLM"
  667. self.post_init()
  668. def set_output_embeddings(self, new_embeddings):
  669. self.lm_head.out_proj = new_embeddings
  670. def get_output_embeddings(self):
  671. return self.lm_head.out_proj
  672. @can_return_tuple
  673. @auto_docstring
  674. def forward(
  675. self,
  676. input_ids: torch.LongTensor | None = None,
  677. attention_mask: torch.FloatTensor | None = None,
  678. position_ids: torch.LongTensor | None = None,
  679. decoder_input_ids: torch.LongTensor | None = None,
  680. decoder_attention_mask: torch.BoolTensor | None = None,
  681. decoder_position_ids: torch.LongTensor | None = None,
  682. encoder_outputs: BaseModelOutput | None = None,
  683. past_key_values: EncoderDecoderCache | None = None,
  684. inputs_embeds: torch.FloatTensor | None = None,
  685. decoder_inputs_embeds: torch.FloatTensor | None = None,
  686. labels: torch.LongTensor | None = None,
  687. use_cache: bool | None = None,
  688. logits_to_keep: int | torch.Tensor = 0,
  689. **kwargs: Unpack[TransformersKwargs],
  690. ) -> tuple[torch.FloatTensor] | Seq2SeqLMOutput:
  691. r"""
  692. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  693. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  694. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  695. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  696. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  697. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  698. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  699. """
  700. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  701. # get decoder inputs from shifting lm labels to the right
  702. decoder_input_ids = self._shift_right(labels)
  703. decoder_outputs: Seq2SeqModelOutput = self.model(
  704. input_ids=input_ids,
  705. attention_mask=attention_mask,
  706. position_ids=position_ids,
  707. decoder_input_ids=decoder_input_ids,
  708. decoder_attention_mask=decoder_attention_mask,
  709. decoder_position_ids=decoder_position_ids,
  710. encoder_outputs=encoder_outputs,
  711. past_key_values=past_key_values,
  712. inputs_embeds=inputs_embeds,
  713. decoder_inputs_embeds=decoder_inputs_embeds,
  714. use_cache=use_cache,
  715. **kwargs,
  716. )
  717. hidden_states = decoder_outputs.last_hidden_state
  718. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  719. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  720. logits = self.lm_head(hidden_states[:, slice_indices, :])
  721. decoder_config = self.get_decoder().config
  722. if decoder_config.final_logit_softcapping is not None:
  723. logits = logits / decoder_config.final_logit_softcapping
  724. logits = torch.tanh(logits)
  725. logits = logits * decoder_config.final_logit_softcapping
  726. loss = None
  727. if labels is not None:
  728. # Input has right-shifted so we directly perform masked lm loss
  729. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  730. return Seq2SeqLMOutput(
  731. loss=loss,
  732. logits=logits,
  733. past_key_values=decoder_outputs.past_key_values,
  734. decoder_hidden_states=decoder_outputs.decoder_hidden_states,
  735. decoder_attentions=decoder_outputs.decoder_attentions,
  736. cross_attentions=decoder_outputs.cross_attentions,
  737. encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state,
  738. encoder_hidden_states=decoder_outputs.encoder_hidden_states,
  739. encoder_attentions=decoder_outputs.encoder_attentions,
  740. )
  741. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  742. return self._shift_right(labels)
  743. @auto_docstring
  744. class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
  745. def __init__(self, config: T5GemmaConfig, is_encoder_decoder: bool | None = None):
  746. r"""
  747. is_encoder_decoder (`Optional`, *optional*):
  748. Whether use encoder_decoder for sequence classification. When set to False, only encoder is used.
  749. """
  750. if is_encoder_decoder is not None:
  751. config.is_encoder_decoder = is_encoder_decoder
  752. super().__init__(config)
  753. self.num_labels = config.num_labels
  754. if config.is_encoder_decoder:
  755. self.model = T5GemmaModel(config)
  756. else:
  757. self.model = T5GemmaEncoderModel(config)
  758. hidden_size = config.encoder.hidden_size
  759. if config.is_encoder_decoder:
  760. hidden_size = config.decoder.hidden_size
  761. classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1)
  762. self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout)
  763. self.post_init()
  764. def get_input_embeddings(self):
  765. return self.model.get_input_embeddings()
  766. def set_input_embeddings(self, value):
  767. self.model.set_input_embeddings(value)
  768. @can_return_tuple
  769. @auto_docstring
  770. def forward(
  771. self,
  772. input_ids: torch.LongTensor | None = None,
  773. attention_mask: torch.Tensor | None = None,
  774. position_ids: torch.LongTensor | None = None,
  775. decoder_input_ids: torch.LongTensor | None = None,
  776. decoder_attention_mask: torch.Tensor | None = None,
  777. decoder_position_ids: torch.LongTensor | None = None,
  778. encoder_outputs: BaseModelOutput | None = None,
  779. inputs_embeds: torch.FloatTensor | None = None,
  780. decoder_inputs_embeds: torch.FloatTensor | None = None,
  781. labels: torch.LongTensor | None = None,
  782. **kwargs: Unpack[TransformersKwargs],
  783. ) -> SequenceClassifierOutput:
  784. r"""
  785. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  786. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  787. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  788. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  789. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  790. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  791. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  792. """
  793. if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None):
  794. raise NotImplementedError(
  795. f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode."
  796. )
  797. # Following T5, we automatically creates decoder_input_ids from input_ids if no decoder_input_ids are provided
  798. if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None):
  799. if input_ids is None:
  800. raise ValueError(
  801. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  802. "passed, `input_ids` cannot be `None`. Please pass either "
  803. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  804. )
  805. decoder_input_ids = self._shift_right(input_ids)
  806. if self.config.is_encoder_decoder:
  807. outputs: Seq2SeqModelOutput = self.model(
  808. input_ids,
  809. attention_mask=attention_mask,
  810. position_ids=position_ids,
  811. decoder_input_ids=decoder_input_ids,
  812. decoder_attention_mask=decoder_attention_mask,
  813. decoder_position_ids=decoder_position_ids,
  814. encoder_outputs=encoder_outputs,
  815. inputs_embeds=inputs_embeds,
  816. decoder_inputs_embeds=decoder_inputs_embeds,
  817. use_cache=False,
  818. **kwargs,
  819. )
  820. last_hidden_state = outputs.last_hidden_state
  821. hidden_states = outputs.decoder_hidden_states
  822. attentions = outputs.decoder_attentions
  823. else:
  824. outputs: BaseModelOutput = self.model(
  825. input_ids,
  826. attention_mask=attention_mask,
  827. position_ids=position_ids,
  828. inputs_embeds=inputs_embeds,
  829. **kwargs,
  830. )
  831. last_hidden_state = outputs.last_hidden_state
  832. hidden_states = outputs.hidden_states
  833. attentions = outputs.attentions
  834. logits = self.score(last_hidden_state)
  835. if input_ids is not None:
  836. batch_size = input_ids.shape[0]
  837. else:
  838. batch_size = inputs_embeds.shape[0]
  839. if self.config.pad_token_id is None and batch_size != 1:
  840. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  841. if self.config.pad_token_id is None:
  842. last_non_pad_token = -1
  843. elif input_ids is not None:
  844. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  845. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  846. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  847. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  848. if self.config.is_encoder_decoder:
  849. last_non_pad_token += 1 # due to the right shift.
  850. last_non_pad_token = torch.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1)
  851. else:
  852. last_non_pad_token = -1
  853. logger.warning_once(
  854. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  855. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  856. )
  857. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  858. loss = None
  859. if labels is not None:
  860. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  861. return SequenceClassifierOutput(
  862. loss=loss,
  863. logits=pooled_logits,
  864. hidden_states=hidden_states,
  865. attentions=attentions,
  866. )
  867. @auto_docstring
  868. class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
  869. def __init__(self, config: T5GemmaConfig, is_encoder_decoder: bool | None = None):
  870. r"""
  871. is_encoder_decoder (`Optional`, *optional*):
  872. Whether use encoder_decoder for token classification. When set to False, only encoder is used.
  873. """
  874. if is_encoder_decoder is not None:
  875. config.is_encoder_decoder = is_encoder_decoder
  876. super().__init__(config)
  877. self.num_labels = config.num_labels
  878. if config.is_encoder_decoder:
  879. self.model = T5GemmaModel(config)
  880. else:
  881. self.model = T5GemmaEncoderModel(config)
  882. hidden_size = config.encoder.hidden_size
  883. if config.is_encoder_decoder:
  884. hidden_size = config.decoder.hidden_size
  885. classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1)
  886. self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout)
  887. self.post_init()
  888. def get_input_embeddings(self):
  889. return self.model.get_input_embeddings()
  890. def set_input_embeddings(self, value):
  891. self.model.set_input_embeddings(value)
  892. @can_return_tuple
  893. @auto_docstring
  894. def forward(
  895. self,
  896. input_ids: torch.LongTensor | None = None,
  897. attention_mask: torch.Tensor | None = None,
  898. position_ids: torch.LongTensor | None = None,
  899. decoder_input_ids: torch.LongTensor | None = None,
  900. decoder_attention_mask: torch.Tensor | None = None,
  901. decoder_position_ids: torch.LongTensor | None = None,
  902. encoder_outputs: BaseModelOutput | None = None,
  903. inputs_embeds: torch.FloatTensor | None = None,
  904. decoder_inputs_embeds: torch.FloatTensor | None = None,
  905. labels: torch.LongTensor | None = None,
  906. **kwargs: Unpack[TransformersKwargs],
  907. ) -> TokenClassifierOutput:
  908. r"""
  909. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  910. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  911. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  912. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  913. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  914. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  915. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  916. """
  917. if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None):
  918. raise NotImplementedError(
  919. f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode."
  920. )
  921. if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None):
  922. if input_ids is None:
  923. raise ValueError(
  924. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  925. "passed, `input_ids` cannot be `None`. Please pass either "
  926. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  927. )
  928. decoder_input_ids = self._shift_right(input_ids)
  929. if self.config.is_encoder_decoder:
  930. outputs: Seq2SeqModelOutput = self.model(
  931. input_ids,
  932. attention_mask=attention_mask,
  933. position_ids=position_ids,
  934. decoder_input_ids=decoder_input_ids,
  935. decoder_attention_mask=decoder_attention_mask,
  936. decoder_position_ids=decoder_position_ids,
  937. encoder_outputs=encoder_outputs,
  938. inputs_embeds=inputs_embeds,
  939. decoder_inputs_embeds=decoder_inputs_embeds,
  940. use_cache=False,
  941. **kwargs,
  942. )
  943. last_hidden_state = outputs.last_hidden_state
  944. hidden_states = outputs.decoder_hidden_states
  945. attentions = outputs.decoder_attentions
  946. else:
  947. outputs: BaseModelOutput = self.model(
  948. input_ids,
  949. attention_mask=attention_mask,
  950. position_ids=position_ids,
  951. inputs_embeds=inputs_embeds,
  952. **kwargs,
  953. )
  954. last_hidden_state = outputs.last_hidden_state
  955. hidden_states = outputs.hidden_states
  956. attentions = outputs.attentions
  957. logits = self.score(last_hidden_state)
  958. loss = None
  959. if labels is not None:
  960. loss = self.loss_function(logits, labels, self.config)
  961. return TokenClassifierOutput(
  962. loss=loss,
  963. logits=logits,
  964. hidden_states=hidden_states,
  965. attentions=attentions,
  966. )
  967. __all__ = [
  968. "T5GemmaConfig",
  969. "T5GemmaModuleConfig",
  970. "T5GemmaForConditionalGeneration",
  971. "T5GemmaModel",
  972. "T5GemmaEncoderModel",
  973. "T5GemmaPreTrainedModel",
  974. "T5GemmaForSequenceClassification",
  975. "T5GemmaForTokenClassification",
  976. ]