modeling_t5gemma.py 56 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/t5gemma/modular_t5gemma.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_t5gemma.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. from typing import Optional
  23. import torch
  24. import torch.nn as nn
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernel_func_from_hub, use_kernelized_func
  30. from ...masking_utils import (
  31. create_bidirectional_mask,
  32. create_bidirectional_sliding_window_mask,
  33. create_causal_mask,
  34. create_sliding_window_causal_mask,
  35. )
  36. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  37. from ...modeling_layers import GradientCheckpointingLayer
  38. from ...modeling_outputs import (
  39. BaseModelOutput,
  40. BaseModelOutputWithPastAndCrossAttentions,
  41. Seq2SeqLMOutput,
  42. Seq2SeqModelOutput,
  43. SequenceClassifierOutput,
  44. TokenClassifierOutput,
  45. )
  46. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  47. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  48. from ...processing_utils import Unpack
  49. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  50. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  51. from ...utils.output_capturing import OutputRecorder, capture_outputs
  52. from .configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
  53. logger = logging.get_logger(__name__)
  54. class T5GemmaRMSNorm(nn.Module):
  55. def __init__(self, dim: int, eps: float = 1e-6):
  56. super().__init__()
  57. self.eps = eps
  58. self.weight = nn.Parameter(torch.zeros(dim))
  59. def _norm(self, x):
  60. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  61. def forward(self, x):
  62. output = self._norm(x.float())
  63. # Llama does x.to(float16) * w whilst T5Gemma is (x * w).to(float16)
  64. # See https://github.com/huggingface/transformers/pull/29402
  65. output = output * (1.0 + self.weight.float())
  66. return output.type_as(x)
  67. def extra_repr(self):
  68. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  69. class T5GemmaMLP(nn.Module):
  70. def __init__(self, config):
  71. super().__init__()
  72. self.config = config
  73. self.hidden_size = config.hidden_size
  74. self.intermediate_size = config.intermediate_size
  75. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  76. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  77. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  78. self.act_fn = ACT2FN[config.hidden_activation]
  79. self.dropout = nn.Dropout(config.dropout_rate)
  80. def forward(self, x):
  81. hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
  82. hidden_states = self.dropout(hidden_states)
  83. down_proj = self.down_proj(hidden_states)
  84. return down_proj
  85. class T5GemmaRotaryEmbedding(nn.Module):
  86. inv_freq: torch.Tensor # fix linting for `register_buffer`
  87. def __init__(self, config: T5GemmaConfig, device=None):
  88. super().__init__()
  89. self.max_seq_len_cached = config.max_position_embeddings
  90. self.original_max_seq_len = config.max_position_embeddings
  91. self.config = config
  92. self.rope_type = self.config.rope_parameters["rope_type"]
  93. rope_init_fn: Callable = self.compute_default_rope_parameters
  94. if self.rope_type != "default":
  95. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  96. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  97. self.register_buffer("inv_freq", inv_freq, persistent=False)
  98. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  99. @staticmethod
  100. def compute_default_rope_parameters(
  101. config: T5GemmaConfig | None = None,
  102. device: Optional["torch.device"] = None,
  103. seq_len: int | None = None,
  104. ) -> tuple["torch.Tensor", float]:
  105. """
  106. Computes the inverse frequencies according to the original RoPE implementation
  107. Args:
  108. config ([`~transformers.PreTrainedConfig`]):
  109. The model configuration.
  110. device (`torch.device`):
  111. The device to use for initialization of the inverse frequencies.
  112. seq_len (`int`, *optional*):
  113. The current sequence length. Unused for this type of RoPE.
  114. Returns:
  115. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  116. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  117. """
  118. base = config.rope_parameters["rope_theta"]
  119. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  120. attention_factor = 1.0 # Unused in this type of RoPE
  121. # Compute the inverse frequencies
  122. inv_freq = 1.0 / (
  123. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  124. )
  125. return inv_freq, attention_factor
  126. @torch.no_grad()
  127. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  128. def forward(self, x, position_ids):
  129. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  130. position_ids_expanded = position_ids[:, None, :].float()
  131. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  132. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  133. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  134. emb = torch.cat((freqs, freqs), dim=-1)
  135. cos = emb.cos() * self.attention_scaling
  136. sin = emb.sin() * self.attention_scaling
  137. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  138. def rotate_half(x):
  139. """Rotates half the hidden dims of the input."""
  140. x1 = x[..., : x.shape[-1] // 2]
  141. x2 = x[..., x.shape[-1] // 2 :]
  142. return torch.cat((-x2, x1), dim=-1)
  143. @use_kernel_func_from_hub("rotary_pos_emb")
  144. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  145. """Applies Rotary Position Embedding to the query and key tensors.
  146. Args:
  147. q (`torch.Tensor`): The query tensor.
  148. k (`torch.Tensor`): The key tensor.
  149. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  150. sin (`torch.Tensor`): The sine part of the rotary embedding.
  151. unsqueeze_dim (`int`, *optional*, defaults to 1):
  152. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  153. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  154. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  155. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  156. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  157. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  158. Returns:
  159. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  160. """
  161. cos = cos.unsqueeze(unsqueeze_dim)
  162. sin = sin.unsqueeze(unsqueeze_dim)
  163. q_embed = (q * cos) + (rotate_half(q) * sin)
  164. k_embed = (k * cos) + (rotate_half(k) * sin)
  165. return q_embed, k_embed
  166. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  167. """
  168. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  169. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  170. """
  171. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  172. if n_rep == 1:
  173. return hidden_states
  174. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  175. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  176. def eager_attention_forward(
  177. module: nn.Module,
  178. query: torch.Tensor,
  179. key: torch.Tensor,
  180. value: torch.Tensor,
  181. attention_mask: torch.Tensor | None,
  182. dropout: float | int = 0.0,
  183. scaling: float | None = None,
  184. softcap: float | None = None,
  185. **kwargs,
  186. ) -> tuple[torch.Tensor, torch.Tensor]:
  187. if scaling is None:
  188. scaling = module.head_dim**-0.5
  189. key_states = repeat_kv(key, module.num_key_value_groups)
  190. value_states = repeat_kv(value, module.num_key_value_groups)
  191. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  192. if softcap is not None:
  193. attn_weights = attn_weights / softcap
  194. attn_weights = torch.tanh(attn_weights)
  195. attn_weights = attn_weights * softcap
  196. if attention_mask is not None:
  197. attn_weights = attn_weights + attention_mask
  198. # upcast attention to fp32
  199. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  200. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  201. attn_output = torch.matmul(attn_weights, value_states)
  202. attn_output = attn_output.transpose(1, 2).contiguous()
  203. return attn_output, attn_weights
  204. @use_kernelized_func(apply_rotary_pos_emb)
  205. class T5GemmaSelfAttention(nn.Module):
  206. """Multi-headed attention from 'Attention Is All You Need' paper"""
  207. def __init__(self, config: T5GemmaModuleConfig, layer_idx: int):
  208. super().__init__()
  209. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  210. self.config = config
  211. self.layer_idx = layer_idx
  212. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  213. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  214. self.scaling = config.query_pre_attn_scalar**-0.5
  215. self.attention_dropout = self.config.attention_dropout
  216. # Required by flash attention: encoder selfattention is non-causal
  217. self.is_causal = config.is_decoder
  218. self.q_proj = nn.Linear(
  219. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  220. )
  221. self.k_proj = nn.Linear(
  222. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  223. )
  224. self.v_proj = nn.Linear(
  225. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  226. )
  227. self.o_proj = nn.Linear(
  228. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  229. )
  230. self.attn_logit_softcapping = self.config.attn_logit_softcapping
  231. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  232. def forward(
  233. self,
  234. hidden_states: torch.Tensor,
  235. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  236. attention_mask: torch.Tensor | None = None,
  237. past_key_values: Cache | None = None,
  238. **kwargs: Unpack[FlashAttentionKwargs],
  239. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  240. input_shape = hidden_states.shape[:-1]
  241. hidden_shape = (*input_shape, -1, self.head_dim)
  242. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  243. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  244. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  245. cos, sin = position_embeddings
  246. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  247. if past_key_values is not None:
  248. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  249. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  250. self.config._attn_implementation, eager_attention_forward
  251. )
  252. attn_output, attn_weights = attention_interface(
  253. self,
  254. query_states,
  255. key_states,
  256. value_states,
  257. attention_mask,
  258. dropout=self.attention_dropout if self.training else 0.0,
  259. scaling=self.scaling,
  260. sliding_window=self.sliding_window,
  261. softcap=self.attn_logit_softcapping,
  262. **kwargs,
  263. )
  264. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  265. attn_output = self.o_proj(attn_output)
  266. return attn_output, attn_weights
  267. @use_kernelized_func(apply_rotary_pos_emb)
  268. class T5GemmaCrossAttention(nn.Module):
  269. """Multi-headed attention from 'Attention Is All You Need' paper"""
  270. def __init__(self, config: T5GemmaModuleConfig, layer_idx: int):
  271. super().__init__()
  272. self.config = config
  273. self.layer_idx = layer_idx
  274. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  275. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  276. self.scaling = config.query_pre_attn_scalar**-0.5
  277. self.attention_dropout = self.config.attention_dropout
  278. self.is_causal = False
  279. self.q_proj = nn.Linear(
  280. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  281. )
  282. self.k_proj = nn.Linear(
  283. config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  284. )
  285. self.v_proj = nn.Linear(
  286. config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  287. )
  288. self.o_proj = nn.Linear(
  289. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  290. )
  291. self.attn_logit_softcapping = self.config.attn_logit_softcapping
  292. if config.cross_attention_hidden_size is None:
  293. raise ValueError("Cross-attention needs cross_attention_hidden_size to be specified.")
  294. def forward(
  295. self,
  296. hidden_states: torch.Tensor,
  297. attention_mask: torch.Tensor | None,
  298. encoder_hidden_states: torch.Tensor | None,
  299. past_key_values: Cache | None = None,
  300. **kwargs: Unpack[FlashAttentionKwargs],
  301. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  302. if encoder_hidden_states is None:
  303. raise ValueError("Encoder hidden state is required for cross attention.")
  304. input_shape = hidden_states.shape[:-1]
  305. hidden_shape = (*input_shape, -1, self.head_dim)
  306. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  307. if past_key_values is not None:
  308. is_updated = past_key_values.is_updated.get(self.layer_idx)
  309. curr_past_key_values = past_key_values.cross_attention_cache
  310. if past_key_values is None or not is_updated:
  311. encoder_input_shape = encoder_hidden_states.shape[:-1]
  312. encoder_hidden_shape = (*encoder_input_shape, -1, self.head_dim)
  313. key_states = self.k_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
  314. value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
  315. if past_key_values is not None:
  316. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  317. past_key_values.is_updated[self.layer_idx] = True
  318. else:
  319. key_states = curr_past_key_values.layers[self.layer_idx].keys
  320. value_states = curr_past_key_values.layers[self.layer_idx].values
  321. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  322. self.config._attn_implementation, eager_attention_forward
  323. )
  324. attn_output, attn_weights = attention_interface(
  325. self,
  326. query_states,
  327. key_states,
  328. value_states,
  329. attention_mask,
  330. dropout=self.attention_dropout if self.training else 0.0,
  331. scaling=self.scaling,
  332. sliding_window=None,
  333. softcap=self.attn_logit_softcapping,
  334. **kwargs,
  335. )
  336. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  337. attn_output = self.o_proj(attn_output)
  338. return attn_output, attn_weights
  339. class T5GemmaEncoderLayer(GradientCheckpointingLayer):
  340. """Encoder sub-layer."""
  341. def __init__(self, config, layer_idx: int):
  342. super().__init__()
  343. self.hidden_size = config.hidden_size
  344. self.config = config
  345. self.layer_idx = layer_idx
  346. self.attention_type = config.layer_types[layer_idx]
  347. self.self_attn = T5GemmaSelfAttention(
  348. config=config,
  349. layer_idx=layer_idx,
  350. )
  351. self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  352. self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  353. self.mlp = T5GemmaMLP(config)
  354. self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  355. self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  356. self.dropout = nn.Dropout(config.dropout_rate)
  357. def forward(
  358. self,
  359. hidden_states: torch.Tensor,
  360. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  361. attention_mask: torch.Tensor | None = None,
  362. position_ids: torch.LongTensor | None = None,
  363. **kwargs,
  364. ) -> tuple[torch.FloatTensor,]:
  365. residual = hidden_states
  366. hidden_states = self.pre_self_attn_layernorm(hidden_states)
  367. hidden_states, _ = self.self_attn(
  368. hidden_states=hidden_states,
  369. position_embeddings=position_embeddings,
  370. attention_mask=attention_mask,
  371. position_ids=position_ids,
  372. past_key_values=None,
  373. **kwargs,
  374. )
  375. hidden_states = self.post_self_attn_layernorm(hidden_states)
  376. hidden_states = residual + self.dropout(hidden_states)
  377. residual = hidden_states
  378. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  379. hidden_states = self.mlp(hidden_states)
  380. hidden_states = self.post_feedforward_layernorm(hidden_states)
  381. hidden_states = residual + self.dropout(hidden_states)
  382. return hidden_states
  383. class T5GemmaDecoderLayer(GradientCheckpointingLayer):
  384. """Decoder sub-layer: an extra cross-attention layer."""
  385. def __init__(self, config, layer_idx: int):
  386. super().__init__()
  387. self.hidden_size = config.hidden_size
  388. self.config = config
  389. self.layer_idx = layer_idx
  390. self.attention_type = config.layer_types[layer_idx]
  391. self.self_attn = T5GemmaSelfAttention(
  392. config=config,
  393. layer_idx=layer_idx,
  394. )
  395. self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  396. self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  397. self.mlp = T5GemmaMLP(config)
  398. self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  399. self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  400. self.dropout = nn.Dropout(config.dropout_rate)
  401. self.cross_attn = T5GemmaCrossAttention(config=config, layer_idx=layer_idx)
  402. self.pre_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  403. self.post_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  404. def forward(
  405. self,
  406. hidden_states: torch.Tensor,
  407. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  408. attention_mask: torch.Tensor | None = None,
  409. position_ids: torch.LongTensor | None = None,
  410. past_key_values: EncoderDecoderCache | None = None,
  411. use_cache: bool | None = False,
  412. encoder_hidden_states: torch.Tensor | None = None,
  413. encoder_attention_mask: torch.Tensor | None = None,
  414. **kwargs,
  415. ) -> torch.FloatTensor:
  416. residual = hidden_states
  417. hidden_states = self.pre_self_attn_layernorm(hidden_states)
  418. hidden_states, _ = self.self_attn(
  419. hidden_states=hidden_states,
  420. position_embeddings=position_embeddings,
  421. attention_mask=attention_mask,
  422. position_ids=position_ids,
  423. past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
  424. use_cache=use_cache,
  425. **kwargs,
  426. )
  427. hidden_states = self.post_self_attn_layernorm(hidden_states)
  428. hidden_states = residual + self.dropout(hidden_states)
  429. residual = hidden_states
  430. hidden_states = self.pre_cross_attn_layernorm(hidden_states)
  431. hidden_states, _ = self.cross_attn(
  432. hidden_states=hidden_states,
  433. encoder_hidden_states=encoder_hidden_states,
  434. attention_mask=encoder_attention_mask,
  435. past_key_values=past_key_values,
  436. use_cache=use_cache,
  437. **kwargs,
  438. )
  439. hidden_states = self.post_cross_attn_layernorm(hidden_states)
  440. hidden_states = residual + self.dropout(hidden_states)
  441. residual = hidden_states
  442. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  443. hidden_states = self.mlp(hidden_states)
  444. hidden_states = self.post_feedforward_layernorm(hidden_states)
  445. hidden_states = residual + self.dropout(hidden_states)
  446. return hidden_states
  447. class T5GemmaClassificationHead(nn.Module):
  448. """Head for sentence-level classification tasks."""
  449. def __init__(self, hidden_size: int, num_labels: int, classifier_dropout_rate: float = 0.0):
  450. super().__init__()
  451. self.dropout = nn.Dropout(p=classifier_dropout_rate)
  452. self.out_proj = nn.Linear(hidden_size, num_labels)
  453. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  454. hidden_states = self.dropout(hidden_states)
  455. hidden_states = self.out_proj(hidden_states)
  456. return hidden_states
  457. class T5GemmaLMHead(nn.Module):
  458. """Head for language modeling (generation) tasks."""
  459. def __init__(self, hidden_size: int, vocab_size: int, bias: bool = False):
  460. super().__init__()
  461. self.out_proj = nn.Linear(hidden_size, vocab_size, bias=bias)
  462. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  463. logits = self.out_proj(hidden_states)
  464. return logits
  465. @auto_docstring
  466. class T5GemmaPreTrainedModel(PreTrainedModel):
  467. config: T5GemmaConfig
  468. base_model_prefix = "model"
  469. supports_gradient_checkpointing = True
  470. _no_split_modules = ["T5GemmaEncoderLayer", "T5GemmaDecoderLayer"]
  471. _skip_keys_device_placement = ["past_key_values"]
  472. _supports_flash_attn = True
  473. _supports_sdpa = True
  474. _supports_flex_attn = True
  475. _can_compile_fullgraph = True
  476. _supports_attention_backend = True
  477. _can_record_outputs = {
  478. "hidden_states": T5GemmaDecoderLayer,
  479. "attentions": [
  480. OutputRecorder(T5GemmaSelfAttention, index=1, layer_name="self_attn"),
  481. OutputRecorder(T5GemmaSelfAttention, index=1, layer_name="cross_attn"),
  482. OutputRecorder(T5GemmaCrossAttention, index=1, layer_name="cross_attn"),
  483. ],
  484. }
  485. @torch.no_grad()
  486. def _init_weights(self, module):
  487. # TODO: support initialization for encoders and decoders separately(?)
  488. super()._init_weights(module)
  489. std = self.config.initializer_range
  490. if isinstance(module, T5GemmaClassificationHead):
  491. scale = module.out_proj.weight.shape[0] ** -0.5
  492. init.normal_(module.out_proj.weight, mean=0.0, std=std * scale)
  493. if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
  494. init.zeros_(module.out_proj.bias)
  495. elif isinstance(module, T5GemmaLMHead):
  496. if not self.config.tie_word_embeddings:
  497. scale = module.out_proj.weight.shape[0] ** -0.5
  498. init.normal_(module.out_proj.weight, mean=0.0, std=std * scale)
  499. # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
  500. elif "RMSNorm" in module.__class__.__name__:
  501. init.zeros_(module.weight)
  502. def _shift_right(self, input_ids):
  503. """
  504. Shifts input_ids to the right, prepends the decoder_start_token_id, and handles
  505. pad_token_id replacement for labels that were -100.
  506. This is a common preparation step for decoder inputs in sequence-to-sequence models.
  507. """
  508. decoder_start_token_id = self.config.decoder.bos_token_id
  509. pad_token_id = self.config.decoder.pad_token_id
  510. if decoder_start_token_id is None:
  511. raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ")
  512. # shift inputs to the right
  513. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  514. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  515. shifted_input_ids[..., 0] = decoder_start_token_id
  516. if pad_token_id is None:
  517. raise ValueError("self.model.config.decoder.pad_token_id has to be defined.")
  518. # Is this T5 specific?
  519. # replace possible -100 values in labels by `pad_token_id`
  520. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  521. return shifted_input_ids
  522. def make_default_2d_attention_mask(
  523. token_ids: torch.LongTensor | None,
  524. hidden_states: torch.Tensor,
  525. pad_token_id: int | None,
  526. ) -> torch.Tensor:
  527. """Construct the default attention mask."""
  528. if token_ids is not None:
  529. if pad_token_id is None:
  530. raise ValueError("`pad_token_id` is required for padding information.")
  531. attention_mask = (token_ids != pad_token_id).to(hidden_states.device, torch.long)
  532. else:
  533. attention_mask = torch.ones(
  534. (hidden_states.shape[0], hidden_states.shape[1]), device=hidden_states.device, dtype=torch.long
  535. )
  536. return attention_mask
  537. class T5GemmaEncoder(T5GemmaPreTrainedModel):
  538. _can_record_outputs = {
  539. "attentions": T5GemmaSelfAttention,
  540. "hidden_states": T5GemmaEncoderLayer,
  541. }
  542. def __init__(self, config):
  543. super().__init__(config)
  544. self.padding_idx = config.pad_token_id
  545. self.vocab_size = config.vocab_size
  546. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  547. self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  548. self.gradient_checkpointing = False
  549. self.layers = nn.ModuleList(
  550. [T5GemmaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  551. )
  552. self.dropout = nn.Dropout(config.dropout_rate)
  553. self.rotary_emb = T5GemmaRotaryEmbedding(config=config)
  554. # Initialize weights and apply final processing
  555. self.post_init()
  556. @merge_with_config_defaults
  557. @capture_outputs
  558. def forward(
  559. self,
  560. input_ids: torch.LongTensor | None = None,
  561. attention_mask: torch.Tensor | None = None,
  562. position_ids: torch.LongTensor | None = None,
  563. inputs_embeds: torch.FloatTensor | None = None,
  564. **kwargs: Unpack[TransformersKwargs],
  565. ) -> tuple | BaseModelOutput:
  566. if (input_ids is None) ^ (inputs_embeds is not None):
  567. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  568. # As we want to pass `past_key_values=None` explicitly everywhere, we need to pop them from kwargs if present
  569. kwargs.pop("past_key_values", None)
  570. if inputs_embeds is None:
  571. inputs_embeds = self.embed_tokens(input_ids)
  572. if position_ids is None:
  573. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
  574. position_ids = position_ids.unsqueeze(0)
  575. if attention_mask is None:
  576. attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id)
  577. if not isinstance(self_attn_mask_mapping := attention_mask, dict):
  578. mask_kwargs = {
  579. "config": self.config,
  580. "inputs_embeds": inputs_embeds,
  581. "attention_mask": attention_mask,
  582. }
  583. self_attn_mask_mapping = {
  584. "full_attention": create_bidirectional_mask(**mask_kwargs),
  585. "sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs),
  586. }
  587. hidden_states = inputs_embeds
  588. normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
  589. hidden_states = hidden_states * normalizer
  590. hidden_states = self.dropout(hidden_states)
  591. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  592. for i, layer_module in enumerate(self.layers[: self.config.num_hidden_layers]):
  593. hidden_states = layer_module(
  594. hidden_states,
  595. position_embeddings,
  596. self_attn_mask_mapping[self.config.layer_types[i]],
  597. position_ids,
  598. **kwargs,
  599. )
  600. hidden_states = self.norm(hidden_states)
  601. hidden_states = self.dropout(hidden_states)
  602. return BaseModelOutput(
  603. last_hidden_state=hidden_states,
  604. )
  605. class T5GemmaDecoder(T5GemmaPreTrainedModel):
  606. _can_record_outputs = {
  607. "attentions": OutputRecorder(T5GemmaSelfAttention, index=1),
  608. "cross_attentions": OutputRecorder(T5GemmaCrossAttention, index=1),
  609. "hidden_states": T5GemmaDecoderLayer,
  610. }
  611. def __init__(self, config):
  612. super().__init__(config)
  613. self.padding_idx = config.pad_token_id
  614. self.vocab_size = config.vocab_size
  615. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  616. self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  617. self.gradient_checkpointing = False
  618. self.layers = nn.ModuleList(
  619. [T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  620. )
  621. self.dropout = nn.Dropout(config.dropout_rate)
  622. self.rotary_emb = T5GemmaRotaryEmbedding(config=config)
  623. # Initialize weights and apply final processing
  624. self.post_init()
  625. @merge_with_config_defaults
  626. @capture_outputs
  627. def forward(
  628. self,
  629. input_ids: torch.LongTensor | None = None,
  630. attention_mask: torch.Tensor | None = None,
  631. position_ids: torch.LongTensor | None = None,
  632. past_key_values: EncoderDecoderCache | None = None,
  633. inputs_embeds: torch.FloatTensor | None = None,
  634. use_cache: bool | None = None,
  635. encoder_hidden_states: torch.Tensor | None = None,
  636. encoder_attention_mask: torch.Tensor | None = None,
  637. **kwargs: Unpack[TransformersKwargs],
  638. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  639. if (input_ids is None) ^ (inputs_embeds is not None):
  640. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  641. if encoder_hidden_states is None:
  642. raise ValueError("`encoder_hidden_states` must be given in decoder")
  643. if inputs_embeds is None:
  644. inputs_embeds = self.embed_tokens(input_ids)
  645. if not self.training and use_cache and past_key_values is None:
  646. # We do not pass the config to the cross attn cache to avoid initializing SWA
  647. # --> we use full attention between our cross attentions
  648. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache())
  649. if position_ids is None:
  650. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  651. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  652. position_ids = position_ids.unsqueeze(0)
  653. if attention_mask is None and past_key_values is None:
  654. attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id)
  655. if not isinstance(self_attn_mask_mapping := attention_mask, dict):
  656. mask_kwargs = {
  657. "config": self.config,
  658. "inputs_embeds": inputs_embeds,
  659. "attention_mask": attention_mask,
  660. "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None,
  661. "position_ids": position_ids,
  662. }
  663. self_attn_mask_mapping = {
  664. "full_attention": create_causal_mask(**mask_kwargs),
  665. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  666. }
  667. if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict):
  668. cross_attn_mask_mapping = {
  669. "full_attention": create_bidirectional_mask(
  670. config=self.config,
  671. inputs_embeds=inputs_embeds,
  672. attention_mask=encoder_attention_mask,
  673. encoder_hidden_states=encoder_hidden_states,
  674. )
  675. }
  676. hidden_states = inputs_embeds
  677. normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
  678. hidden_states = hidden_states * normalizer
  679. hidden_states = self.dropout(hidden_states)
  680. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  681. for i, layer_module in enumerate(self.layers[: self.config.num_hidden_layers]):
  682. hidden_states = layer_module(
  683. hidden_states,
  684. position_embeddings,
  685. self_attn_mask_mapping[self.config.layer_types[i]],
  686. position_ids,
  687. past_key_values,
  688. use_cache,
  689. encoder_hidden_states,
  690. cross_attn_mask_mapping["full_attention"],
  691. **kwargs,
  692. )
  693. hidden_states = self.norm(hidden_states)
  694. hidden_states = self.dropout(hidden_states)
  695. return BaseModelOutputWithPastAndCrossAttentions(
  696. last_hidden_state=hidden_states,
  697. past_key_values=past_key_values,
  698. )
  699. @auto_docstring
  700. class T5GemmaModel(T5GemmaPreTrainedModel):
  701. def __init__(self, config: T5GemmaConfig):
  702. super().__init__(config)
  703. if not config.is_encoder_decoder:
  704. raise ValueError("T5GemmaModel only support encoder-decoder modeling. Use `T5GemmaEncoderModel` instead.")
  705. self.encoder = T5GemmaEncoder(config.encoder)
  706. self.decoder = T5GemmaDecoder(config.decoder)
  707. self.post_init()
  708. def get_input_embeddings(self):
  709. return self.encoder.get_input_embeddings()
  710. def set_input_embeddings(self, new_embeddings):
  711. return self.encoder.set_input_embeddings(new_embeddings)
  712. @can_return_tuple
  713. @auto_docstring
  714. def forward(
  715. self,
  716. input_ids: torch.LongTensor | None = None,
  717. attention_mask: torch.FloatTensor | None = None,
  718. position_ids: torch.LongTensor | None = None,
  719. decoder_input_ids: torch.LongTensor | None = None,
  720. decoder_attention_mask: torch.BoolTensor | None = None,
  721. decoder_position_ids: torch.LongTensor | None = None,
  722. encoder_outputs: BaseModelOutput | None = None,
  723. past_key_values: EncoderDecoderCache | None = None,
  724. inputs_embeds: torch.Tensor | None = None,
  725. decoder_inputs_embeds: torch.Tensor | None = None,
  726. use_cache: bool | None = None,
  727. **kwargs: Unpack[TransformersKwargs],
  728. ) -> Seq2SeqModelOutput:
  729. r"""
  730. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  731. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  732. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  733. """
  734. if encoder_outputs is None:
  735. encoder_outputs = self.encoder(
  736. input_ids=input_ids,
  737. attention_mask=attention_mask,
  738. position_ids=position_ids,
  739. inputs_embeds=inputs_embeds,
  740. **kwargs,
  741. )
  742. encoder_hidden_states = encoder_outputs.last_hidden_state
  743. decoder_outputs = self.decoder(
  744. input_ids=decoder_input_ids,
  745. attention_mask=decoder_attention_mask,
  746. position_ids=decoder_position_ids,
  747. inputs_embeds=decoder_inputs_embeds,
  748. past_key_values=past_key_values,
  749. encoder_hidden_states=encoder_hidden_states,
  750. encoder_attention_mask=attention_mask,
  751. use_cache=use_cache,
  752. **kwargs,
  753. )
  754. return Seq2SeqModelOutput(
  755. last_hidden_state=decoder_outputs.last_hidden_state,
  756. past_key_values=decoder_outputs.past_key_values,
  757. decoder_hidden_states=decoder_outputs.hidden_states
  758. if kwargs.get("output_hidden_states", False)
  759. else (decoder_outputs.last_hidden_state,),
  760. decoder_attentions=decoder_outputs.attentions,
  761. cross_attentions=decoder_outputs.cross_attentions,
  762. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  763. encoder_hidden_states=encoder_outputs.hidden_states,
  764. encoder_attentions=encoder_outputs.attentions,
  765. )
  766. @auto_docstring
  767. class T5GemmaEncoderModel(T5GemmaPreTrainedModel):
  768. def __init__(self, config: T5GemmaConfig):
  769. super().__init__(config)
  770. if config.is_encoder_decoder:
  771. raise ValueError("T5GemmaEncoderModel only supports encoder-only model. Use `T5GemmaModel` instead.")
  772. self.encoder = T5GemmaEncoder(config.encoder)
  773. self.post_init()
  774. def get_input_embeddings(self):
  775. return self.encoder.get_input_embeddings()
  776. def set_input_embeddings(self, new_embeddings):
  777. return self.encoder.set_input_embeddings(new_embeddings)
  778. @can_return_tuple
  779. @auto_docstring
  780. def forward(
  781. self,
  782. input_ids: torch.LongTensor | None = None,
  783. attention_mask: torch.FloatTensor | None = None,
  784. position_ids: torch.LongTensor | None = None,
  785. inputs_embeds: torch.Tensor | None = None,
  786. **kwargs: Unpack[TransformersKwargs],
  787. ) -> BaseModelOutput:
  788. encoder_outputs = self.encoder(
  789. input_ids=input_ids,
  790. attention_mask=attention_mask,
  791. position_ids=position_ids,
  792. inputs_embeds=inputs_embeds,
  793. **kwargs,
  794. )
  795. return encoder_outputs
  796. class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
  797. _tied_weights_keys = {"lm_head.out_proj.weight": "model.decoder.embed_tokens.weight"}
  798. _tp_plan = {"lm_head.out_proj": "colwise_gather_output"}
  799. _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])}
  800. def __init__(self, config: T5GemmaConfig):
  801. config.is_encoder_decoder = True
  802. super().__init__(config)
  803. self.model = T5GemmaModel(config)
  804. self.vocab_size = config.decoder.vocab_size
  805. self.lm_head = T5GemmaLMHead(config.decoder.hidden_size, self.vocab_size)
  806. self.loss_type = "ForMaskedLM"
  807. self.post_init()
  808. def set_output_embeddings(self, new_embeddings):
  809. self.lm_head.out_proj = new_embeddings
  810. def get_output_embeddings(self):
  811. return self.lm_head.out_proj
  812. @can_return_tuple
  813. @auto_docstring
  814. def forward(
  815. self,
  816. input_ids: torch.LongTensor | None = None,
  817. attention_mask: torch.FloatTensor | None = None,
  818. position_ids: torch.LongTensor | None = None,
  819. decoder_input_ids: torch.LongTensor | None = None,
  820. decoder_attention_mask: torch.BoolTensor | None = None,
  821. decoder_position_ids: torch.LongTensor | None = None,
  822. encoder_outputs: BaseModelOutput | None = None,
  823. past_key_values: EncoderDecoderCache | None = None,
  824. inputs_embeds: torch.FloatTensor | None = None,
  825. decoder_inputs_embeds: torch.FloatTensor | None = None,
  826. labels: torch.LongTensor | None = None,
  827. use_cache: bool | None = None,
  828. logits_to_keep: int | torch.Tensor = 0,
  829. **kwargs: Unpack[TransformersKwargs],
  830. ) -> tuple[torch.FloatTensor] | Seq2SeqLMOutput:
  831. r"""
  832. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  833. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  834. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  835. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  836. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  837. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  838. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  839. """
  840. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  841. # get decoder inputs from shifting lm labels to the right
  842. decoder_input_ids = self._shift_right(labels)
  843. decoder_outputs: Seq2SeqModelOutput = self.model(
  844. input_ids=input_ids,
  845. attention_mask=attention_mask,
  846. position_ids=position_ids,
  847. decoder_input_ids=decoder_input_ids,
  848. decoder_attention_mask=decoder_attention_mask,
  849. decoder_position_ids=decoder_position_ids,
  850. encoder_outputs=encoder_outputs,
  851. past_key_values=past_key_values,
  852. inputs_embeds=inputs_embeds,
  853. decoder_inputs_embeds=decoder_inputs_embeds,
  854. use_cache=use_cache,
  855. **kwargs,
  856. )
  857. hidden_states = decoder_outputs.last_hidden_state
  858. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  859. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  860. logits = self.lm_head(hidden_states[:, slice_indices, :])
  861. decoder_config = self.get_decoder().config
  862. if decoder_config.final_logit_softcapping is not None:
  863. logits = logits / decoder_config.final_logit_softcapping
  864. logits = torch.tanh(logits)
  865. logits = logits * decoder_config.final_logit_softcapping
  866. loss = None
  867. if labels is not None:
  868. # Input has right-shifted so we directly perform masked lm loss
  869. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  870. return Seq2SeqLMOutput(
  871. loss=loss,
  872. logits=logits,
  873. past_key_values=decoder_outputs.past_key_values,
  874. decoder_hidden_states=decoder_outputs.decoder_hidden_states,
  875. decoder_attentions=decoder_outputs.decoder_attentions,
  876. cross_attentions=decoder_outputs.cross_attentions,
  877. encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state,
  878. encoder_hidden_states=decoder_outputs.encoder_hidden_states,
  879. encoder_attentions=decoder_outputs.encoder_attentions,
  880. )
  881. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  882. return self._shift_right(labels)
  883. @auto_docstring
  884. class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
  885. def __init__(self, config: T5GemmaConfig, is_encoder_decoder: bool | None = None):
  886. r"""
  887. is_encoder_decoder (`Optional`, *optional*):
  888. Whether use encoder_decoder for sequence classification. When set to False, only encoder is used.
  889. """
  890. if is_encoder_decoder is not None:
  891. config.is_encoder_decoder = is_encoder_decoder
  892. super().__init__(config)
  893. self.num_labels = config.num_labels
  894. if config.is_encoder_decoder:
  895. self.model = T5GemmaModel(config)
  896. else:
  897. self.model = T5GemmaEncoderModel(config)
  898. hidden_size = config.encoder.hidden_size
  899. if config.is_encoder_decoder:
  900. hidden_size = config.decoder.hidden_size
  901. classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1)
  902. self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout)
  903. self.post_init()
  904. def get_input_embeddings(self):
  905. return self.model.get_input_embeddings()
  906. def set_input_embeddings(self, value):
  907. self.model.set_input_embeddings(value)
  908. @can_return_tuple
  909. @auto_docstring
  910. def forward(
  911. self,
  912. input_ids: torch.LongTensor | None = None,
  913. attention_mask: torch.Tensor | None = None,
  914. position_ids: torch.LongTensor | None = None,
  915. decoder_input_ids: torch.LongTensor | None = None,
  916. decoder_attention_mask: torch.Tensor | None = None,
  917. decoder_position_ids: torch.LongTensor | None = None,
  918. encoder_outputs: BaseModelOutput | None = None,
  919. inputs_embeds: torch.FloatTensor | None = None,
  920. decoder_inputs_embeds: torch.FloatTensor | None = None,
  921. labels: torch.LongTensor | None = None,
  922. **kwargs: Unpack[TransformersKwargs],
  923. ) -> SequenceClassifierOutput:
  924. r"""
  925. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  926. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  927. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  928. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  929. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  930. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  931. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  932. """
  933. if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None):
  934. raise NotImplementedError(
  935. f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode."
  936. )
  937. # Following T5, we automatically creates decoder_input_ids from input_ids if no decoder_input_ids are provided
  938. if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None):
  939. if input_ids is None:
  940. raise ValueError(
  941. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  942. "passed, `input_ids` cannot be `None`. Please pass either "
  943. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  944. )
  945. decoder_input_ids = self._shift_right(input_ids)
  946. if self.config.is_encoder_decoder:
  947. outputs: Seq2SeqModelOutput = self.model(
  948. input_ids,
  949. attention_mask=attention_mask,
  950. position_ids=position_ids,
  951. decoder_input_ids=decoder_input_ids,
  952. decoder_attention_mask=decoder_attention_mask,
  953. decoder_position_ids=decoder_position_ids,
  954. encoder_outputs=encoder_outputs,
  955. inputs_embeds=inputs_embeds,
  956. decoder_inputs_embeds=decoder_inputs_embeds,
  957. use_cache=False,
  958. **kwargs,
  959. )
  960. last_hidden_state = outputs.last_hidden_state
  961. hidden_states = outputs.decoder_hidden_states
  962. attentions = outputs.decoder_attentions
  963. else:
  964. outputs: BaseModelOutput = self.model(
  965. input_ids,
  966. attention_mask=attention_mask,
  967. position_ids=position_ids,
  968. inputs_embeds=inputs_embeds,
  969. **kwargs,
  970. )
  971. last_hidden_state = outputs.last_hidden_state
  972. hidden_states = outputs.hidden_states
  973. attentions = outputs.attentions
  974. logits = self.score(last_hidden_state)
  975. if input_ids is not None:
  976. batch_size = input_ids.shape[0]
  977. else:
  978. batch_size = inputs_embeds.shape[0]
  979. if self.config.pad_token_id is None and batch_size != 1:
  980. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  981. if self.config.pad_token_id is None:
  982. last_non_pad_token = -1
  983. elif input_ids is not None:
  984. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  985. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  986. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  987. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  988. if self.config.is_encoder_decoder:
  989. last_non_pad_token += 1 # due to the right shift.
  990. last_non_pad_token = torch.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1)
  991. else:
  992. last_non_pad_token = -1
  993. logger.warning_once(
  994. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  995. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  996. )
  997. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  998. loss = None
  999. if labels is not None:
  1000. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  1001. return SequenceClassifierOutput(
  1002. loss=loss,
  1003. logits=pooled_logits,
  1004. hidden_states=hidden_states,
  1005. attentions=attentions,
  1006. )
  1007. @auto_docstring
  1008. class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
  1009. def __init__(self, config: T5GemmaConfig, is_encoder_decoder: bool | None = None):
  1010. r"""
  1011. is_encoder_decoder (`Optional`, *optional*):
  1012. Whether use encoder_decoder for token classification. When set to False, only encoder is used.
  1013. """
  1014. if is_encoder_decoder is not None:
  1015. config.is_encoder_decoder = is_encoder_decoder
  1016. super().__init__(config)
  1017. self.num_labels = config.num_labels
  1018. if config.is_encoder_decoder:
  1019. self.model = T5GemmaModel(config)
  1020. else:
  1021. self.model = T5GemmaEncoderModel(config)
  1022. hidden_size = config.encoder.hidden_size
  1023. if config.is_encoder_decoder:
  1024. hidden_size = config.decoder.hidden_size
  1025. classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1)
  1026. self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout)
  1027. self.post_init()
  1028. def get_input_embeddings(self):
  1029. return self.model.get_input_embeddings()
  1030. def set_input_embeddings(self, value):
  1031. self.model.set_input_embeddings(value)
  1032. @can_return_tuple
  1033. @auto_docstring
  1034. def forward(
  1035. self,
  1036. input_ids: torch.LongTensor | None = None,
  1037. attention_mask: torch.Tensor | None = None,
  1038. position_ids: torch.LongTensor | None = None,
  1039. decoder_input_ids: torch.LongTensor | None = None,
  1040. decoder_attention_mask: torch.Tensor | None = None,
  1041. decoder_position_ids: torch.LongTensor | None = None,
  1042. encoder_outputs: BaseModelOutput | None = None,
  1043. inputs_embeds: torch.FloatTensor | None = None,
  1044. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1045. labels: torch.LongTensor | None = None,
  1046. **kwargs: Unpack[TransformersKwargs],
  1047. ) -> TokenClassifierOutput:
  1048. r"""
  1049. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
  1050. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
  1051. config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  1052. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1053. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1054. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1055. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1056. """
  1057. if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None):
  1058. raise NotImplementedError(
  1059. f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode."
  1060. )
  1061. if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None):
  1062. if input_ids is None:
  1063. raise ValueError(
  1064. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1065. "passed, `input_ids` cannot be `None`. Please pass either "
  1066. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1067. )
  1068. decoder_input_ids = self._shift_right(input_ids)
  1069. if self.config.is_encoder_decoder:
  1070. outputs: Seq2SeqModelOutput = self.model(
  1071. input_ids,
  1072. attention_mask=attention_mask,
  1073. position_ids=position_ids,
  1074. decoder_input_ids=decoder_input_ids,
  1075. decoder_attention_mask=decoder_attention_mask,
  1076. decoder_position_ids=decoder_position_ids,
  1077. encoder_outputs=encoder_outputs,
  1078. inputs_embeds=inputs_embeds,
  1079. decoder_inputs_embeds=decoder_inputs_embeds,
  1080. use_cache=False,
  1081. **kwargs,
  1082. )
  1083. last_hidden_state = outputs.last_hidden_state
  1084. hidden_states = outputs.decoder_hidden_states
  1085. attentions = outputs.decoder_attentions
  1086. else:
  1087. outputs: BaseModelOutput = self.model(
  1088. input_ids,
  1089. attention_mask=attention_mask,
  1090. position_ids=position_ids,
  1091. inputs_embeds=inputs_embeds,
  1092. **kwargs,
  1093. )
  1094. last_hidden_state = outputs.last_hidden_state
  1095. hidden_states = outputs.hidden_states
  1096. attentions = outputs.attentions
  1097. logits = self.score(last_hidden_state)
  1098. loss = None
  1099. if labels is not None:
  1100. loss = self.loss_function(logits, labels, self.config)
  1101. return TokenClassifierOutput(
  1102. loss=loss,
  1103. logits=logits,
  1104. hidden_states=hidden_states,
  1105. attentions=attentions,
  1106. )
  1107. __all__ = [
  1108. "T5GemmaForConditionalGeneration",
  1109. "T5GemmaModel",
  1110. "T5GemmaEncoderModel",
  1111. "T5GemmaPreTrainedModel",
  1112. "T5GemmaForSequenceClassification",
  1113. "T5GemmaForTokenClassification",
  1114. ]