modeling_t5.py 70 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619
  1. # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch T5 model."""
  15. import copy
  16. import math
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutput,
  28. BaseModelOutputWithPastAndCrossAttentions,
  29. Seq2SeqLMOutput,
  30. Seq2SeqModelOutput,
  31. Seq2SeqQuestionAnsweringModelOutput,
  32. Seq2SeqSequenceClassifierOutput,
  33. TokenClassifierOutput,
  34. )
  35. from ...modeling_utils import PreTrainedModel
  36. from ...utils import DUMMY_INPUTS, DUMMY_MASK, auto_docstring, logging, torch_compilable_check
  37. from .configuration_t5 import T5Config
  38. logger = logging.get_logger(__name__)
  39. class T5LayerNorm(nn.Module):
  40. def __init__(self, hidden_size, eps=1e-6):
  41. """
  42. Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
  43. """
  44. super().__init__()
  45. self.weight = nn.Parameter(torch.ones(hidden_size))
  46. self.variance_epsilon = eps
  47. def forward(self, hidden_states):
  48. # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  49. # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
  50. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  51. # half-precision inputs is done in fp32
  52. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  53. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  54. # convert into half-precision if necessary
  55. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  56. hidden_states = hidden_states.to(self.weight.dtype)
  57. return self.weight * hidden_states
  58. try:
  59. from apex.normalization import FusedRMSNorm
  60. T5LayerNorm = FusedRMSNorm
  61. logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
  62. except ImportError:
  63. # using the normal T5LayerNorm
  64. pass
  65. except Exception:
  66. logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
  67. class T5DenseActDense(nn.Module):
  68. def __init__(self, config: T5Config):
  69. super().__init__()
  70. self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
  71. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  72. self.dropout = nn.Dropout(config.dropout_rate)
  73. self.act = ACT2FN[config.dense_act_fn]
  74. def forward(self, hidden_states):
  75. hidden_states = self.wi(hidden_states)
  76. hidden_states = self.act(hidden_states)
  77. hidden_states = self.dropout(hidden_states)
  78. if (
  79. isinstance(self.wo.weight, torch.Tensor)
  80. and hidden_states.dtype != self.wo.weight.dtype
  81. and self.wo.weight.dtype != torch.int8
  82. ):
  83. hidden_states = hidden_states.to(self.wo.weight.dtype)
  84. hidden_states = self.wo(hidden_states)
  85. return hidden_states
  86. class T5DenseGatedActDense(nn.Module):
  87. def __init__(self, config: T5Config):
  88. super().__init__()
  89. self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
  90. self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
  91. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  92. self.dropout = nn.Dropout(config.dropout_rate)
  93. self.act = ACT2FN[config.dense_act_fn]
  94. def forward(self, hidden_states):
  95. hidden_gelu = self.act(self.wi_0(hidden_states))
  96. hidden_linear = self.wi_1(hidden_states)
  97. hidden_states = hidden_gelu * hidden_linear
  98. hidden_states = self.dropout(hidden_states)
  99. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  100. # See https://github.com/huggingface/transformers/issues/20287
  101. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  102. if (
  103. isinstance(self.wo.weight, torch.Tensor)
  104. and hidden_states.dtype != self.wo.weight.dtype
  105. and self.wo.weight.dtype != torch.int8
  106. ):
  107. hidden_states = hidden_states.to(self.wo.weight.dtype)
  108. hidden_states = self.wo(hidden_states)
  109. return hidden_states
  110. class T5LayerFF(nn.Module):
  111. def __init__(self, config: T5Config):
  112. super().__init__()
  113. if config.is_gated_act:
  114. self.DenseReluDense = T5DenseGatedActDense(config)
  115. else:
  116. self.DenseReluDense = T5DenseActDense(config)
  117. self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  118. self.dropout = nn.Dropout(config.dropout_rate)
  119. def forward(self, hidden_states):
  120. forwarded_states = self.layer_norm(hidden_states)
  121. forwarded_states = self.DenseReluDense(forwarded_states)
  122. hidden_states = hidden_states + self.dropout(forwarded_states)
  123. return hidden_states
  124. class T5Attention(nn.Module):
  125. def __init__(
  126. self,
  127. config: T5Config,
  128. has_relative_attention_bias=False,
  129. layer_idx: int | None = None,
  130. ):
  131. super().__init__()
  132. self.is_decoder = config.is_decoder
  133. self.has_relative_attention_bias = has_relative_attention_bias
  134. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  135. self.relative_attention_max_distance = config.relative_attention_max_distance
  136. self.d_model = config.d_model
  137. self.key_value_proj_dim = config.d_kv
  138. self.n_heads = config.num_heads
  139. self.dropout = config.dropout_rate
  140. self.inner_dim = self.n_heads * self.key_value_proj_dim
  141. self.layer_idx = layer_idx
  142. if layer_idx is None and self.is_decoder:
  143. logger.warning_once(
  144. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  145. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  146. "when creating this class."
  147. )
  148. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  149. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  150. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  151. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  152. if self.has_relative_attention_bias:
  153. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  154. self.gradient_checkpointing = False
  155. @staticmethod
  156. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  157. """
  158. Adapted from Mesh Tensorflow:
  159. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  160. Translate relative position to a bucket number for relative attention. The relative position is defined as
  161. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  162. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  163. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  164. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  165. This should allow for more graceful generalization to longer sequences than the model has been trained on
  166. Args:
  167. relative_position: an int32 Tensor
  168. bidirectional: a boolean - whether the attention is bidirectional
  169. num_buckets: an integer
  170. max_distance: an integer
  171. Returns:
  172. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  173. """
  174. relative_buckets = 0
  175. if bidirectional:
  176. num_buckets //= 2
  177. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  178. relative_position = torch.abs(relative_position)
  179. else:
  180. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  181. # now relative_position is in the range [0, inf)
  182. # half of the buckets are for exact increments in positions
  183. max_exact = num_buckets // 2
  184. is_small = relative_position < max_exact
  185. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  186. relative_position_if_large = max_exact + (
  187. torch.log(relative_position.float() / max_exact)
  188. / math.log(max_distance / max_exact)
  189. * (num_buckets - max_exact)
  190. ).to(torch.long)
  191. relative_position_if_large = torch.min(
  192. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  193. )
  194. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  195. return relative_buckets
  196. def compute_bias(self, query_length, key_length, device=None, past_seen_tokens=0):
  197. """Compute binned relative position bias"""
  198. if device is None:
  199. device = self.relative_attention_bias.weight.device
  200. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + past_seen_tokens
  201. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  202. relative_position = memory_position - context_position # shape (query_length, key_length)
  203. relative_position_bucket = self._relative_position_bucket(
  204. relative_position, # shape (query_length, key_length)
  205. bidirectional=(not self.is_decoder),
  206. num_buckets=self.relative_attention_num_buckets,
  207. max_distance=self.relative_attention_max_distance,
  208. )
  209. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  210. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  211. return values
  212. def forward(
  213. self,
  214. hidden_states,
  215. mask=None,
  216. key_value_states=None,
  217. position_bias=None,
  218. past_key_values=None,
  219. output_attentions=False,
  220. **kwargs,
  221. ):
  222. """
  223. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  224. """
  225. # Input is (batch_size, seq_length, dim)
  226. # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
  227. input_shape = hidden_states.shape[:-1]
  228. hidden_shape = (*input_shape, -1, self.key_value_proj_dim)
  229. past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0
  230. # We clone here for StaticCache, as we get the value before updating it, but use it after and it's the same ref
  231. past_seen_tokens = past_seen_tokens.clone() if isinstance(past_seen_tokens, torch.Tensor) else past_seen_tokens
  232. # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
  233. is_cross_attention = key_value_states is not None
  234. query_states = self.q(hidden_states).view(hidden_shape).transpose(1, 2)
  235. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  236. is_updated = False
  237. if isinstance(past_key_values, EncoderDecoderCache):
  238. is_updated = past_key_values.is_updated.get(self.layer_idx)
  239. if is_cross_attention:
  240. # after the first generated id, we can subsequently re-use all key/value_states from cache
  241. curr_past_key_values = past_key_values.cross_attention_cache
  242. else:
  243. curr_past_key_values = past_key_values.self_attention_cache
  244. else:
  245. curr_past_key_values = past_key_values
  246. current_states = key_value_states if is_cross_attention else hidden_states
  247. if is_cross_attention and past_key_values is not None and is_updated:
  248. # reuse k,v, cross_attentions
  249. key_states = curr_past_key_values.layers[self.layer_idx].keys
  250. value_states = curr_past_key_values.layers[self.layer_idx].values
  251. else:
  252. kv_shape = (*current_states.shape[:-1], -1, self.key_value_proj_dim)
  253. key_states = self.k(current_states).view(kv_shape).transpose(1, 2)
  254. value_states = self.v(current_states).view(kv_shape).transpose(1, 2)
  255. if past_key_values is not None:
  256. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  257. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  258. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  259. past_key_values.is_updated[self.layer_idx] = True
  260. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  261. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  262. if position_bias is None:
  263. key_length = key_states.shape[-2]
  264. if not self.has_relative_attention_bias:
  265. position_bias = torch.zeros(
  266. (1, query_states.shape[1], input_shape[1], key_length), device=scores.device, dtype=scores.dtype
  267. )
  268. if self.gradient_checkpointing and self.training:
  269. position_bias.requires_grad = True
  270. else:
  271. position_bias = self.compute_bias(
  272. input_shape[1], key_length, device=scores.device, past_seen_tokens=past_seen_tokens
  273. )
  274. if mask is not None:
  275. causal_mask = mask[:, :, :, : key_states.shape[-2]]
  276. position_bias = position_bias + causal_mask
  277. position_bias_masked = position_bias
  278. scores += position_bias_masked
  279. # (batch_size, n_heads, seq_length, key_length)
  280. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  281. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  282. attn_output = torch.matmul(attn_weights, value_states)
  283. attn_output = attn_output.transpose(1, 2).contiguous()
  284. attn_output = attn_output.reshape(*input_shape, -1)
  285. attn_output = self.o(attn_output)
  286. outputs = (attn_output, position_bias)
  287. if output_attentions:
  288. outputs = outputs + (attn_weights,)
  289. return outputs
  290. class T5LayerSelfAttention(nn.Module):
  291. def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
  292. super().__init__()
  293. self.SelfAttention = T5Attention(
  294. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  295. )
  296. self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  297. self.dropout = nn.Dropout(config.dropout_rate)
  298. def forward(
  299. self,
  300. hidden_states,
  301. attention_mask=None,
  302. position_bias=None,
  303. past_key_values=None,
  304. use_cache=False,
  305. output_attentions=False,
  306. **kwargs,
  307. ):
  308. normed_hidden_states = self.layer_norm(hidden_states)
  309. attention_output = self.SelfAttention(
  310. normed_hidden_states,
  311. mask=attention_mask,
  312. position_bias=position_bias,
  313. past_key_values=past_key_values,
  314. use_cache=use_cache,
  315. output_attentions=output_attentions,
  316. )
  317. hidden_states = hidden_states + self.dropout(attention_output[0])
  318. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  319. return outputs
  320. class T5LayerCrossAttention(nn.Module):
  321. def __init__(self, config, layer_idx: int | None = None):
  322. super().__init__()
  323. self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  324. self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  325. self.dropout = nn.Dropout(config.dropout_rate)
  326. def forward(
  327. self,
  328. hidden_states,
  329. key_value_states,
  330. attention_mask=None,
  331. position_bias=None,
  332. past_key_values=None,
  333. output_attentions=False,
  334. **kwargs,
  335. ):
  336. normed_hidden_states = self.layer_norm(hidden_states)
  337. attention_output = self.EncDecAttention(
  338. normed_hidden_states,
  339. mask=attention_mask,
  340. key_value_states=key_value_states,
  341. position_bias=position_bias,
  342. past_key_values=past_key_values,
  343. output_attentions=output_attentions,
  344. )
  345. layer_output = hidden_states + self.dropout(attention_output[0])
  346. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  347. return outputs
  348. class T5Block(GradientCheckpointingLayer):
  349. def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
  350. super().__init__()
  351. self.is_decoder = config.is_decoder
  352. self.layer = nn.ModuleList()
  353. self.layer.append(
  354. T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
  355. )
  356. if self.is_decoder:
  357. self.layer.append(T5LayerCrossAttention(config, layer_idx=layer_idx))
  358. self.layer.append(T5LayerFF(config))
  359. def forward(
  360. self,
  361. hidden_states,
  362. attention_mask=None,
  363. position_bias=None,
  364. encoder_hidden_states=None,
  365. encoder_attention_mask=None,
  366. encoder_decoder_position_bias=None,
  367. past_key_values=None,
  368. use_cache=False,
  369. output_attentions=False,
  370. return_dict=True,
  371. **kwargs,
  372. ):
  373. self_attention_outputs = self.layer[0](
  374. hidden_states,
  375. attention_mask=attention_mask,
  376. position_bias=position_bias,
  377. past_key_values=past_key_values,
  378. use_cache=use_cache,
  379. output_attentions=output_attentions,
  380. )
  381. hidden_states = self_attention_outputs[0]
  382. attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
  383. # clamp inf values to enable fp16 training
  384. if hidden_states.dtype == torch.float16:
  385. clamp_value = torch.where(
  386. torch.isinf(hidden_states).any(),
  387. torch.finfo(hidden_states.dtype).max - 1000,
  388. torch.finfo(hidden_states.dtype).max,
  389. )
  390. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  391. do_cross_attention = self.is_decoder and encoder_hidden_states is not None
  392. if do_cross_attention:
  393. cross_attention_outputs = self.layer[1](
  394. hidden_states,
  395. key_value_states=encoder_hidden_states,
  396. attention_mask=encoder_attention_mask,
  397. position_bias=encoder_decoder_position_bias,
  398. past_key_values=past_key_values,
  399. output_attentions=output_attentions,
  400. )
  401. hidden_states = cross_attention_outputs[0]
  402. # clamp inf values to enable fp16 training
  403. if hidden_states.dtype == torch.float16:
  404. clamp_value = torch.where(
  405. torch.isinf(hidden_states).any(),
  406. torch.finfo(hidden_states.dtype).max - 1000,
  407. torch.finfo(hidden_states.dtype).max,
  408. )
  409. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  410. # Keep cross-attention outputs and relative position weights
  411. attention_outputs = attention_outputs + cross_attention_outputs[1:]
  412. # Apply Feed Forward layer
  413. hidden_states = self.layer[-1](hidden_states)
  414. # clamp inf values to enable fp16 training
  415. if hidden_states.dtype == torch.float16:
  416. clamp_value = torch.where(
  417. torch.isinf(hidden_states).any(),
  418. torch.finfo(hidden_states.dtype).max - 1000,
  419. torch.finfo(hidden_states.dtype).max,
  420. )
  421. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  422. outputs = (hidden_states,)
  423. return (
  424. outputs + attention_outputs
  425. ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  426. class T5ClassificationHead(nn.Module):
  427. """Head for sentence-level classification tasks."""
  428. def __init__(self, config: T5Config):
  429. super().__init__()
  430. self.dense = nn.Linear(config.d_model, config.d_model)
  431. self.dropout = nn.Dropout(p=config.classifier_dropout)
  432. self.out_proj = nn.Linear(config.d_model, config.num_labels)
  433. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  434. hidden_states = self.dropout(hidden_states)
  435. hidden_states = self.dense(hidden_states)
  436. hidden_states = torch.tanh(hidden_states)
  437. hidden_states = self.dropout(hidden_states)
  438. hidden_states = self.out_proj(hidden_states)
  439. return hidden_states
  440. @auto_docstring
  441. class T5PreTrainedModel(PreTrainedModel):
  442. config: T5Config
  443. base_model_prefix = "transformer"
  444. supports_gradient_checkpointing = True
  445. _can_compile_fullgraph = True
  446. _no_split_modules = ["T5Block"]
  447. _keep_in_fp32_modules = ["wo"]
  448. @property
  449. def dummy_inputs(self):
  450. input_ids = torch.tensor(DUMMY_INPUTS)
  451. input_mask = torch.tensor(DUMMY_MASK)
  452. dummy_inputs = {
  453. "decoder_input_ids": input_ids,
  454. "input_ids": input_ids,
  455. "decoder_attention_mask": input_mask,
  456. }
  457. return dummy_inputs
  458. @torch.no_grad()
  459. def _init_weights(self, module):
  460. """Initialize the weights"""
  461. factor = self.config.initializer_factor # Used for testing weights initialization
  462. if isinstance(module, T5LayerNorm):
  463. init.constant_(module.weight, factor * 1.0)
  464. elif isinstance(
  465. module,
  466. (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering),
  467. ):
  468. init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
  469. if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
  470. init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0)
  471. if hasattr(module, "qa_outputs"):
  472. init.normal_(module.qa_outputs.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  473. init.zeros_(module.qa_outputs.bias)
  474. elif isinstance(module, T5ForTokenClassification):
  475. if hasattr(module, "classifier"):
  476. init.normal_(module.classifier.weight, mean=0.0, std=factor * 1.0)
  477. init.zeros_(module.classifier.bias)
  478. elif isinstance(module, T5ClassificationHead):
  479. init.normal_(module.dense.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  480. if hasattr(module.dense, "bias") and module.dense.bias is not None:
  481. init.zeros_(module.dense.bias)
  482. init.normal_(module.out_proj.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  483. if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
  484. init.zeros_(module.out_proj.bias)
  485. elif isinstance(module, T5DenseActDense):
  486. init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  487. if hasattr(module.wi, "bias") and module.wi.bias is not None:
  488. init.zeros_(module.wi.bias)
  489. init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  490. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  491. init.zeros_(module.wo.bias)
  492. elif isinstance(module, T5DenseGatedActDense):
  493. init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  494. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  495. init.zeros_(module.wi_0.bias)
  496. init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  497. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  498. init.zeros_(module.wi_1.bias)
  499. init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  500. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  501. init.zeros_(module.wo.bias)
  502. elif isinstance(module, T5Attention):
  503. d_model = self.config.d_model
  504. key_value_proj_dim = self.config.d_kv
  505. n_heads = self.config.num_heads
  506. init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
  507. init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5))
  508. init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5))
  509. init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  510. if module.has_relative_attention_bias:
  511. init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5))
  512. def _shift_right(self, input_ids):
  513. decoder_start_token_id = self.config.decoder_start_token_id
  514. pad_token_id = self.config.pad_token_id
  515. if decoder_start_token_id is None:
  516. raise ValueError(
  517. "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
  518. "See T5 docs for more information."
  519. )
  520. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  521. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  522. shifted_input_ids[..., 0] = decoder_start_token_id
  523. if pad_token_id is None:
  524. raise ValueError("self.model.config.pad_token_id has to be defined.")
  525. # replace possible -100 values in labels by `pad_token_id`
  526. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  527. return shifted_input_ids
  528. class T5Stack(T5PreTrainedModel):
  529. def __init__(self, config):
  530. super().__init__(config)
  531. self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
  532. self.is_decoder = config.is_decoder
  533. self.block = nn.ModuleList(
  534. [T5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)]
  535. )
  536. self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  537. self.dropout = nn.Dropout(config.dropout_rate)
  538. # Initialize weights and apply final processing
  539. self.post_init()
  540. self.gradient_checkpointing = False
  541. def set_input_embeddings(self, new_embeddings):
  542. self.embed_tokens = new_embeddings
  543. def forward(
  544. self,
  545. input_ids=None,
  546. attention_mask=None,
  547. encoder_hidden_states=None,
  548. encoder_attention_mask=None,
  549. inputs_embeds=None,
  550. past_key_values=None,
  551. use_cache=None,
  552. output_attentions=None,
  553. output_hidden_states=None,
  554. return_dict=None,
  555. **kwargs,
  556. ):
  557. use_cache = use_cache if use_cache is not None else self.config.use_cache
  558. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  559. output_hidden_states = (
  560. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  561. )
  562. return_dict = return_dict if return_dict is not None else self.config.return_dict
  563. if input_ids is not None and inputs_embeds is not None:
  564. err_msg_prefix = "decoder_" if self.is_decoder else ""
  565. raise ValueError(
  566. f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
  567. )
  568. elif input_ids is not None:
  569. input_shape = input_ids.size()
  570. input_ids = input_ids.view(-1, input_shape[-1])
  571. elif inputs_embeds is not None:
  572. input_shape = inputs_embeds.size()[:-1]
  573. else:
  574. err_msg_prefix = "decoder_" if self.is_decoder else ""
  575. raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
  576. if self.gradient_checkpointing and self.training:
  577. if use_cache:
  578. logger.warning_once(
  579. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  580. )
  581. use_cache = False
  582. if inputs_embeds is None:
  583. if self.embed_tokens is None:
  584. raise ValueError("You have to initialize the model with valid token embeddings")
  585. inputs_embeds = self.embed_tokens(input_ids)
  586. batch_size, seq_length = input_shape
  587. if use_cache is True:
  588. if not self.is_decoder:
  589. raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
  590. if self.is_decoder:
  591. if use_cache and past_key_values is None:
  592. if self.config.is_encoder_decoder:
  593. past_key_values = EncoderDecoderCache(
  594. DynamicCache(config=self.config), DynamicCache(config=self.config)
  595. )
  596. else:
  597. past_key_values = DynamicCache(config=self.config)
  598. elif not self.is_decoder:
  599. # do not pass cache object down the line for encoder stack
  600. # it messes indexing later in decoder-stack because cache object is modified in-place
  601. past_key_values = None
  602. if self.config.is_decoder:
  603. attention_mask = create_causal_mask(
  604. config=self.config,
  605. inputs_embeds=inputs_embeds,
  606. attention_mask=attention_mask,
  607. past_key_values=past_key_values.self_attention_cache
  608. if isinstance(past_key_values, EncoderDecoderCache)
  609. else past_key_values,
  610. )
  611. else:
  612. attention_mask = create_bidirectional_mask(
  613. config=self.config,
  614. inputs_embeds=inputs_embeds,
  615. attention_mask=attention_mask,
  616. )
  617. encoder_extended_attention_mask = None
  618. if self.is_decoder and encoder_hidden_states is not None:
  619. encoder_extended_attention_mask = create_bidirectional_mask(
  620. config=self.config,
  621. inputs_embeds=inputs_embeds,
  622. attention_mask=encoder_attention_mask,
  623. encoder_hidden_states=encoder_hidden_states,
  624. )
  625. all_hidden_states = () if output_hidden_states else None
  626. all_attentions = () if output_attentions else None
  627. all_cross_attentions = () if (output_attentions and self.is_decoder) else None
  628. position_bias = None
  629. encoder_decoder_position_bias = None
  630. hidden_states = self.dropout(inputs_embeds)
  631. for layer_module in self.block:
  632. if output_hidden_states:
  633. all_hidden_states = all_hidden_states + (hidden_states,)
  634. layer_outputs = layer_module(
  635. hidden_states,
  636. attention_mask,
  637. position_bias,
  638. encoder_hidden_states,
  639. encoder_extended_attention_mask,
  640. encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
  641. past_key_values=past_key_values,
  642. use_cache=use_cache,
  643. output_attentions=output_attentions,
  644. return_dict=return_dict,
  645. )
  646. hidden_states = layer_outputs[0]
  647. # We share the position biases between the layers - the first layer store them
  648. # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
  649. # (cross-attention position bias), (cross-attention weights)
  650. position_bias = layer_outputs[1]
  651. if self.is_decoder and encoder_hidden_states is not None:
  652. encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
  653. if output_attentions:
  654. all_attentions = all_attentions + (layer_outputs[2],)
  655. if self.is_decoder:
  656. all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
  657. hidden_states = self.final_layer_norm(hidden_states)
  658. hidden_states = self.dropout(hidden_states)
  659. # Add last layer
  660. if output_hidden_states:
  661. all_hidden_states = all_hidden_states + (hidden_states,)
  662. if not return_dict:
  663. return tuple(
  664. v
  665. for v in [
  666. hidden_states,
  667. past_key_values,
  668. all_hidden_states,
  669. all_attentions,
  670. all_cross_attentions,
  671. ]
  672. if v is not None
  673. )
  674. return BaseModelOutputWithPastAndCrossAttentions(
  675. last_hidden_state=hidden_states,
  676. past_key_values=past_key_values,
  677. hidden_states=all_hidden_states,
  678. attentions=all_attentions,
  679. cross_attentions=all_cross_attentions,
  680. )
  681. @auto_docstring
  682. class T5Model(T5PreTrainedModel):
  683. _keys_to_ignore_on_load_unexpected = [
  684. "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
  685. ]
  686. _tied_weights_keys = {
  687. "encoder.embed_tokens.weight": "shared.weight",
  688. "decoder.embed_tokens.weight": "shared.weight",
  689. }
  690. def __init__(self, config: T5Config):
  691. super().__init__(config)
  692. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  693. encoder_config = copy.deepcopy(config)
  694. encoder_config.is_decoder = False
  695. encoder_config.use_cache = False
  696. self.encoder = T5Stack(encoder_config)
  697. decoder_config = copy.deepcopy(config)
  698. decoder_config.is_decoder = True
  699. decoder_config.num_layers = config.num_decoder_layers
  700. self.decoder = T5Stack(decoder_config)
  701. # Initialize weights and apply final processing
  702. self.post_init()
  703. def get_input_embeddings(self):
  704. return self.shared
  705. def set_input_embeddings(self, new_embeddings):
  706. self.shared = new_embeddings
  707. self.encoder.set_input_embeddings(new_embeddings)
  708. self.decoder.set_input_embeddings(new_embeddings)
  709. @auto_docstring
  710. def forward(
  711. self,
  712. input_ids: torch.LongTensor | None = None,
  713. attention_mask: torch.FloatTensor | None = None,
  714. decoder_input_ids: torch.LongTensor | None = None,
  715. decoder_attention_mask: torch.BoolTensor | None = None,
  716. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  717. past_key_values: Cache | None = None,
  718. inputs_embeds: torch.Tensor | None = None,
  719. decoder_inputs_embeds: torch.Tensor | None = None,
  720. use_cache: bool | None = None,
  721. output_attentions: bool | None = None,
  722. output_hidden_states: bool | None = None,
  723. return_dict: bool | None = None,
  724. **kwargs,
  725. ) -> tuple[torch.FloatTensor] | Seq2SeqModelOutput:
  726. r"""
  727. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  728. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  729. should be able to pad the inputs on both the right and the left.
  730. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  731. [`PreTrainedTokenizer.__call__`] for detail.
  732. [What are input IDs?](../glossary#input-ids)
  733. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  734. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  735. Indices of decoder input sequence tokens in the vocabulary.
  736. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  737. [`PreTrainedTokenizer.__call__`] for details.
  738. [What are decoder input IDs?](../glossary#decoder-input-ids)
  739. T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  740. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  741. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
  742. Training](./t5#training).
  743. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  744. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  745. be used by default.
  746. Example:
  747. ```python
  748. >>> from transformers import AutoTokenizer, T5Model
  749. >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
  750. >>> model = T5Model.from_pretrained("google-t5/t5-small")
  751. >>> input_ids = tokenizer(
  752. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  753. ... ).input_ids # Batch size 1
  754. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  755. >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
  756. >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
  757. >>> decoder_input_ids = model._shift_right(decoder_input_ids)
  758. >>> # forward pass
  759. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  760. >>> last_hidden_states = outputs.last_hidden_state
  761. ```"""
  762. use_cache = use_cache if use_cache is not None else self.config.use_cache
  763. return_dict = return_dict if return_dict is not None else self.config.return_dict
  764. # Encode if needed (training, first prediction pass)
  765. if encoder_outputs is None:
  766. encoder_outputs = self.encoder(
  767. input_ids=input_ids,
  768. attention_mask=attention_mask,
  769. inputs_embeds=inputs_embeds,
  770. output_attentions=output_attentions,
  771. output_hidden_states=output_hidden_states,
  772. return_dict=return_dict,
  773. )
  774. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  775. encoder_outputs = BaseModelOutput(
  776. last_hidden_state=encoder_outputs[0],
  777. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  778. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  779. )
  780. hidden_states = encoder_outputs[0]
  781. # Decode
  782. decoder_outputs = self.decoder(
  783. input_ids=decoder_input_ids,
  784. attention_mask=decoder_attention_mask,
  785. inputs_embeds=decoder_inputs_embeds,
  786. past_key_values=past_key_values,
  787. encoder_hidden_states=hidden_states,
  788. encoder_attention_mask=attention_mask,
  789. use_cache=use_cache,
  790. output_attentions=output_attentions,
  791. output_hidden_states=output_hidden_states,
  792. return_dict=return_dict,
  793. )
  794. if not return_dict:
  795. return decoder_outputs + encoder_outputs
  796. return Seq2SeqModelOutput(
  797. last_hidden_state=decoder_outputs.last_hidden_state,
  798. past_key_values=decoder_outputs.past_key_values,
  799. decoder_hidden_states=decoder_outputs.hidden_states,
  800. decoder_attentions=decoder_outputs.attentions,
  801. cross_attentions=decoder_outputs.cross_attentions,
  802. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  803. encoder_hidden_states=encoder_outputs.hidden_states,
  804. encoder_attentions=encoder_outputs.attentions,
  805. )
  806. @auto_docstring(
  807. custom_intro="""
  808. T5 Model with a `language modeling` head on top.
  809. """
  810. )
  811. class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
  812. _keys_to_ignore_on_load_unexpected = [
  813. "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
  814. ]
  815. _tied_weights_keys = {
  816. "lm_head.weight": "shared.weight",
  817. "encoder.embed_tokens.weight": "shared.weight",
  818. "decoder.embed_tokens.weight": "shared.weight",
  819. }
  820. def __init__(self, config: T5Config):
  821. super().__init__(config)
  822. self.model_dim = config.d_model
  823. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  824. encoder_config = copy.deepcopy(config)
  825. encoder_config.is_decoder = False
  826. encoder_config.use_cache = False
  827. self.encoder = T5Stack(encoder_config)
  828. decoder_config = copy.deepcopy(config)
  829. decoder_config.is_decoder = True
  830. decoder_config.num_layers = config.num_decoder_layers
  831. self.decoder = T5Stack(decoder_config)
  832. self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
  833. # Initialize weights and apply final processing
  834. self.post_init()
  835. def get_input_embeddings(self):
  836. return self.shared
  837. def set_input_embeddings(self, new_embeddings):
  838. self.shared = new_embeddings
  839. self.encoder.set_input_embeddings(new_embeddings)
  840. self.decoder.set_input_embeddings(new_embeddings)
  841. @auto_docstring
  842. def forward(
  843. self,
  844. input_ids: torch.LongTensor | None = None,
  845. attention_mask: torch.FloatTensor | None = None,
  846. decoder_input_ids: torch.LongTensor | None = None,
  847. decoder_attention_mask: torch.BoolTensor | None = None,
  848. encoder_outputs: tuple[tuple[torch.Tensor]] | None = None,
  849. past_key_values: Cache | None = None,
  850. inputs_embeds: torch.FloatTensor | None = None,
  851. decoder_inputs_embeds: torch.FloatTensor | None = None,
  852. labels: torch.LongTensor | None = None,
  853. use_cache: bool | None = None,
  854. output_attentions: bool | None = None,
  855. output_hidden_states: bool | None = None,
  856. return_dict: bool | None = None,
  857. **kwargs,
  858. ) -> tuple[torch.FloatTensor] | Seq2SeqLMOutput:
  859. r"""
  860. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  861. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  862. should be able to pad the inputs on both the right and the left.
  863. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  864. [`PreTrainedTokenizer.__call__`] for detail.
  865. [What are input IDs?](../glossary#input-ids)
  866. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  867. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  868. Indices of decoder input sequence tokens in the vocabulary.
  869. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  870. [`PreTrainedTokenizer.__call__`] for details.
  871. [What are decoder input IDs?](../glossary#decoder-input-ids)
  872. T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  873. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  874. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
  875. Training](./t5#training).
  876. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  877. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  878. be used by default.
  879. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  880. Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
  881. config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
  882. labels in `[0, ..., config.vocab_size]`
  883. Examples:
  884. ```python
  885. >>> from transformers import AutoTokenizer, T5ForConditionalGeneration
  886. >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
  887. >>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
  888. >>> # training
  889. >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
  890. >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
  891. >>> outputs = model(input_ids=input_ids, labels=labels)
  892. >>> loss = outputs.loss
  893. >>> logits = outputs.logits
  894. >>> # inference
  895. >>> input_ids = tokenizer(
  896. ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
  897. ... ).input_ids # Batch size 1
  898. >>> outputs = model.generate(input_ids)
  899. >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
  900. >>> # studies have shown that owning a dog is good for you.
  901. ```"""
  902. use_cache = use_cache if use_cache is not None else self.config.use_cache
  903. return_dict = return_dict if return_dict is not None else self.config.return_dict
  904. # Encode if needed (training, first prediction pass)
  905. if encoder_outputs is None:
  906. # Convert encoder inputs in embeddings if needed
  907. encoder_outputs = self.encoder(
  908. input_ids=input_ids,
  909. attention_mask=attention_mask,
  910. inputs_embeds=inputs_embeds,
  911. output_attentions=output_attentions,
  912. output_hidden_states=output_hidden_states,
  913. return_dict=return_dict,
  914. )
  915. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  916. encoder_outputs = BaseModelOutput(
  917. last_hidden_state=encoder_outputs[0],
  918. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  919. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  920. )
  921. hidden_states = encoder_outputs[0]
  922. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  923. # get decoder inputs from shifting lm labels to the right
  924. decoder_input_ids = self._shift_right(labels)
  925. # Decode
  926. decoder_outputs = self.decoder(
  927. input_ids=decoder_input_ids,
  928. attention_mask=decoder_attention_mask,
  929. inputs_embeds=decoder_inputs_embeds,
  930. past_key_values=past_key_values,
  931. encoder_hidden_states=hidden_states,
  932. encoder_attention_mask=attention_mask,
  933. use_cache=use_cache,
  934. output_attentions=output_attentions,
  935. output_hidden_states=output_hidden_states,
  936. return_dict=return_dict,
  937. )
  938. sequence_output = decoder_outputs[0]
  939. if self.config.scale_decoder_outputs:
  940. sequence_output = sequence_output * (self.model_dim**-0.5)
  941. lm_logits = self.lm_head(sequence_output)
  942. loss = None
  943. if labels is not None:
  944. loss_fct = CrossEntropyLoss(ignore_index=-100)
  945. # move labels to correct device to enable PP
  946. labels = labels.to(lm_logits.device)
  947. loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
  948. if not return_dict:
  949. output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
  950. return ((loss,) + output) if loss is not None else output
  951. return Seq2SeqLMOutput(
  952. loss=loss,
  953. logits=lm_logits,
  954. past_key_values=decoder_outputs.past_key_values,
  955. decoder_hidden_states=decoder_outputs.hidden_states,
  956. decoder_attentions=decoder_outputs.attentions,
  957. cross_attentions=decoder_outputs.cross_attentions,
  958. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  959. encoder_hidden_states=encoder_outputs.hidden_states,
  960. encoder_attentions=encoder_outputs.attentions,
  961. )
  962. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  963. return self._shift_right(labels)
  964. @auto_docstring
  965. class T5EncoderModel(T5PreTrainedModel):
  966. _tied_weights_keys = {"encoder.embed_tokens.weight": "shared.weight"}
  967. _keys_to_ignore_on_load_unexpected = [r"decoder"]
  968. def __init__(self, config: T5Config):
  969. super().__init__(config)
  970. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  971. encoder_config = config
  972. encoder_config.use_cache = False
  973. encoder_config.is_encoder_decoder = False
  974. self.encoder = T5Stack(encoder_config)
  975. # Initialize weights and apply final processing
  976. self.post_init()
  977. def get_input_embeddings(self):
  978. return self.shared
  979. def set_input_embeddings(self, new_embeddings):
  980. self.shared = new_embeddings
  981. self.encoder.set_input_embeddings(new_embeddings)
  982. @auto_docstring
  983. def forward(
  984. self,
  985. input_ids: torch.LongTensor | None = None,
  986. attention_mask: torch.FloatTensor | None = None,
  987. inputs_embeds: torch.FloatTensor | None = None,
  988. output_attentions: bool | None = None,
  989. output_hidden_states: bool | None = None,
  990. return_dict: bool | None = None,
  991. **kwargs,
  992. ) -> tuple[torch.FloatTensor] | BaseModelOutput:
  993. r"""
  994. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  995. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  996. should be able to pad the inputs on both the right and the left.
  997. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  998. [`PreTrainedTokenizer.__call__`] for detail.
  999. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1000. Example:
  1001. ```python
  1002. >>> from transformers import AutoTokenizer, T5EncoderModel
  1003. >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
  1004. >>> model = T5EncoderModel.from_pretrained("google-t5/t5-small")
  1005. >>> input_ids = tokenizer(
  1006. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1007. ... ).input_ids # Batch size 1
  1008. >>> outputs = model(input_ids=input_ids)
  1009. >>> last_hidden_states = outputs.last_hidden_state
  1010. ```"""
  1011. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1012. encoder_outputs = self.encoder(
  1013. input_ids=input_ids,
  1014. attention_mask=attention_mask,
  1015. inputs_embeds=inputs_embeds,
  1016. output_attentions=output_attentions,
  1017. output_hidden_states=output_hidden_states,
  1018. return_dict=return_dict,
  1019. )
  1020. return encoder_outputs
  1021. @auto_docstring(
  1022. custom_intro="""
  1023. T5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  1024. tasks.
  1025. """
  1026. )
  1027. class T5ForSequenceClassification(T5PreTrainedModel):
  1028. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1029. def __init__(self, config: T5Config):
  1030. super().__init__(config)
  1031. self.transformer = T5Model(config)
  1032. self.classification_head = T5ClassificationHead(config)
  1033. # Initialize weights and apply final processing
  1034. self.post_init()
  1035. @auto_docstring
  1036. def forward(
  1037. self,
  1038. input_ids: torch.LongTensor | None = None,
  1039. attention_mask: torch.Tensor | None = None,
  1040. decoder_input_ids: torch.LongTensor | None = None,
  1041. decoder_attention_mask: torch.LongTensor | None = None,
  1042. encoder_outputs: list[torch.FloatTensor] | None = None,
  1043. inputs_embeds: torch.FloatTensor | None = None,
  1044. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1045. labels: torch.LongTensor | None = None,
  1046. use_cache: bool | None = None,
  1047. output_attentions: bool | None = None,
  1048. output_hidden_states: bool | None = None,
  1049. return_dict: bool | None = None,
  1050. **kwargs,
  1051. ) -> tuple | Seq2SeqSequenceClassifierOutput:
  1052. r"""
  1053. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1054. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1055. should be able to pad the inputs on both the right and the left.
  1056. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1057. [`PreTrainedTokenizer.__call__`] for detail.
  1058. [What are input IDs?](../glossary#input-ids)
  1059. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1060. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1061. Indices of decoder input sequence tokens in the vocabulary.
  1062. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1063. [`PreTrainedTokenizer.__call__`] for details.
  1064. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1065. T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1066. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1067. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
  1068. Training](./t5#training).
  1069. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1070. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1071. be used by default.
  1072. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1073. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1074. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1075. """
  1076. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1077. if labels is not None:
  1078. use_cache = False
  1079. if input_ids is None and inputs_embeds is not None:
  1080. raise NotImplementedError(
  1081. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  1082. )
  1083. # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates
  1084. # decoder_input_ids from input_ids if no decoder_input_ids are provided
  1085. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1086. if input_ids is None:
  1087. raise ValueError(
  1088. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1089. "passed, `input_ids` cannot be `None`. Please pass either "
  1090. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1091. )
  1092. decoder_input_ids = self._shift_right(input_ids)
  1093. outputs = self.transformer(
  1094. input_ids,
  1095. attention_mask=attention_mask,
  1096. decoder_input_ids=decoder_input_ids,
  1097. decoder_attention_mask=decoder_attention_mask,
  1098. encoder_outputs=encoder_outputs,
  1099. inputs_embeds=inputs_embeds,
  1100. decoder_inputs_embeds=decoder_inputs_embeds,
  1101. use_cache=use_cache,
  1102. output_attentions=output_attentions,
  1103. output_hidden_states=output_hidden_states,
  1104. return_dict=return_dict,
  1105. )
  1106. sequence_output = outputs[0]
  1107. eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
  1108. torch_compilable_check(
  1109. torch.unique_consecutive(eos_mask.sum(1)).numel() == 1,
  1110. "All examples must have the same number of <eos> tokens.",
  1111. )
  1112. batch_size, _, hidden_size = sequence_output.shape
  1113. sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
  1114. logits = self.classification_head(sentence_representation)
  1115. loss = None
  1116. if labels is not None:
  1117. labels = labels.to(logits.device)
  1118. if self.config.problem_type is None:
  1119. if self.config.num_labels == 1:
  1120. self.config.problem_type = "regression"
  1121. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1122. self.config.problem_type = "single_label_classification"
  1123. else:
  1124. self.config.problem_type = "multi_label_classification"
  1125. if self.config.problem_type == "regression":
  1126. loss_fct = MSELoss()
  1127. if self.config.num_labels == 1:
  1128. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1129. else:
  1130. loss = loss_fct(logits, labels)
  1131. elif self.config.problem_type == "single_label_classification":
  1132. loss_fct = CrossEntropyLoss()
  1133. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1134. elif self.config.problem_type == "multi_label_classification":
  1135. loss_fct = BCEWithLogitsLoss()
  1136. loss = loss_fct(logits, labels)
  1137. if not return_dict:
  1138. output = (logits,) + outputs[1:]
  1139. return ((loss,) + output) if loss is not None else output
  1140. return Seq2SeqSequenceClassifierOutput(
  1141. loss=loss,
  1142. logits=logits,
  1143. past_key_values=outputs.past_key_values,
  1144. decoder_hidden_states=outputs.decoder_hidden_states,
  1145. decoder_attentions=outputs.decoder_attentions,
  1146. cross_attentions=outputs.cross_attentions,
  1147. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1148. encoder_hidden_states=outputs.encoder_hidden_states,
  1149. encoder_attentions=outputs.encoder_attentions,
  1150. )
  1151. @auto_docstring
  1152. class T5ForTokenClassification(T5PreTrainedModel):
  1153. def __init__(self, config: T5Config):
  1154. super().__init__(config)
  1155. self.num_labels = config.num_labels
  1156. self.transformer = T5EncoderModel(config)
  1157. self.dropout = nn.Dropout(config.classifier_dropout)
  1158. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1159. # Initialize weights and apply final processing
  1160. self.post_init()
  1161. @auto_docstring
  1162. def forward(
  1163. self,
  1164. input_ids: torch.Tensor | None = None,
  1165. attention_mask: torch.Tensor | None = None,
  1166. inputs_embeds: torch.Tensor | None = None,
  1167. labels: torch.Tensor | None = None,
  1168. output_attentions: bool | None = None,
  1169. output_hidden_states: bool | None = None,
  1170. return_dict: bool | None = None,
  1171. **kwargs,
  1172. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  1173. r"""
  1174. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1175. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1176. should be able to pad the inputs on both the right and the left.
  1177. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1178. [`PreTrainedTokenizer.__call__`] for detail.
  1179. [What are input IDs?](../glossary#input-ids)
  1180. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1181. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1182. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1183. """
  1184. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1185. outputs = self.transformer(
  1186. input_ids,
  1187. attention_mask=attention_mask,
  1188. inputs_embeds=inputs_embeds,
  1189. output_attentions=output_attentions,
  1190. output_hidden_states=output_hidden_states,
  1191. return_dict=return_dict,
  1192. )
  1193. hidden_states = outputs[0]
  1194. hidden_states = self.dropout(hidden_states)
  1195. logits = self.classifier(hidden_states)
  1196. loss = None
  1197. if labels is not None:
  1198. loss_fct = CrossEntropyLoss()
  1199. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1200. if not return_dict:
  1201. output = (logits, outputs[2:-1])
  1202. return ((loss,) + output) if loss is not None else output
  1203. return TokenClassifierOutput(
  1204. loss=loss,
  1205. logits=logits,
  1206. hidden_states=outputs.hidden_states,
  1207. attentions=outputs.attentions,
  1208. )
  1209. @auto_docstring
  1210. class T5ForQuestionAnswering(T5PreTrainedModel):
  1211. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1212. _tied_weights_keys = {
  1213. "encoder.embed_tokens.weight": "shared.weight",
  1214. "decoder.embed_tokens.weight": "shared.weight",
  1215. }
  1216. def __init__(self, config: T5Config):
  1217. super().__init__(config)
  1218. self.model_dim = config.d_model
  1219. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1220. encoder_config = copy.deepcopy(config)
  1221. encoder_config.is_decoder = False
  1222. encoder_config.use_cache = False
  1223. self.encoder = T5Stack(encoder_config)
  1224. decoder_config = copy.deepcopy(config)
  1225. decoder_config.is_decoder = True
  1226. decoder_config.num_layers = config.num_decoder_layers
  1227. self.decoder = T5Stack(decoder_config)
  1228. self.num_labels = config.num_labels
  1229. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1230. # Initialize weights and apply final processing
  1231. self.post_init()
  1232. def get_input_embeddings(self):
  1233. return self.shared
  1234. def set_input_embeddings(self, new_embeddings):
  1235. self.shared = new_embeddings
  1236. self.encoder.set_input_embeddings(new_embeddings)
  1237. self.decoder.set_input_embeddings(new_embeddings)
  1238. @auto_docstring
  1239. def forward(
  1240. self,
  1241. input_ids: torch.LongTensor | None = None,
  1242. attention_mask: torch.FloatTensor | None = None,
  1243. decoder_input_ids: torch.LongTensor | None = None,
  1244. decoder_attention_mask: torch.BoolTensor | None = None,
  1245. encoder_outputs: tuple[tuple[torch.Tensor]] | None = None,
  1246. start_positions: torch.LongTensor | None = None,
  1247. end_positions: torch.LongTensor | None = None,
  1248. inputs_embeds: torch.FloatTensor | None = None,
  1249. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1250. use_cache: bool | None = None,
  1251. output_attentions: bool | None = None,
  1252. output_hidden_states: bool | None = None,
  1253. return_dict: bool | None = None,
  1254. **kwargs,
  1255. ) -> tuple[torch.FloatTensor] | Seq2SeqQuestionAnsweringModelOutput:
  1256. r"""
  1257. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1258. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1259. should be able to pad the inputs on both the right and the left.
  1260. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1261. [`PreTrainedTokenizer.__call__`] for detail.
  1262. [What are input IDs?](../glossary#input-ids)
  1263. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1264. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1265. Indices of decoder input sequence tokens in the vocabulary.
  1266. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1267. [`PreTrainedTokenizer.__call__`] for details.
  1268. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1269. T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1270. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1271. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
  1272. Training](./t5#training).
  1273. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1274. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1275. be used by default.
  1276. """
  1277. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1278. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1279. if start_positions is not None and end_positions is not None:
  1280. use_cache = False
  1281. # Copied from models.bart.modeling_bart.BartModel.forward
  1282. # different to other models, T5 automatically creates decoder_input_ids from
  1283. # input_ids if no decoder_input_ids are provided
  1284. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1285. if input_ids is None:
  1286. raise ValueError(
  1287. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1288. "passed, `input_ids` cannot be `None`. Please pass either "
  1289. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1290. )
  1291. decoder_input_ids = self._shift_right(input_ids)
  1292. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1293. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1294. # Encode if needed (training, first prediction pass)
  1295. if encoder_outputs is None:
  1296. encoder_outputs = self.encoder(
  1297. input_ids=input_ids,
  1298. attention_mask=attention_mask,
  1299. inputs_embeds=inputs_embeds,
  1300. output_attentions=output_attentions,
  1301. output_hidden_states=output_hidden_states,
  1302. return_dict=return_dict,
  1303. )
  1304. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1305. encoder_outputs = BaseModelOutput(
  1306. last_hidden_state=encoder_outputs[0],
  1307. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1308. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1309. )
  1310. hidden_states = encoder_outputs[0]
  1311. # Decode
  1312. decoder_outputs = self.decoder(
  1313. input_ids=decoder_input_ids,
  1314. attention_mask=decoder_attention_mask,
  1315. inputs_embeds=decoder_inputs_embeds,
  1316. past_key_values=None,
  1317. encoder_hidden_states=hidden_states,
  1318. encoder_attention_mask=attention_mask,
  1319. use_cache=use_cache,
  1320. output_attentions=output_attentions,
  1321. output_hidden_states=output_hidden_states,
  1322. return_dict=return_dict,
  1323. )
  1324. sequence_output = decoder_outputs[0]
  1325. logits = self.qa_outputs(sequence_output)
  1326. start_logits, end_logits = logits.split(1, dim=-1)
  1327. start_logits = start_logits.squeeze(-1).contiguous()
  1328. end_logits = end_logits.squeeze(-1).contiguous()
  1329. total_loss = None
  1330. if start_positions is not None and end_positions is not None:
  1331. # If we are on multi-GPU, split add a dimension
  1332. if len(start_positions.size()) > 1:
  1333. start_positions = start_positions.squeeze(-1).to(start_logits.device)
  1334. if len(end_positions.size()) > 1:
  1335. end_positions = end_positions.squeeze(-1).to(end_logits.device)
  1336. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1337. ignored_index = start_logits.size(1)
  1338. start_positions = start_positions.clamp(0, ignored_index)
  1339. end_positions = end_positions.clamp(0, ignored_index)
  1340. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1341. start_loss = loss_fct(start_logits, start_positions)
  1342. end_loss = loss_fct(end_logits, end_positions)
  1343. total_loss = (start_loss + end_loss) / 2
  1344. if not return_dict:
  1345. output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs
  1346. return ((total_loss,) + output) if total_loss is not None else output
  1347. return Seq2SeqQuestionAnsweringModelOutput(
  1348. loss=total_loss,
  1349. start_logits=start_logits,
  1350. end_logits=end_logits,
  1351. past_key_values=decoder_outputs.past_key_values,
  1352. decoder_hidden_states=decoder_outputs.hidden_states,
  1353. decoder_attentions=decoder_outputs.attentions,
  1354. cross_attentions=decoder_outputs.cross_attentions,
  1355. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1356. encoder_hidden_states=encoder_outputs.hidden_states,
  1357. encoder_attentions=encoder_outputs.attentions,
  1358. )
  1359. __all__ = [
  1360. "T5EncoderModel",
  1361. "T5ForConditionalGeneration",
  1362. "T5Model",
  1363. "T5PreTrainedModel",
  1364. "T5ForQuestionAnswering",
  1365. "T5ForSequenceClassification",
  1366. "T5ForTokenClassification",
  1367. ]