modeling_mt5.py 74 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682
  1. # Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch mT5 model."""
  15. import copy
  16. import math
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutput,
  28. BaseModelOutputWithPastAndCrossAttentions,
  29. Seq2SeqLMOutput,
  30. Seq2SeqModelOutput,
  31. Seq2SeqQuestionAnsweringModelOutput,
  32. Seq2SeqSequenceClassifierOutput,
  33. TokenClassifierOutput,
  34. )
  35. from ...modeling_utils import PreTrainedModel
  36. from ...utils import DUMMY_INPUTS, DUMMY_MASK, auto_docstring, logging, torch_compilable_check
  37. from .configuration_mt5 import MT5Config
  38. logger = logging.get_logger(__name__)
  39. # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->MT5
  40. class MT5LayerNorm(nn.Module):
  41. def __init__(self, hidden_size, eps=1e-6):
  42. """
  43. Construct a layernorm module in the MT5 style. No bias and no subtraction of mean.
  44. """
  45. super().__init__()
  46. self.weight = nn.Parameter(torch.ones(hidden_size))
  47. self.variance_epsilon = eps
  48. def forward(self, hidden_states):
  49. # MT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  50. # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
  51. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  52. # half-precision inputs is done in fp32
  53. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  54. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  55. # convert into half-precision if necessary
  56. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  57. hidden_states = hidden_states.to(self.weight.dtype)
  58. return self.weight * hidden_states
  59. # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->MT5
  60. class MT5DenseActDense(nn.Module):
  61. def __init__(self, config: MT5Config):
  62. super().__init__()
  63. self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
  64. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  65. self.dropout = nn.Dropout(config.dropout_rate)
  66. self.act = ACT2FN[config.dense_act_fn]
  67. def forward(self, hidden_states):
  68. hidden_states = self.wi(hidden_states)
  69. hidden_states = self.act(hidden_states)
  70. hidden_states = self.dropout(hidden_states)
  71. if (
  72. isinstance(self.wo.weight, torch.Tensor)
  73. and hidden_states.dtype != self.wo.weight.dtype
  74. and self.wo.weight.dtype != torch.int8
  75. ):
  76. hidden_states = hidden_states.to(self.wo.weight.dtype)
  77. hidden_states = self.wo(hidden_states)
  78. return hidden_states
  79. # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->MT5
  80. class MT5DenseGatedActDense(nn.Module):
  81. def __init__(self, config: MT5Config):
  82. super().__init__()
  83. self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
  84. self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
  85. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  86. self.dropout = nn.Dropout(config.dropout_rate)
  87. self.act = ACT2FN[config.dense_act_fn]
  88. def forward(self, hidden_states):
  89. hidden_gelu = self.act(self.wi_0(hidden_states))
  90. hidden_linear = self.wi_1(hidden_states)
  91. hidden_states = hidden_gelu * hidden_linear
  92. hidden_states = self.dropout(hidden_states)
  93. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  94. # See https://github.com/huggingface/transformers/issues/20287
  95. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  96. if (
  97. isinstance(self.wo.weight, torch.Tensor)
  98. and hidden_states.dtype != self.wo.weight.dtype
  99. and self.wo.weight.dtype != torch.int8
  100. ):
  101. hidden_states = hidden_states.to(self.wo.weight.dtype)
  102. hidden_states = self.wo(hidden_states)
  103. return hidden_states
  104. # Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->MT5
  105. class MT5LayerFF(nn.Module):
  106. def __init__(self, config: MT5Config):
  107. super().__init__()
  108. if config.is_gated_act:
  109. self.DenseReluDense = MT5DenseGatedActDense(config)
  110. else:
  111. self.DenseReluDense = MT5DenseActDense(config)
  112. self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  113. self.dropout = nn.Dropout(config.dropout_rate)
  114. def forward(self, hidden_states):
  115. forwarded_states = self.layer_norm(hidden_states)
  116. forwarded_states = self.DenseReluDense(forwarded_states)
  117. hidden_states = hidden_states + self.dropout(forwarded_states)
  118. return hidden_states
  119. # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->MT5
  120. class MT5Attention(nn.Module):
  121. def __init__(
  122. self,
  123. config: MT5Config,
  124. has_relative_attention_bias=False,
  125. layer_idx: int | None = None,
  126. ):
  127. super().__init__()
  128. self.is_decoder = config.is_decoder
  129. self.has_relative_attention_bias = has_relative_attention_bias
  130. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  131. self.relative_attention_max_distance = config.relative_attention_max_distance
  132. self.d_model = config.d_model
  133. self.key_value_proj_dim = config.d_kv
  134. self.n_heads = config.num_heads
  135. self.dropout = config.dropout_rate
  136. self.inner_dim = self.n_heads * self.key_value_proj_dim
  137. self.layer_idx = layer_idx
  138. if layer_idx is None and self.is_decoder:
  139. logger.warning_once(
  140. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  141. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  142. "when creating this class."
  143. )
  144. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  145. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  146. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  147. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  148. if self.has_relative_attention_bias:
  149. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  150. self.gradient_checkpointing = False
  151. @staticmethod
  152. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  153. """
  154. Adapted from Mesh Tensorflow:
  155. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  156. Translate relative position to a bucket number for relative attention. The relative position is defined as
  157. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  158. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  159. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  160. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  161. This should allow for more graceful generalization to longer sequences than the model has been trained on
  162. Args:
  163. relative_position: an int32 Tensor
  164. bidirectional: a boolean - whether the attention is bidirectional
  165. num_buckets: an integer
  166. max_distance: an integer
  167. Returns:
  168. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  169. """
  170. relative_buckets = 0
  171. if bidirectional:
  172. num_buckets //= 2
  173. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  174. relative_position = torch.abs(relative_position)
  175. else:
  176. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  177. # now relative_position is in the range [0, inf)
  178. # half of the buckets are for exact increments in positions
  179. max_exact = num_buckets // 2
  180. is_small = relative_position < max_exact
  181. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  182. relative_position_if_large = max_exact + (
  183. torch.log(relative_position.float() / max_exact)
  184. / math.log(max_distance / max_exact)
  185. * (num_buckets - max_exact)
  186. ).to(torch.long)
  187. relative_position_if_large = torch.min(
  188. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  189. )
  190. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  191. return relative_buckets
  192. def compute_bias(self, query_length, key_length, device=None, past_seen_tokens=0):
  193. """Compute binned relative position bias"""
  194. if device is None:
  195. device = self.relative_attention_bias.weight.device
  196. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + past_seen_tokens
  197. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  198. relative_position = memory_position - context_position # shape (query_length, key_length)
  199. relative_position_bucket = self._relative_position_bucket(
  200. relative_position, # shape (query_length, key_length)
  201. bidirectional=(not self.is_decoder),
  202. num_buckets=self.relative_attention_num_buckets,
  203. max_distance=self.relative_attention_max_distance,
  204. )
  205. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  206. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  207. return values
  208. def forward(
  209. self,
  210. hidden_states,
  211. mask=None,
  212. key_value_states=None,
  213. position_bias=None,
  214. past_key_values=None,
  215. output_attentions=False,
  216. **kwargs,
  217. ):
  218. """
  219. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  220. """
  221. # Input is (batch_size, seq_length, dim)
  222. # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
  223. input_shape = hidden_states.shape[:-1]
  224. hidden_shape = (*input_shape, -1, self.key_value_proj_dim)
  225. past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0
  226. # We clone here for StaticCache, as we get the value before updating it, but use it after and it's the same ref
  227. past_seen_tokens = past_seen_tokens.clone() if isinstance(past_seen_tokens, torch.Tensor) else past_seen_tokens
  228. # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
  229. is_cross_attention = key_value_states is not None
  230. query_states = self.q(hidden_states).view(hidden_shape).transpose(1, 2)
  231. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  232. is_updated = False
  233. if isinstance(past_key_values, EncoderDecoderCache):
  234. is_updated = past_key_values.is_updated.get(self.layer_idx)
  235. if is_cross_attention:
  236. # after the first generated id, we can subsequently re-use all key/value_states from cache
  237. curr_past_key_values = past_key_values.cross_attention_cache
  238. else:
  239. curr_past_key_values = past_key_values.self_attention_cache
  240. else:
  241. curr_past_key_values = past_key_values
  242. current_states = key_value_states if is_cross_attention else hidden_states
  243. if is_cross_attention and past_key_values is not None and is_updated:
  244. # reuse k,v, cross_attentions
  245. key_states = curr_past_key_values.layers[self.layer_idx].keys
  246. value_states = curr_past_key_values.layers[self.layer_idx].values
  247. else:
  248. kv_shape = (*current_states.shape[:-1], -1, self.key_value_proj_dim)
  249. key_states = self.k(current_states).view(kv_shape).transpose(1, 2)
  250. value_states = self.v(current_states).view(kv_shape).transpose(1, 2)
  251. if past_key_values is not None:
  252. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  253. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  254. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  255. past_key_values.is_updated[self.layer_idx] = True
  256. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  257. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  258. if position_bias is None:
  259. key_length = key_states.shape[-2]
  260. if not self.has_relative_attention_bias:
  261. position_bias = torch.zeros(
  262. (1, query_states.shape[1], input_shape[1], key_length), device=scores.device, dtype=scores.dtype
  263. )
  264. if self.gradient_checkpointing and self.training:
  265. position_bias.requires_grad = True
  266. else:
  267. position_bias = self.compute_bias(
  268. input_shape[1], key_length, device=scores.device, past_seen_tokens=past_seen_tokens
  269. )
  270. if mask is not None:
  271. causal_mask = mask[:, :, :, : key_states.shape[-2]]
  272. position_bias = position_bias + causal_mask
  273. position_bias_masked = position_bias
  274. scores += position_bias_masked
  275. # (batch_size, n_heads, seq_length, key_length)
  276. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  277. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  278. attn_output = torch.matmul(attn_weights, value_states)
  279. attn_output = attn_output.transpose(1, 2).contiguous()
  280. attn_output = attn_output.reshape(*input_shape, -1)
  281. attn_output = self.o(attn_output)
  282. outputs = (attn_output, position_bias)
  283. if output_attentions:
  284. outputs = outputs + (attn_weights,)
  285. return outputs
  286. # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->MT5
  287. class MT5LayerSelfAttention(nn.Module):
  288. def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
  289. super().__init__()
  290. self.SelfAttention = MT5Attention(
  291. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  292. )
  293. self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  294. self.dropout = nn.Dropout(config.dropout_rate)
  295. def forward(
  296. self,
  297. hidden_states,
  298. attention_mask=None,
  299. position_bias=None,
  300. past_key_values=None,
  301. use_cache=False,
  302. output_attentions=False,
  303. **kwargs,
  304. ):
  305. normed_hidden_states = self.layer_norm(hidden_states)
  306. attention_output = self.SelfAttention(
  307. normed_hidden_states,
  308. mask=attention_mask,
  309. position_bias=position_bias,
  310. past_key_values=past_key_values,
  311. use_cache=use_cache,
  312. output_attentions=output_attentions,
  313. )
  314. hidden_states = hidden_states + self.dropout(attention_output[0])
  315. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  316. return outputs
  317. # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->MT5
  318. class MT5LayerCrossAttention(nn.Module):
  319. def __init__(self, config, layer_idx: int | None = None):
  320. super().__init__()
  321. self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  322. self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  323. self.dropout = nn.Dropout(config.dropout_rate)
  324. def forward(
  325. self,
  326. hidden_states,
  327. key_value_states,
  328. attention_mask=None,
  329. position_bias=None,
  330. past_key_values=None,
  331. output_attentions=False,
  332. **kwargs,
  333. ):
  334. normed_hidden_states = self.layer_norm(hidden_states)
  335. attention_output = self.EncDecAttention(
  336. normed_hidden_states,
  337. mask=attention_mask,
  338. key_value_states=key_value_states,
  339. position_bias=position_bias,
  340. past_key_values=past_key_values,
  341. output_attentions=output_attentions,
  342. )
  343. layer_output = hidden_states + self.dropout(attention_output[0])
  344. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  345. return outputs
  346. # Copied from transformers.models.t5.modeling_t5.T5Block with T5->MT5
  347. class MT5Block(GradientCheckpointingLayer):
  348. def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
  349. super().__init__()
  350. self.is_decoder = config.is_decoder
  351. self.layer = nn.ModuleList()
  352. self.layer.append(
  353. MT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
  354. )
  355. if self.is_decoder:
  356. self.layer.append(MT5LayerCrossAttention(config, layer_idx=layer_idx))
  357. self.layer.append(MT5LayerFF(config))
  358. def forward(
  359. self,
  360. hidden_states,
  361. attention_mask=None,
  362. position_bias=None,
  363. encoder_hidden_states=None,
  364. encoder_attention_mask=None,
  365. encoder_decoder_position_bias=None,
  366. past_key_values=None,
  367. use_cache=False,
  368. output_attentions=False,
  369. return_dict=True,
  370. **kwargs,
  371. ):
  372. self_attention_outputs = self.layer[0](
  373. hidden_states,
  374. attention_mask=attention_mask,
  375. position_bias=position_bias,
  376. past_key_values=past_key_values,
  377. use_cache=use_cache,
  378. output_attentions=output_attentions,
  379. )
  380. hidden_states = self_attention_outputs[0]
  381. attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
  382. # clamp inf values to enable fp16 training
  383. if hidden_states.dtype == torch.float16:
  384. clamp_value = torch.where(
  385. torch.isinf(hidden_states).any(),
  386. torch.finfo(hidden_states.dtype).max - 1000,
  387. torch.finfo(hidden_states.dtype).max,
  388. )
  389. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  390. do_cross_attention = self.is_decoder and encoder_hidden_states is not None
  391. if do_cross_attention:
  392. cross_attention_outputs = self.layer[1](
  393. hidden_states,
  394. key_value_states=encoder_hidden_states,
  395. attention_mask=encoder_attention_mask,
  396. position_bias=encoder_decoder_position_bias,
  397. past_key_values=past_key_values,
  398. output_attentions=output_attentions,
  399. )
  400. hidden_states = cross_attention_outputs[0]
  401. # clamp inf values to enable fp16 training
  402. if hidden_states.dtype == torch.float16:
  403. clamp_value = torch.where(
  404. torch.isinf(hidden_states).any(),
  405. torch.finfo(hidden_states.dtype).max - 1000,
  406. torch.finfo(hidden_states.dtype).max,
  407. )
  408. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  409. # Keep cross-attention outputs and relative position weights
  410. attention_outputs = attention_outputs + cross_attention_outputs[1:]
  411. # Apply Feed Forward layer
  412. hidden_states = self.layer[-1](hidden_states)
  413. # clamp inf values to enable fp16 training
  414. if hidden_states.dtype == torch.float16:
  415. clamp_value = torch.where(
  416. torch.isinf(hidden_states).any(),
  417. torch.finfo(hidden_states.dtype).max - 1000,
  418. torch.finfo(hidden_states.dtype).max,
  419. )
  420. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  421. outputs = (hidden_states,)
  422. return (
  423. outputs + attention_outputs
  424. ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  425. # Copied from transformers.models.t5.modeling_t5.T5ClassificationHead with T5->MT5
  426. class MT5ClassificationHead(nn.Module):
  427. """Head for sentence-level classification tasks."""
  428. def __init__(self, config: MT5Config):
  429. super().__init__()
  430. self.dense = nn.Linear(config.d_model, config.d_model)
  431. self.dropout = nn.Dropout(p=config.classifier_dropout)
  432. self.out_proj = nn.Linear(config.d_model, config.num_labels)
  433. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  434. hidden_states = self.dropout(hidden_states)
  435. hidden_states = self.dense(hidden_states)
  436. hidden_states = torch.tanh(hidden_states)
  437. hidden_states = self.dropout(hidden_states)
  438. hidden_states = self.out_proj(hidden_states)
  439. return hidden_states
  440. @auto_docstring
  441. # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel with T5->MT5, t5->mt5
  442. class MT5PreTrainedModel(PreTrainedModel):
  443. config: MT5Config
  444. base_model_prefix = "transformer"
  445. supports_gradient_checkpointing = True
  446. _can_compile_fullgraph = True
  447. _no_split_modules = ["MT5Block"]
  448. _keep_in_fp32_modules = ["wo"]
  449. @property
  450. def dummy_inputs(self):
  451. input_ids = torch.tensor(DUMMY_INPUTS)
  452. input_mask = torch.tensor(DUMMY_MASK)
  453. dummy_inputs = {
  454. "decoder_input_ids": input_ids,
  455. "input_ids": input_ids,
  456. "decoder_attention_mask": input_mask,
  457. }
  458. return dummy_inputs
  459. @torch.no_grad()
  460. def _init_weights(self, module):
  461. """Initialize the weights"""
  462. factor = self.config.initializer_factor # Used for testing weights initialization
  463. if isinstance(module, MT5LayerNorm):
  464. init.constant_(module.weight, factor * 1.0)
  465. elif isinstance(
  466. module,
  467. (MT5Model, MT5ForConditionalGeneration, MT5EncoderModel, MT5ForQuestionAnswering),
  468. ):
  469. init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
  470. if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
  471. init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0)
  472. if hasattr(module, "qa_outputs"):
  473. init.normal_(module.qa_outputs.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  474. init.zeros_(module.qa_outputs.bias)
  475. elif isinstance(module, MT5ForTokenClassification):
  476. if hasattr(module, "classifier"):
  477. init.normal_(module.classifier.weight, mean=0.0, std=factor * 1.0)
  478. init.zeros_(module.classifier.bias)
  479. elif isinstance(module, MT5ClassificationHead):
  480. init.normal_(module.dense.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  481. if hasattr(module.dense, "bias") and module.dense.bias is not None:
  482. init.zeros_(module.dense.bias)
  483. init.normal_(module.out_proj.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  484. if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
  485. init.zeros_(module.out_proj.bias)
  486. elif isinstance(module, MT5DenseActDense):
  487. init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  488. if hasattr(module.wi, "bias") and module.wi.bias is not None:
  489. init.zeros_(module.wi.bias)
  490. init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  491. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  492. init.zeros_(module.wo.bias)
  493. elif isinstance(module, MT5DenseGatedActDense):
  494. init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  495. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  496. init.zeros_(module.wi_0.bias)
  497. init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  498. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  499. init.zeros_(module.wi_1.bias)
  500. init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  501. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  502. init.zeros_(module.wo.bias)
  503. elif isinstance(module, MT5Attention):
  504. d_model = self.config.d_model
  505. key_value_proj_dim = self.config.d_kv
  506. n_heads = self.config.num_heads
  507. init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
  508. init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5))
  509. init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5))
  510. init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  511. if module.has_relative_attention_bias:
  512. init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5))
  513. def _shift_right(self, input_ids):
  514. decoder_start_token_id = self.config.decoder_start_token_id
  515. pad_token_id = self.config.pad_token_id
  516. if decoder_start_token_id is None:
  517. raise ValueError(
  518. "self.model.config.decoder_start_token_id has to be defined. In MT5 it is usually set to the pad_token_id. "
  519. "See MT5 docs for more information."
  520. )
  521. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  522. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  523. shifted_input_ids[..., 0] = decoder_start_token_id
  524. if pad_token_id is None:
  525. raise ValueError("self.model.config.pad_token_id has to be defined.")
  526. # replace possible -100 values in labels by `pad_token_id`
  527. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  528. return shifted_input_ids
  529. # Copied from transformers.models.t5.modeling_t5.T5Stack with T5->MT5
  530. class MT5Stack(MT5PreTrainedModel):
  531. def __init__(self, config):
  532. super().__init__(config)
  533. self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
  534. self.is_decoder = config.is_decoder
  535. self.block = nn.ModuleList(
  536. [MT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)]
  537. )
  538. self.final_layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  539. self.dropout = nn.Dropout(config.dropout_rate)
  540. # Initialize weights and apply final processing
  541. self.post_init()
  542. self.gradient_checkpointing = False
  543. def set_input_embeddings(self, new_embeddings):
  544. self.embed_tokens = new_embeddings
  545. def forward(
  546. self,
  547. input_ids=None,
  548. attention_mask=None,
  549. encoder_hidden_states=None,
  550. encoder_attention_mask=None,
  551. inputs_embeds=None,
  552. past_key_values=None,
  553. use_cache=None,
  554. output_attentions=None,
  555. output_hidden_states=None,
  556. return_dict=None,
  557. **kwargs,
  558. ):
  559. use_cache = use_cache if use_cache is not None else self.config.use_cache
  560. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  561. output_hidden_states = (
  562. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  563. )
  564. return_dict = return_dict if return_dict is not None else self.config.return_dict
  565. if input_ids is not None and inputs_embeds is not None:
  566. err_msg_prefix = "decoder_" if self.is_decoder else ""
  567. raise ValueError(
  568. f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
  569. )
  570. elif input_ids is not None:
  571. input_shape = input_ids.size()
  572. input_ids = input_ids.view(-1, input_shape[-1])
  573. elif inputs_embeds is not None:
  574. input_shape = inputs_embeds.size()[:-1]
  575. else:
  576. err_msg_prefix = "decoder_" if self.is_decoder else ""
  577. raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
  578. if self.gradient_checkpointing and self.training:
  579. if use_cache:
  580. logger.warning_once(
  581. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  582. )
  583. use_cache = False
  584. if inputs_embeds is None:
  585. if self.embed_tokens is None:
  586. raise ValueError("You have to initialize the model with valid token embeddings")
  587. inputs_embeds = self.embed_tokens(input_ids)
  588. batch_size, seq_length = input_shape
  589. if use_cache is True:
  590. if not self.is_decoder:
  591. raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
  592. if self.is_decoder:
  593. if use_cache and past_key_values is None:
  594. if self.config.is_encoder_decoder:
  595. past_key_values = EncoderDecoderCache(
  596. DynamicCache(config=self.config), DynamicCache(config=self.config)
  597. )
  598. else:
  599. past_key_values = DynamicCache(config=self.config)
  600. elif not self.is_decoder:
  601. # do not pass cache object down the line for encoder stack
  602. # it messes indexing later in decoder-stack because cache object is modified in-place
  603. past_key_values = None
  604. if self.config.is_decoder:
  605. attention_mask = create_causal_mask(
  606. config=self.config,
  607. inputs_embeds=inputs_embeds,
  608. attention_mask=attention_mask,
  609. past_key_values=past_key_values.self_attention_cache
  610. if isinstance(past_key_values, EncoderDecoderCache)
  611. else past_key_values,
  612. )
  613. else:
  614. attention_mask = create_bidirectional_mask(
  615. config=self.config,
  616. inputs_embeds=inputs_embeds,
  617. attention_mask=attention_mask,
  618. )
  619. encoder_extended_attention_mask = None
  620. if self.is_decoder and encoder_hidden_states is not None:
  621. encoder_extended_attention_mask = create_bidirectional_mask(
  622. config=self.config,
  623. inputs_embeds=inputs_embeds,
  624. attention_mask=encoder_attention_mask,
  625. encoder_hidden_states=encoder_hidden_states,
  626. )
  627. all_hidden_states = () if output_hidden_states else None
  628. all_attentions = () if output_attentions else None
  629. all_cross_attentions = () if (output_attentions and self.is_decoder) else None
  630. position_bias = None
  631. encoder_decoder_position_bias = None
  632. hidden_states = self.dropout(inputs_embeds)
  633. for layer_module in self.block:
  634. if output_hidden_states:
  635. all_hidden_states = all_hidden_states + (hidden_states,)
  636. layer_outputs = layer_module(
  637. hidden_states,
  638. attention_mask,
  639. position_bias,
  640. encoder_hidden_states,
  641. encoder_extended_attention_mask,
  642. encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
  643. past_key_values=past_key_values,
  644. use_cache=use_cache,
  645. output_attentions=output_attentions,
  646. return_dict=return_dict,
  647. )
  648. hidden_states = layer_outputs[0]
  649. # We share the position biases between the layers - the first layer store them
  650. # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
  651. # (cross-attention position bias), (cross-attention weights)
  652. position_bias = layer_outputs[1]
  653. if self.is_decoder and encoder_hidden_states is not None:
  654. encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
  655. if output_attentions:
  656. all_attentions = all_attentions + (layer_outputs[2],)
  657. if self.is_decoder:
  658. all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
  659. hidden_states = self.final_layer_norm(hidden_states)
  660. hidden_states = self.dropout(hidden_states)
  661. # Add last layer
  662. if output_hidden_states:
  663. all_hidden_states = all_hidden_states + (hidden_states,)
  664. if not return_dict:
  665. return tuple(
  666. v
  667. for v in [
  668. hidden_states,
  669. past_key_values,
  670. all_hidden_states,
  671. all_attentions,
  672. all_cross_attentions,
  673. ]
  674. if v is not None
  675. )
  676. return BaseModelOutputWithPastAndCrossAttentions(
  677. last_hidden_state=hidden_states,
  678. past_key_values=past_key_values,
  679. hidden_states=all_hidden_states,
  680. attentions=all_attentions,
  681. cross_attentions=all_cross_attentions,
  682. )
  683. @auto_docstring
  684. class MT5Model(MT5PreTrainedModel):
  685. r"""
  686. Examples:
  687. ```python
  688. >>> from transformers import MT5Model, AutoTokenizer
  689. >>> model = MT5Model.from_pretrained("google/mt5-small")
  690. >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
  691. >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
  692. >>> summary = "Weiter Verhandlung in Syrien."
  693. >>> inputs = tokenizer(article, return_tensors="pt")
  694. >>> labels = tokenizer(text_target=summary, return_tensors="pt")
  695. >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
  696. >>> hidden_states = outputs.last_hidden_state
  697. ```"""
  698. model_type = "mt5"
  699. config: MT5Config
  700. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  701. _tied_weights_keys = {
  702. "encoder.embed_tokens.weight": "shared.weight",
  703. "decoder.embed_tokens.weight": "shared.weight",
  704. }
  705. # Copied from transformers.models.t5.modeling_t5.T5Model.__init__ with T5->MT5
  706. def __init__(self, config: MT5Config):
  707. super().__init__(config)
  708. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  709. encoder_config = copy.deepcopy(config)
  710. encoder_config.is_decoder = False
  711. encoder_config.use_cache = False
  712. self.encoder = MT5Stack(encoder_config)
  713. decoder_config = copy.deepcopy(config)
  714. decoder_config.is_decoder = True
  715. decoder_config.num_layers = config.num_decoder_layers
  716. self.decoder = MT5Stack(decoder_config)
  717. # Initialize weights and apply final processing
  718. self.post_init()
  719. # Copied from transformers.models.t5.modeling_t5.T5Model.get_input_embeddings
  720. def get_input_embeddings(self):
  721. return self.shared
  722. # Copied from transformers.models.t5.modeling_t5.T5Model.set_input_embeddings
  723. def set_input_embeddings(self, new_embeddings):
  724. self.shared = new_embeddings
  725. self.encoder.set_input_embeddings(new_embeddings)
  726. self.decoder.set_input_embeddings(new_embeddings)
  727. @auto_docstring
  728. # Copied from transformers.models.t5.modeling_t5.T5Model.forward with google-t5/->google/, T5->MT5, t5->mt5
  729. def forward(
  730. self,
  731. input_ids: torch.LongTensor | None = None,
  732. attention_mask: torch.FloatTensor | None = None,
  733. decoder_input_ids: torch.LongTensor | None = None,
  734. decoder_attention_mask: torch.BoolTensor | None = None,
  735. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  736. past_key_values: Cache | None = None,
  737. inputs_embeds: torch.Tensor | None = None,
  738. decoder_inputs_embeds: torch.Tensor | None = None,
  739. use_cache: bool | None = None,
  740. output_attentions: bool | None = None,
  741. output_hidden_states: bool | None = None,
  742. return_dict: bool | None = None,
  743. **kwargs,
  744. ) -> tuple[torch.FloatTensor] | Seq2SeqModelOutput:
  745. r"""
  746. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  747. Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
  748. should be able to pad the inputs on both the right and the left.
  749. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  750. [`PreTrainedTokenizer.__call__`] for detail.
  751. [What are input IDs?](../glossary#input-ids)
  752. To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training).
  753. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  754. Indices of decoder input sequence tokens in the vocabulary.
  755. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  756. [`PreTrainedTokenizer.__call__`] for details.
  757. [What are decoder input IDs?](../glossary#decoder-input-ids)
  758. MT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  759. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  760. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [MT5
  761. Training](./mt5#training).
  762. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  763. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  764. be used by default.
  765. Example:
  766. ```python
  767. >>> from transformers import AutoTokenizer, MT5Model
  768. >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
  769. >>> model = MT5Model.from_pretrained("google/mt5-small")
  770. >>> input_ids = tokenizer(
  771. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  772. ... ).input_ids # Batch size 1
  773. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  774. >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for MT5Model.
  775. >>> # This is not needed for torch's MT5ForConditionalGeneration as it does this internally using labels arg.
  776. >>> decoder_input_ids = model._shift_right(decoder_input_ids)
  777. >>> # forward pass
  778. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  779. >>> last_hidden_states = outputs.last_hidden_state
  780. ```"""
  781. use_cache = use_cache if use_cache is not None else self.config.use_cache
  782. return_dict = return_dict if return_dict is not None else self.config.return_dict
  783. # Encode if needed (training, first prediction pass)
  784. if encoder_outputs is None:
  785. encoder_outputs = self.encoder(
  786. input_ids=input_ids,
  787. attention_mask=attention_mask,
  788. inputs_embeds=inputs_embeds,
  789. output_attentions=output_attentions,
  790. output_hidden_states=output_hidden_states,
  791. return_dict=return_dict,
  792. )
  793. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  794. encoder_outputs = BaseModelOutput(
  795. last_hidden_state=encoder_outputs[0],
  796. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  797. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  798. )
  799. hidden_states = encoder_outputs[0]
  800. # Decode
  801. decoder_outputs = self.decoder(
  802. input_ids=decoder_input_ids,
  803. attention_mask=decoder_attention_mask,
  804. inputs_embeds=decoder_inputs_embeds,
  805. past_key_values=past_key_values,
  806. encoder_hidden_states=hidden_states,
  807. encoder_attention_mask=attention_mask,
  808. use_cache=use_cache,
  809. output_attentions=output_attentions,
  810. output_hidden_states=output_hidden_states,
  811. return_dict=return_dict,
  812. )
  813. if not return_dict:
  814. return decoder_outputs + encoder_outputs
  815. return Seq2SeqModelOutput(
  816. last_hidden_state=decoder_outputs.last_hidden_state,
  817. past_key_values=decoder_outputs.past_key_values,
  818. decoder_hidden_states=decoder_outputs.hidden_states,
  819. decoder_attentions=decoder_outputs.attentions,
  820. cross_attentions=decoder_outputs.cross_attentions,
  821. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  822. encoder_hidden_states=encoder_outputs.hidden_states,
  823. encoder_attentions=encoder_outputs.attentions,
  824. )
  825. @auto_docstring(
  826. custom_intro="""
  827. MT5 Model with a `language modeling` head on top.
  828. """
  829. )
  830. class MT5ForConditionalGeneration(MT5PreTrainedModel, GenerationMixin):
  831. r"""
  832. Examples:
  833. ```python
  834. >>> from transformers import MT5ForConditionalGeneration, AutoTokenizer
  835. >>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
  836. >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
  837. >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
  838. >>> summary = "Weiter Verhandlung in Syrien."
  839. >>> inputs = tokenizer(article, text_target=summary, return_tensors="pt")
  840. >>> outputs = model(**inputs)
  841. >>> loss = outputs.loss
  842. ```"""
  843. model_type = "mt5"
  844. config: MT5Config
  845. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  846. _tied_weights_keys = {
  847. "encoder.embed_tokens.weight": "shared.weight",
  848. "decoder.embed_tokens.weight": "shared.weight",
  849. "lm_head.weight": "shared.weight",
  850. }
  851. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.__init__ with T5->MT5
  852. def __init__(self, config: MT5Config):
  853. super().__init__(config)
  854. self.model_dim = config.d_model
  855. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  856. encoder_config = copy.deepcopy(config)
  857. encoder_config.is_decoder = False
  858. encoder_config.use_cache = False
  859. self.encoder = MT5Stack(encoder_config)
  860. decoder_config = copy.deepcopy(config)
  861. decoder_config.is_decoder = True
  862. decoder_config.num_layers = config.num_decoder_layers
  863. self.decoder = MT5Stack(decoder_config)
  864. self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
  865. # Initialize weights and apply final processing
  866. self.post_init()
  867. def get_input_embeddings(self):
  868. return self.shared
  869. def set_input_embeddings(self, new_embeddings):
  870. self.shared = new_embeddings
  871. self.encoder.set_input_embeddings(new_embeddings)
  872. self.decoder.set_input_embeddings(new_embeddings)
  873. @auto_docstring
  874. def forward(
  875. self,
  876. input_ids: torch.LongTensor | None = None,
  877. attention_mask: torch.FloatTensor | None = None,
  878. decoder_input_ids: torch.LongTensor | None = None,
  879. decoder_attention_mask: torch.BoolTensor | None = None,
  880. encoder_outputs: tuple[tuple[torch.Tensor]] | None = None,
  881. past_key_values: Cache | None = None,
  882. inputs_embeds: torch.FloatTensor | None = None,
  883. decoder_inputs_embeds: torch.FloatTensor | None = None,
  884. labels: torch.LongTensor | None = None,
  885. use_cache: bool | None = None,
  886. output_attentions: bool | None = None,
  887. output_hidden_states: bool | None = None,
  888. return_dict: bool | None = None,
  889. **kwargs,
  890. ) -> tuple[torch.FloatTensor] | Seq2SeqLMOutput:
  891. r"""
  892. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  893. Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
  894. should be able to pad the inputs on both the right and the left.
  895. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  896. [`PreTrainedTokenizer.__call__`] for detail.
  897. [What are input IDs?](../glossary#input-ids)
  898. To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training).
  899. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  900. Indices of decoder input sequence tokens in the vocabulary.
  901. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  902. [`PreTrainedTokenizer.__call__`] for details.
  903. [What are decoder input IDs?](../glossary#decoder-input-ids)
  904. MT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  905. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  906. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [MT5
  907. Training](./mt5#training).
  908. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  909. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  910. be used by default.
  911. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  912. Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
  913. config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
  914. labels in `[0, ..., config.vocab_size]`
  915. Examples:
  916. ```python
  917. >>> from transformers import AutoTokenizer, MT5ForConditionalGeneration
  918. >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
  919. >>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
  920. >>> # training
  921. >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
  922. >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
  923. >>> outputs = model(input_ids=input_ids, labels=labels)
  924. >>> loss = outputs.loss
  925. >>> logits = outputs.logits
  926. >>> # inference
  927. >>> input_ids = tokenizer(
  928. ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
  929. ... ).input_ids # Batch size 1
  930. >>> outputs = model.generate(input_ids)
  931. >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
  932. >>> # studies have shown that owning a dog is good for you.
  933. ```"""
  934. use_cache = use_cache if use_cache is not None else self.config.use_cache
  935. return_dict = return_dict if return_dict is not None else self.config.return_dict
  936. # Encode if needed (training, first prediction pass)
  937. if encoder_outputs is None:
  938. # Convert encoder inputs in embeddings if needed
  939. encoder_outputs = self.encoder(
  940. input_ids=input_ids,
  941. attention_mask=attention_mask,
  942. inputs_embeds=inputs_embeds,
  943. output_attentions=output_attentions,
  944. output_hidden_states=output_hidden_states,
  945. return_dict=return_dict,
  946. )
  947. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  948. encoder_outputs = BaseModelOutput(
  949. last_hidden_state=encoder_outputs[0],
  950. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  951. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  952. )
  953. hidden_states = encoder_outputs[0]
  954. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  955. # get decoder inputs from shifting lm labels to the right
  956. decoder_input_ids = self._shift_right(labels)
  957. # Decode
  958. decoder_outputs = self.decoder(
  959. input_ids=decoder_input_ids,
  960. attention_mask=decoder_attention_mask,
  961. inputs_embeds=decoder_inputs_embeds,
  962. past_key_values=past_key_values,
  963. encoder_hidden_states=hidden_states,
  964. encoder_attention_mask=attention_mask,
  965. use_cache=use_cache,
  966. output_attentions=output_attentions,
  967. output_hidden_states=output_hidden_states,
  968. return_dict=return_dict,
  969. )
  970. sequence_output = decoder_outputs[0]
  971. lm_logits = self.lm_head(sequence_output)
  972. loss = None
  973. if labels is not None:
  974. loss_fct = CrossEntropyLoss(ignore_index=-100)
  975. # move labels to correct device to enable PP
  976. labels = labels.to(lm_logits.device)
  977. loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
  978. if not return_dict:
  979. output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
  980. return ((loss,) + output) if loss is not None else output
  981. return Seq2SeqLMOutput(
  982. loss=loss,
  983. logits=lm_logits,
  984. past_key_values=decoder_outputs.past_key_values,
  985. decoder_hidden_states=decoder_outputs.hidden_states,
  986. decoder_attentions=decoder_outputs.attentions,
  987. cross_attentions=decoder_outputs.cross_attentions,
  988. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  989. encoder_hidden_states=encoder_outputs.hidden_states,
  990. encoder_attentions=encoder_outputs.attentions,
  991. )
  992. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels
  993. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  994. return self._shift_right(labels)
  995. @auto_docstring
  996. class MT5EncoderModel(MT5PreTrainedModel):
  997. r"""
  998. Examples:
  999. ```python
  1000. >>> from transformers import MT5EncoderModel, AutoTokenizer
  1001. >>> model = MT5EncoderModel.from_pretrained("google/mt5-small")
  1002. >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
  1003. >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
  1004. >>> input_ids = tokenizer(article, return_tensors="pt").input_ids
  1005. >>> outputs = model(input_ids)
  1006. >>> hidden_state = outputs.last_hidden_state
  1007. ```"""
  1008. model_type = "mt5"
  1009. config: MT5Config
  1010. _tied_weights_keys = {
  1011. "encoder.embed_tokens.weight": "shared.weight",
  1012. }
  1013. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.__init__ with T5->MT5
  1014. def __init__(self, config: MT5Config):
  1015. super().__init__(config)
  1016. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1017. encoder_config = config
  1018. encoder_config.use_cache = False
  1019. encoder_config.is_encoder_decoder = False
  1020. self.encoder = MT5Stack(encoder_config)
  1021. # Initialize weights and apply final processing
  1022. self.post_init()
  1023. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_input_embeddings
  1024. def get_input_embeddings(self):
  1025. return self.shared
  1026. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.set_input_embeddings
  1027. def set_input_embeddings(self, new_embeddings):
  1028. self.shared = new_embeddings
  1029. self.encoder.set_input_embeddings(new_embeddings)
  1030. @auto_docstring
  1031. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.forward with google-t5/->google/, T5->MT5, t5->mt5
  1032. def forward(
  1033. self,
  1034. input_ids: torch.LongTensor | None = None,
  1035. attention_mask: torch.FloatTensor | None = None,
  1036. inputs_embeds: torch.FloatTensor | None = None,
  1037. output_attentions: bool | None = None,
  1038. output_hidden_states: bool | None = None,
  1039. return_dict: bool | None = None,
  1040. **kwargs,
  1041. ) -> tuple[torch.FloatTensor] | BaseModelOutput:
  1042. r"""
  1043. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1044. Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
  1045. should be able to pad the inputs on both the right and the left.
  1046. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1047. [`PreTrainedTokenizer.__call__`] for detail.
  1048. To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training).
  1049. Example:
  1050. ```python
  1051. >>> from transformers import AutoTokenizer, MT5EncoderModel
  1052. >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
  1053. >>> model = MT5EncoderModel.from_pretrained("google/mt5-small")
  1054. >>> input_ids = tokenizer(
  1055. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1056. ... ).input_ids # Batch size 1
  1057. >>> outputs = model(input_ids=input_ids)
  1058. >>> last_hidden_states = outputs.last_hidden_state
  1059. ```"""
  1060. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1061. encoder_outputs = self.encoder(
  1062. input_ids=input_ids,
  1063. attention_mask=attention_mask,
  1064. inputs_embeds=inputs_embeds,
  1065. output_attentions=output_attentions,
  1066. output_hidden_states=output_hidden_states,
  1067. return_dict=return_dict,
  1068. )
  1069. return encoder_outputs
  1070. @auto_docstring(
  1071. custom_intro="""
  1072. MT5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  1073. tasks.
  1074. """
  1075. )
  1076. class MT5ForSequenceClassification(MT5PreTrainedModel):
  1077. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1078. # Copied from transformers.models.t5.modeling_t5.T5ForSequenceClassification.__init__ with T5->MT5
  1079. def __init__(self, config: MT5Config):
  1080. super().__init__(config)
  1081. self.transformer = MT5Model(config)
  1082. self.classification_head = MT5ClassificationHead(config)
  1083. # Initialize weights and apply final processing
  1084. self.post_init()
  1085. @auto_docstring
  1086. # Copied from transformers.models.t5.modeling_t5.T5ForSequenceClassification.forward with T5->MT5, t5->mt5
  1087. def forward(
  1088. self,
  1089. input_ids: torch.LongTensor | None = None,
  1090. attention_mask: torch.Tensor | None = None,
  1091. decoder_input_ids: torch.LongTensor | None = None,
  1092. decoder_attention_mask: torch.LongTensor | None = None,
  1093. encoder_outputs: list[torch.FloatTensor] | None = None,
  1094. inputs_embeds: torch.FloatTensor | None = None,
  1095. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1096. labels: torch.LongTensor | None = None,
  1097. use_cache: bool | None = None,
  1098. output_attentions: bool | None = None,
  1099. output_hidden_states: bool | None = None,
  1100. return_dict: bool | None = None,
  1101. **kwargs,
  1102. ) -> tuple | Seq2SeqSequenceClassifierOutput:
  1103. r"""
  1104. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1105. Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
  1106. should be able to pad the inputs on both the right and the left.
  1107. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1108. [`PreTrainedTokenizer.__call__`] for detail.
  1109. [What are input IDs?](../glossary#input-ids)
  1110. To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training).
  1111. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1112. Indices of decoder input sequence tokens in the vocabulary.
  1113. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1114. [`PreTrainedTokenizer.__call__`] for details.
  1115. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1116. MT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1117. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1118. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [MT5
  1119. Training](./mt5#training).
  1120. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1121. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1122. be used by default.
  1123. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1124. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1125. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1126. """
  1127. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1128. if labels is not None:
  1129. use_cache = False
  1130. if input_ids is None and inputs_embeds is not None:
  1131. raise NotImplementedError(
  1132. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  1133. )
  1134. # Copied from models.bart.modeling_bart.BartModel.forward different to other models, MT5 automatically creates
  1135. # decoder_input_ids from input_ids if no decoder_input_ids are provided
  1136. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1137. if input_ids is None:
  1138. raise ValueError(
  1139. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1140. "passed, `input_ids` cannot be `None`. Please pass either "
  1141. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1142. )
  1143. decoder_input_ids = self._shift_right(input_ids)
  1144. outputs = self.transformer(
  1145. input_ids,
  1146. attention_mask=attention_mask,
  1147. decoder_input_ids=decoder_input_ids,
  1148. decoder_attention_mask=decoder_attention_mask,
  1149. encoder_outputs=encoder_outputs,
  1150. inputs_embeds=inputs_embeds,
  1151. decoder_inputs_embeds=decoder_inputs_embeds,
  1152. use_cache=use_cache,
  1153. output_attentions=output_attentions,
  1154. output_hidden_states=output_hidden_states,
  1155. return_dict=return_dict,
  1156. )
  1157. sequence_output = outputs[0]
  1158. eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
  1159. torch_compilable_check(
  1160. torch.unique_consecutive(eos_mask.sum(1)).numel() == 1,
  1161. "All examples must have the same number of <eos> tokens.",
  1162. )
  1163. batch_size, _, hidden_size = sequence_output.shape
  1164. sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
  1165. logits = self.classification_head(sentence_representation)
  1166. loss = None
  1167. if labels is not None:
  1168. labels = labels.to(logits.device)
  1169. if self.config.problem_type is None:
  1170. if self.config.num_labels == 1:
  1171. self.config.problem_type = "regression"
  1172. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1173. self.config.problem_type = "single_label_classification"
  1174. else:
  1175. self.config.problem_type = "multi_label_classification"
  1176. if self.config.problem_type == "regression":
  1177. loss_fct = MSELoss()
  1178. if self.config.num_labels == 1:
  1179. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1180. else:
  1181. loss = loss_fct(logits, labels)
  1182. elif self.config.problem_type == "single_label_classification":
  1183. loss_fct = CrossEntropyLoss()
  1184. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1185. elif self.config.problem_type == "multi_label_classification":
  1186. loss_fct = BCEWithLogitsLoss()
  1187. loss = loss_fct(logits, labels)
  1188. if not return_dict:
  1189. output = (logits,) + outputs[1:]
  1190. return ((loss,) + output) if loss is not None else output
  1191. return Seq2SeqSequenceClassifierOutput(
  1192. loss=loss,
  1193. logits=logits,
  1194. past_key_values=outputs.past_key_values,
  1195. decoder_hidden_states=outputs.decoder_hidden_states,
  1196. decoder_attentions=outputs.decoder_attentions,
  1197. cross_attentions=outputs.cross_attentions,
  1198. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1199. encoder_hidden_states=outputs.encoder_hidden_states,
  1200. encoder_attentions=outputs.encoder_attentions,
  1201. )
  1202. @auto_docstring
  1203. class MT5ForTokenClassification(MT5PreTrainedModel):
  1204. # Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.__init__ with T5->MT5
  1205. def __init__(self, config: MT5Config):
  1206. super().__init__(config)
  1207. self.num_labels = config.num_labels
  1208. self.transformer = MT5EncoderModel(config)
  1209. self.dropout = nn.Dropout(config.classifier_dropout)
  1210. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1211. # Initialize weights and apply final processing
  1212. self.post_init()
  1213. @auto_docstring
  1214. # Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.forward with T5->MT5
  1215. def forward(
  1216. self,
  1217. input_ids: torch.Tensor | None = None,
  1218. attention_mask: torch.Tensor | None = None,
  1219. inputs_embeds: torch.Tensor | None = None,
  1220. labels: torch.Tensor | None = None,
  1221. output_attentions: bool | None = None,
  1222. output_hidden_states: bool | None = None,
  1223. return_dict: bool | None = None,
  1224. **kwargs,
  1225. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  1226. r"""
  1227. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1228. Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
  1229. should be able to pad the inputs on both the right and the left.
  1230. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1231. [`PreTrainedTokenizer.__call__`] for detail.
  1232. [What are input IDs?](../glossary#input-ids)
  1233. To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./t5#training).
  1234. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1235. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1236. """
  1237. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1238. outputs = self.transformer(
  1239. input_ids,
  1240. attention_mask=attention_mask,
  1241. inputs_embeds=inputs_embeds,
  1242. output_attentions=output_attentions,
  1243. output_hidden_states=output_hidden_states,
  1244. return_dict=return_dict,
  1245. )
  1246. hidden_states = outputs[0]
  1247. hidden_states = self.dropout(hidden_states)
  1248. logits = self.classifier(hidden_states)
  1249. loss = None
  1250. if labels is not None:
  1251. loss_fct = CrossEntropyLoss()
  1252. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1253. if not return_dict:
  1254. output = (logits, outputs[2:-1])
  1255. return ((loss,) + output) if loss is not None else output
  1256. return TokenClassifierOutput(
  1257. loss=loss,
  1258. logits=logits,
  1259. hidden_states=outputs.hidden_states,
  1260. attentions=outputs.attentions,
  1261. )
  1262. @auto_docstring
  1263. class MT5ForQuestionAnswering(MT5PreTrainedModel):
  1264. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1265. _tied_weights_keys = {
  1266. "encoder.embed_tokens.weight": "shared.weight",
  1267. "decoder.embed_tokens.weight": "shared.weight",
  1268. }
  1269. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.__init__ with T5->MT5
  1270. def __init__(self, config: MT5Config):
  1271. super().__init__(config)
  1272. self.model_dim = config.d_model
  1273. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1274. encoder_config = copy.deepcopy(config)
  1275. encoder_config.is_decoder = False
  1276. encoder_config.use_cache = False
  1277. self.encoder = MT5Stack(encoder_config)
  1278. decoder_config = copy.deepcopy(config)
  1279. decoder_config.is_decoder = True
  1280. decoder_config.num_layers = config.num_decoder_layers
  1281. self.decoder = MT5Stack(decoder_config)
  1282. self.num_labels = config.num_labels
  1283. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1284. # Initialize weights and apply final processing
  1285. self.post_init()
  1286. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_input_embeddings
  1287. def get_input_embeddings(self):
  1288. return self.shared
  1289. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.set_input_embeddings
  1290. def set_input_embeddings(self, new_embeddings):
  1291. self.shared = new_embeddings
  1292. self.encoder.set_input_embeddings(new_embeddings)
  1293. self.decoder.set_input_embeddings(new_embeddings)
  1294. @auto_docstring
  1295. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.forward
  1296. def forward(
  1297. self,
  1298. input_ids: torch.LongTensor | None = None,
  1299. attention_mask: torch.FloatTensor | None = None,
  1300. decoder_input_ids: torch.LongTensor | None = None,
  1301. decoder_attention_mask: torch.BoolTensor | None = None,
  1302. encoder_outputs: tuple[tuple[torch.Tensor]] | None = None,
  1303. start_positions: torch.LongTensor | None = None,
  1304. end_positions: torch.LongTensor | None = None,
  1305. inputs_embeds: torch.FloatTensor | None = None,
  1306. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1307. use_cache: bool | None = None,
  1308. output_attentions: bool | None = None,
  1309. output_hidden_states: bool | None = None,
  1310. return_dict: bool | None = None,
  1311. **kwargs,
  1312. ) -> tuple[torch.FloatTensor] | Seq2SeqQuestionAnsweringModelOutput:
  1313. r"""
  1314. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1315. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1316. should be able to pad the inputs on both the right and the left.
  1317. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1318. [`PreTrainedTokenizer.__call__`] for detail.
  1319. [What are input IDs?](../glossary#input-ids)
  1320. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1321. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1322. Indices of decoder input sequence tokens in the vocabulary.
  1323. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1324. [`PreTrainedTokenizer.__call__`] for details.
  1325. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1326. T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1327. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1328. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
  1329. Training](./t5#training).
  1330. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1331. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1332. be used by default.
  1333. """
  1334. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1335. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1336. if start_positions is not None and end_positions is not None:
  1337. use_cache = False
  1338. # Copied from models.bart.modeling_bart.BartModel.forward
  1339. # different to other models, T5 automatically creates decoder_input_ids from
  1340. # input_ids if no decoder_input_ids are provided
  1341. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1342. if input_ids is None:
  1343. raise ValueError(
  1344. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1345. "passed, `input_ids` cannot be `None`. Please pass either "
  1346. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1347. )
  1348. decoder_input_ids = self._shift_right(input_ids)
  1349. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1350. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1351. # Encode if needed (training, first prediction pass)
  1352. if encoder_outputs is None:
  1353. encoder_outputs = self.encoder(
  1354. input_ids=input_ids,
  1355. attention_mask=attention_mask,
  1356. inputs_embeds=inputs_embeds,
  1357. output_attentions=output_attentions,
  1358. output_hidden_states=output_hidden_states,
  1359. return_dict=return_dict,
  1360. )
  1361. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1362. encoder_outputs = BaseModelOutput(
  1363. last_hidden_state=encoder_outputs[0],
  1364. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1365. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1366. )
  1367. hidden_states = encoder_outputs[0]
  1368. # Decode
  1369. decoder_outputs = self.decoder(
  1370. input_ids=decoder_input_ids,
  1371. attention_mask=decoder_attention_mask,
  1372. inputs_embeds=decoder_inputs_embeds,
  1373. past_key_values=None,
  1374. encoder_hidden_states=hidden_states,
  1375. encoder_attention_mask=attention_mask,
  1376. use_cache=use_cache,
  1377. output_attentions=output_attentions,
  1378. output_hidden_states=output_hidden_states,
  1379. return_dict=return_dict,
  1380. )
  1381. sequence_output = decoder_outputs[0]
  1382. logits = self.qa_outputs(sequence_output)
  1383. start_logits, end_logits = logits.split(1, dim=-1)
  1384. start_logits = start_logits.squeeze(-1).contiguous()
  1385. end_logits = end_logits.squeeze(-1).contiguous()
  1386. total_loss = None
  1387. if start_positions is not None and end_positions is not None:
  1388. # If we are on multi-GPU, split add a dimension
  1389. if len(start_positions.size()) > 1:
  1390. start_positions = start_positions.squeeze(-1).to(start_logits.device)
  1391. if len(end_positions.size()) > 1:
  1392. end_positions = end_positions.squeeze(-1).to(end_logits.device)
  1393. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1394. ignored_index = start_logits.size(1)
  1395. start_positions = start_positions.clamp(0, ignored_index)
  1396. end_positions = end_positions.clamp(0, ignored_index)
  1397. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1398. start_loss = loss_fct(start_logits, start_positions)
  1399. end_loss = loss_fct(end_logits, end_positions)
  1400. total_loss = (start_loss + end_loss) / 2
  1401. if not return_dict:
  1402. output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs
  1403. return ((total_loss,) + output) if total_loss is not None else output
  1404. return Seq2SeqQuestionAnsweringModelOutput(
  1405. loss=total_loss,
  1406. start_logits=start_logits,
  1407. end_logits=end_logits,
  1408. past_key_values=decoder_outputs.past_key_values,
  1409. decoder_hidden_states=decoder_outputs.hidden_states,
  1410. decoder_attentions=decoder_outputs.attentions,
  1411. cross_attentions=decoder_outputs.cross_attentions,
  1412. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1413. encoder_hidden_states=encoder_outputs.hidden_states,
  1414. encoder_attentions=encoder_outputs.attentions,
  1415. )
  1416. __all__ = [
  1417. "MT5EncoderModel",
  1418. "MT5ForConditionalGeneration",
  1419. "MT5ForQuestionAnswering",
  1420. "MT5ForSequenceClassification",
  1421. "MT5ForTokenClassification",
  1422. "MT5Model",
  1423. "MT5PreTrainedModel",
  1424. ]