modeling_umt5.py 72 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630
  1. # Copyright 2023 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch UMT5 model."""
  15. import copy
  16. import math
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...masking_utils import create_causal_mask
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutput,
  28. BaseModelOutputWithPastAndCrossAttentions,
  29. Seq2SeqLMOutput,
  30. Seq2SeqModelOutput,
  31. Seq2SeqQuestionAnsweringModelOutput,
  32. Seq2SeqSequenceClassifierOutput,
  33. TokenClassifierOutput,
  34. )
  35. from ...modeling_utils import PreTrainedModel
  36. from ...utils import (
  37. DUMMY_INPUTS,
  38. DUMMY_MASK,
  39. auto_docstring,
  40. is_torchdynamo_compiling,
  41. logging,
  42. torch_compilable_check,
  43. )
  44. from .configuration_umt5 import UMT5Config
  45. logger = logging.get_logger(__name__)
  46. # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->UMT5
  47. class UMT5LayerNorm(nn.Module):
  48. def __init__(self, hidden_size, eps=1e-6):
  49. """
  50. Construct a layernorm module in the UMT5 style. No bias and no subtraction of mean.
  51. """
  52. super().__init__()
  53. self.weight = nn.Parameter(torch.ones(hidden_size))
  54. self.variance_epsilon = eps
  55. def forward(self, hidden_states):
  56. # UMT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  57. # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
  58. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  59. # half-precision inputs is done in fp32
  60. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  61. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  62. # convert into half-precision if necessary
  63. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  64. hidden_states = hidden_states.to(self.weight.dtype)
  65. return self.weight * hidden_states
  66. # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->UMT5
  67. class UMT5DenseActDense(nn.Module):
  68. def __init__(self, config: UMT5Config):
  69. super().__init__()
  70. self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
  71. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  72. self.dropout = nn.Dropout(config.dropout_rate)
  73. self.act = ACT2FN[config.dense_act_fn]
  74. def forward(self, hidden_states):
  75. hidden_states = self.wi(hidden_states)
  76. hidden_states = self.act(hidden_states)
  77. hidden_states = self.dropout(hidden_states)
  78. if (
  79. isinstance(self.wo.weight, torch.Tensor)
  80. and hidden_states.dtype != self.wo.weight.dtype
  81. and self.wo.weight.dtype != torch.int8
  82. ):
  83. hidden_states = hidden_states.to(self.wo.weight.dtype)
  84. hidden_states = self.wo(hidden_states)
  85. return hidden_states
  86. # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->UMT5
  87. class UMT5DenseGatedActDense(nn.Module):
  88. def __init__(self, config: UMT5Config):
  89. super().__init__()
  90. self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
  91. self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
  92. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  93. self.dropout = nn.Dropout(config.dropout_rate)
  94. self.act = ACT2FN[config.dense_act_fn]
  95. def forward(self, hidden_states):
  96. hidden_gelu = self.act(self.wi_0(hidden_states))
  97. hidden_linear = self.wi_1(hidden_states)
  98. hidden_states = hidden_gelu * hidden_linear
  99. hidden_states = self.dropout(hidden_states)
  100. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  101. # See https://github.com/huggingface/transformers/issues/20287
  102. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  103. if (
  104. isinstance(self.wo.weight, torch.Tensor)
  105. and hidden_states.dtype != self.wo.weight.dtype
  106. and self.wo.weight.dtype != torch.int8
  107. ):
  108. hidden_states = hidden_states.to(self.wo.weight.dtype)
  109. hidden_states = self.wo(hidden_states)
  110. return hidden_states
  111. # Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->UMT5
  112. class UMT5LayerFF(nn.Module):
  113. def __init__(self, config: UMT5Config):
  114. super().__init__()
  115. if config.is_gated_act:
  116. self.DenseReluDense = UMT5DenseGatedActDense(config)
  117. else:
  118. self.DenseReluDense = UMT5DenseActDense(config)
  119. self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  120. self.dropout = nn.Dropout(config.dropout_rate)
  121. def forward(self, hidden_states):
  122. forwarded_states = self.layer_norm(hidden_states)
  123. forwarded_states = self.DenseReluDense(forwarded_states)
  124. hidden_states = hidden_states + self.dropout(forwarded_states)
  125. return hidden_states
  126. class UMT5Attention(nn.Module):
  127. """
  128. T5's attention using relative_attention_bias.
  129. """
  130. def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
  131. super().__init__()
  132. self.is_decoder = config.is_decoder
  133. self.has_relative_attention_bias = has_relative_attention_bias
  134. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  135. self.relative_attention_max_distance = config.relative_attention_max_distance
  136. self.d_model = config.d_model
  137. self.key_value_proj_dim = config.d_kv
  138. self.n_heads = config.num_heads
  139. self.dropout = config.dropout_rate
  140. self.inner_dim = self.n_heads * self.key_value_proj_dim
  141. self.layer_idx = layer_idx
  142. if layer_idx is None and self.is_decoder:
  143. logger.warning_once(
  144. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  145. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  146. "when creating this class."
  147. )
  148. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  149. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  150. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  151. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  152. if self.has_relative_attention_bias:
  153. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  154. def _shape(self, projection: torch.Tensor) -> torch.Tensor:
  155. new_projection_shape = projection.size()[:-1] + (self.n_heads, self.key_value_proj_dim)
  156. # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
  157. new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
  158. return new_projection
  159. def _relative_position_bucket(self, relative_position):
  160. """
  161. Adapted from Mesh Tensorflow:
  162. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  163. Translate relative position to a bucket number for relative attention. The relative position is defined as
  164. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  165. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  166. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  167. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  168. This should allow for more graceful generalization to longer sequences than the model has been trained on
  169. Args:
  170. relative_position: an int32 Tensor
  171. bidirectional: a boolean - whether the attention is bidirectional
  172. num_buckets: an integer
  173. max_distance: an integer
  174. Returns:
  175. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  176. """
  177. relative_buckets = 0
  178. num_buckets = self.relative_attention_num_buckets
  179. max_distance = self.relative_attention_max_distance
  180. if not self.is_decoder:
  181. num_buckets //= 2
  182. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  183. relative_position = torch.abs(relative_position)
  184. else:
  185. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  186. # now relative_position is in the range [0, inf)
  187. # half of the buckets are for exact increments in positions
  188. max_exact = num_buckets // 2
  189. is_small = relative_position < max_exact
  190. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  191. log_ratio = torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact)
  192. log_ratio = log_ratio * (num_buckets - max_exact)
  193. relative_position_if_large = max_exact + log_ratio.to(torch.long)
  194. relative_position_if_large = torch.min(
  195. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  196. )
  197. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  198. return relative_buckets
  199. def compute_bias(self, query_length, key_length, device=None, past_seen_tokens=0):
  200. """Compute binned relative position bias"""
  201. if device is None:
  202. device = self.relative_attention_bias.weight.device
  203. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + past_seen_tokens
  204. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  205. relative_position = memory_position - context_position # shape (query_length, key_length)
  206. relative_position_bucket = self._relative_position_bucket(relative_position)
  207. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  208. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  209. return values
  210. def forward(
  211. self,
  212. hidden_states: torch.Tensor,
  213. encoder_hidden_states: torch.Tensor | None = None,
  214. past_key_values: Cache | None = None,
  215. attention_mask: torch.Tensor | None = None,
  216. **kwargs,
  217. ):
  218. batch_size, seq_length = hidden_states.shape[:2]
  219. past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0
  220. # We clone here for StaticCache, as we get the value before updating it, but use it after and it's the same ref
  221. past_seen_tokens = past_seen_tokens.clone() if isinstance(past_seen_tokens, torch.Tensor) else past_seen_tokens
  222. # if encoder_hidden_states are provided this layer is used as a cross-attention layer for the decoder
  223. is_cross_attention = encoder_hidden_states is not None
  224. query_states = self.q(hidden_states)
  225. query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  226. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  227. is_updated = False
  228. if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache):
  229. is_updated = past_key_values.is_updated.get(self.layer_idx)
  230. if is_cross_attention:
  231. # after the first generated id, we can subsequently re-use all key/value_states from cache
  232. curr_past_key_values = past_key_values.cross_attention_cache
  233. else:
  234. curr_past_key_values = past_key_values.self_attention_cache
  235. else:
  236. curr_past_key_values = past_key_values
  237. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  238. if is_cross_attention and past_key_values is not None and is_updated:
  239. # reuse k,v, cross_attentions
  240. key_states = curr_past_key_values.layers[self.layer_idx].keys
  241. value_states = curr_past_key_values.layers[self.layer_idx].values
  242. else:
  243. key_states = self.k(current_states)
  244. value_states = self.v(current_states)
  245. key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  246. value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  247. if past_key_values is not None:
  248. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  249. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  250. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  251. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  252. past_key_values.is_updated[self.layer_idx] = True
  253. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  254. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  255. key_length = key_states.shape[-2]
  256. if not self.has_relative_attention_bias:
  257. position_bias = torch.zeros(
  258. (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
  259. )
  260. else:
  261. position_bias = self.compute_bias(
  262. seq_length, key_length, device=scores.device, past_seen_tokens=past_seen_tokens
  263. )
  264. if attention_mask is not None:
  265. position_bias = position_bias + attention_mask
  266. position_bias_masked = position_bias
  267. scores += position_bias_masked
  268. # (batch_size, n_heads, seq_length, key_length)
  269. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  270. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  271. attn_output = torch.matmul(attn_weights, value_states)
  272. attn_output = attn_output.transpose(1, 2).contiguous()
  273. attn_output = attn_output.view(batch_size, seq_length, -1)
  274. attn_output = self.o(attn_output)
  275. return attn_output, attn_weights
  276. class UMT5LayerSelfAttention(nn.Module):
  277. def __init__(self, config, layer_idx: int | None = None):
  278. super().__init__()
  279. self.SelfAttention = UMT5Attention(config, has_relative_attention_bias=True, layer_idx=layer_idx)
  280. self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  281. self.dropout = nn.Dropout(config.dropout_rate)
  282. def forward(
  283. self,
  284. hidden_states,
  285. attention_mask=None,
  286. past_key_values=None,
  287. **kwargs,
  288. ):
  289. normed_hidden_states = self.layer_norm(hidden_states)
  290. attention_output = self.SelfAttention(
  291. normed_hidden_states,
  292. attention_mask=attention_mask,
  293. past_key_values=past_key_values,
  294. )
  295. hidden_states = hidden_states + self.dropout(attention_output[0])
  296. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  297. return outputs
  298. class UMT5LayerCrossAttention(nn.Module):
  299. def __init__(self, config, layer_idx: int | None = None):
  300. super().__init__()
  301. self.EncDecAttention = UMT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  302. self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  303. self.dropout = nn.Dropout(config.dropout_rate)
  304. def forward(
  305. self,
  306. hidden_states,
  307. encoder_hidden_states=None,
  308. attention_mask=None,
  309. past_key_values=None,
  310. **kwargs,
  311. ):
  312. normed_hidden_states = self.layer_norm(hidden_states)
  313. attention_output = self.EncDecAttention(
  314. normed_hidden_states,
  315. encoder_hidden_states=encoder_hidden_states,
  316. attention_mask=attention_mask,
  317. past_key_values=past_key_values,
  318. )
  319. layer_output = hidden_states + self.dropout(attention_output[0])
  320. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  321. return outputs
  322. class UMT5Block(GradientCheckpointingLayer):
  323. def __init__(self, config, layer_idx: int | None = None):
  324. super().__init__()
  325. self.is_decoder = config.is_decoder
  326. self.layer = nn.ModuleList()
  327. self.layer.append(UMT5LayerSelfAttention(config, layer_idx=layer_idx))
  328. if self.is_decoder:
  329. self.layer.append(UMT5LayerCrossAttention(config, layer_idx=layer_idx))
  330. self.layer.append(UMT5LayerFF(config))
  331. def forward(
  332. self,
  333. hidden_states,
  334. attention_mask=None,
  335. encoder_hidden_states=None,
  336. encoder_attention_mask=None,
  337. past_key_values=None,
  338. use_cache=False,
  339. output_attentions=False,
  340. **kwargs,
  341. ):
  342. hidden_states, self_attn_weights = self.layer[0](
  343. hidden_states,
  344. attention_mask=attention_mask,
  345. past_key_values=past_key_values,
  346. )
  347. # clamp inf values to enable fp16 training
  348. if hidden_states.dtype == torch.float16:
  349. max_dtype = torch.finfo(hidden_states.dtype).max
  350. clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype)
  351. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  352. # Cross-Attention Block
  353. cross_attn_weights = None
  354. do_cross_attention = self.is_decoder and encoder_hidden_states is not None
  355. if do_cross_attention:
  356. hidden_states, cross_attn_weights = self.layer[1](
  357. hidden_states,
  358. encoder_hidden_states=encoder_hidden_states,
  359. attention_mask=encoder_attention_mask,
  360. past_key_values=past_key_values,
  361. )
  362. # clamp inf values to enable fp16 training
  363. if hidden_states.dtype == torch.float16:
  364. max_dtype = torch.finfo(hidden_states.dtype).max
  365. clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype)
  366. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  367. # Apply Feed Forward layer
  368. hidden_states = self.layer[-1](hidden_states)
  369. # clamp inf values to enable fp16 training
  370. if hidden_states.dtype == torch.float16:
  371. max_dtype = torch.finfo(hidden_states.dtype).max
  372. clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype)
  373. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  374. outputs = (hidden_states,)
  375. if output_attentions:
  376. outputs += (self_attn_weights, cross_attn_weights)
  377. return outputs
  378. # Copied from transformers.models.t5.modeling_t5.T5ClassificationHead with T5->UMT5
  379. class UMT5ClassificationHead(nn.Module):
  380. """Head for sentence-level classification tasks."""
  381. def __init__(self, config: UMT5Config):
  382. super().__init__()
  383. self.dense = nn.Linear(config.d_model, config.d_model)
  384. self.dropout = nn.Dropout(p=config.classifier_dropout)
  385. self.out_proj = nn.Linear(config.d_model, config.num_labels)
  386. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  387. hidden_states = self.dropout(hidden_states)
  388. hidden_states = self.dense(hidden_states)
  389. hidden_states = torch.tanh(hidden_states)
  390. hidden_states = self.dropout(hidden_states)
  391. hidden_states = self.out_proj(hidden_states)
  392. return hidden_states
  393. @auto_docstring
  394. class UMT5PreTrainedModel(PreTrainedModel):
  395. config: UMT5Config
  396. base_model_prefix = "transformer"
  397. supports_gradient_checkpointing = True
  398. _can_compile_fullgraph = True
  399. _no_split_modules = ["UMT5Block"]
  400. _keep_in_fp32_modules = ["wo"]
  401. @property
  402. def dummy_inputs(self):
  403. input_ids = torch.tensor(DUMMY_INPUTS)
  404. input_mask = torch.tensor(DUMMY_MASK)
  405. dummy_inputs = {
  406. "decoder_input_ids": input_ids,
  407. "input_ids": input_ids,
  408. "decoder_attention_mask": input_mask,
  409. }
  410. return dummy_inputs
  411. @torch.no_grad()
  412. def _init_weights(self, module):
  413. """Initialize the weights"""
  414. factor = self.config.initializer_factor # Used for testing weights initialization
  415. if isinstance(module, UMT5LayerNorm):
  416. init.constant_(module.weight, factor * 1.0)
  417. elif isinstance(
  418. module,
  419. (
  420. UMT5Model,
  421. UMT5ForConditionalGeneration,
  422. UMT5EncoderModel,
  423. UMT5ForQuestionAnswering,
  424. ),
  425. ):
  426. # Mesh TensorFlow embeddings initialization
  427. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
  428. init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
  429. if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
  430. init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0)
  431. if hasattr(module, "qa_outputs"):
  432. init.normal_(module.qa_outputs.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  433. init.zeros_(module.qa_outputs.bias)
  434. elif isinstance(module, UMT5ForTokenClassification):
  435. if hasattr(module, "classifier"):
  436. init.normal_(module.classifier.weight, mean=0.0, std=factor * 1.0)
  437. init.zeros_(module.classifier.bias)
  438. elif isinstance(module, UMT5ClassificationHead):
  439. init.normal_(module.dense.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  440. if hasattr(module.dense, "bias") and module.dense.bias is not None:
  441. init.zeros_(module.dense.bias)
  442. init.normal_(module.out_proj.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  443. if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
  444. init.zeros_(module.out_proj.bias)
  445. elif isinstance(module, UMT5DenseActDense):
  446. # Mesh TensorFlow FF initialization
  447. # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
  448. # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
  449. init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  450. if hasattr(module.wi, "bias") and module.wi.bias is not None:
  451. init.zeros_(module.wi.bias)
  452. init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  453. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  454. init.zeros_(module.wo.bias)
  455. elif isinstance(module, UMT5DenseGatedActDense):
  456. init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  457. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  458. init.zeros_(module.wi_0.bias)
  459. init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  460. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  461. init.zeros_(module.wi_1.bias)
  462. init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  463. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  464. init.zeros_(module.wo.bias)
  465. elif isinstance(module, UMT5Attention):
  466. # Mesh TensorFlow attention initialization to avoid scaling before softmax
  467. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
  468. d_model = self.config.d_model
  469. key_value_proj_dim = self.config.d_kv
  470. n_heads = self.config.num_heads
  471. init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
  472. init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5))
  473. init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5))
  474. init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  475. if module.has_relative_attention_bias:
  476. init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5))
  477. def _shift_right(self, input_ids):
  478. decoder_start_token_id = self.config.decoder_start_token_id
  479. pad_token_id = self.config.pad_token_id
  480. if decoder_start_token_id is None:
  481. raise ValueError(
  482. "self.model.config.decoder_start_token_id has to be defined. In UMT5 it is usually set to the pad_token_id. "
  483. "See UMT5 docs for more information."
  484. )
  485. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  486. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  487. shifted_input_ids[..., 0] = decoder_start_token_id
  488. if pad_token_id is None:
  489. raise ValueError("self.model.config.pad_token_id has to be defined.")
  490. # replace possible -100 values in labels by `pad_token_id`
  491. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  492. return shifted_input_ids
  493. class UMT5Stack(UMT5PreTrainedModel):
  494. def __init__(self, config):
  495. super().__init__(config)
  496. self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
  497. self.is_decoder = config.is_decoder
  498. self.block = nn.ModuleList([UMT5Block(config, layer_idx=i) for i in range(config.num_layers)])
  499. self.final_layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  500. self.dropout = nn.Dropout(config.dropout_rate)
  501. # Initialize weights and apply final processing
  502. self.gradient_checkpointing = False
  503. self.post_init()
  504. def set_input_embeddings(self, new_embeddings):
  505. self.embed_tokens = new_embeddings
  506. def forward(
  507. self,
  508. input_ids=None,
  509. attention_mask=None,
  510. encoder_hidden_states=None,
  511. encoder_attention_mask=None,
  512. inputs_embeds=None,
  513. past_key_values=None,
  514. use_cache=None,
  515. output_attentions=None,
  516. output_hidden_states=None,
  517. return_dict=None,
  518. **kwargs,
  519. ):
  520. use_cache = use_cache if use_cache is not None else self.config.use_cache
  521. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  522. output_hidden_states = (
  523. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  524. )
  525. return_dict = return_dict if return_dict is not None else self.config.return_dict
  526. if input_ids is not None and inputs_embeds is not None:
  527. err_msg_prefix = "decoder_" if self.is_decoder else ""
  528. raise ValueError(
  529. f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
  530. )
  531. elif input_ids is not None:
  532. input_shape = input_ids.size()
  533. input_ids = input_ids.view(-1, input_shape[-1])
  534. elif inputs_embeds is not None:
  535. input_shape = inputs_embeds.size()[:-1]
  536. else:
  537. err_msg_prefix = "decoder_" if self.is_decoder else ""
  538. raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
  539. if self.gradient_checkpointing and self.training:
  540. if use_cache:
  541. logger.warning_once(
  542. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  543. )
  544. use_cache = False
  545. if inputs_embeds is None:
  546. if self.embed_tokens is None:
  547. raise ValueError("You have to initialize the model with valid token embeddings")
  548. inputs_embeds = self.embed_tokens(input_ids)
  549. batch_size, seq_length = input_shape
  550. if use_cache is True:
  551. if not self.is_decoder:
  552. raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
  553. # initialize past_key_values
  554. if self.is_decoder:
  555. if use_cache and past_key_values is None:
  556. if self.config.is_encoder_decoder:
  557. past_key_values = EncoderDecoderCache(
  558. DynamicCache(config=self.config), DynamicCache(config=self.config)
  559. )
  560. else:
  561. past_key_values = DynamicCache(config=self.config)
  562. elif not self.is_decoder:
  563. # do not pass cache object down the line for encoder stack
  564. # it messes indexing later in decoder-stack because cache object is modified in-place
  565. past_key_values = None
  566. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  567. if attention_mask is None and not is_torchdynamo_compiling():
  568. # required mask seq length can be calculated via length of past cache
  569. mask_seq_length = past_key_values_length + seq_length
  570. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  571. if self.is_decoder:
  572. causal_mask = create_causal_mask(
  573. config=self.config,
  574. inputs_embeds=inputs_embeds,
  575. attention_mask=attention_mask,
  576. past_key_values=past_key_values,
  577. )
  578. elif attention_mask is not None:
  579. causal_mask = attention_mask[:, None, None, :]
  580. causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
  581. causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
  582. else:
  583. causal_mask = None
  584. # If a 2D or 3D attention mask is provided for the cross-attention
  585. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  586. if self.is_decoder and encoder_hidden_states is not None:
  587. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  588. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  589. if encoder_attention_mask is None:
  590. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
  591. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  592. else:
  593. encoder_extended_attention_mask = None
  594. all_hidden_states = () if output_hidden_states else None
  595. all_attentions = () if output_attentions else None
  596. all_cross_attentions = () if output_attentions and self.is_decoder else None
  597. hidden_states = self.dropout(inputs_embeds)
  598. for i, layer_module in enumerate(self.block):
  599. if output_hidden_states:
  600. all_hidden_states = all_hidden_states + (hidden_states,)
  601. layer_outputs = layer_module(
  602. hidden_states,
  603. causal_mask,
  604. encoder_hidden_states, # as a positional argument for gradient checkpointing
  605. encoder_attention_mask=encoder_extended_attention_mask,
  606. past_key_values=past_key_values,
  607. use_cache=use_cache,
  608. output_attentions=output_attentions,
  609. )
  610. hidden_states = layer_outputs[0]
  611. if output_attentions:
  612. all_attentions += (layer_outputs[1],)
  613. if self.is_decoder:
  614. all_cross_attentions += (layer_outputs[2],)
  615. hidden_states = self.final_layer_norm(hidden_states)
  616. hidden_states = self.dropout(hidden_states)
  617. # Add last layer
  618. if output_hidden_states:
  619. all_hidden_states = all_hidden_states + (hidden_states,)
  620. if not return_dict:
  621. return tuple(
  622. v
  623. for v in [
  624. hidden_states,
  625. past_key_values,
  626. all_hidden_states,
  627. all_attentions,
  628. all_cross_attentions,
  629. ]
  630. if v is not None
  631. )
  632. return BaseModelOutputWithPastAndCrossAttentions(
  633. last_hidden_state=hidden_states,
  634. past_key_values=past_key_values,
  635. hidden_states=all_hidden_states,
  636. attentions=all_attentions,
  637. cross_attentions=all_cross_attentions,
  638. )
  639. @auto_docstring
  640. class UMT5Model(UMT5PreTrainedModel):
  641. r"""
  642. Examples:
  643. ```python
  644. >>> from transformers import UMT5Model, AutoTokenizer
  645. >>> model = UMT5Model.from_pretrained("google/umt5-small")
  646. >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
  647. >>> noisy_text = "UN Offizier sagt, dass weiter <extra_id_0> werden muss in Syrien."
  648. >>> label = "<extra_id_0> verhandelt"
  649. >>> inputs = tokenizer(inputs, return_tensors="pt")
  650. >>> labels = tokenizer(label=label, return_tensors="pt")
  651. >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
  652. >>> hidden_states = outputs.last_hidden_state
  653. ```"""
  654. model_type = "umt5"
  655. config: UMT5Config
  656. _tied_weights_keys = {
  657. "encoder.embed_tokens.weight": "shared.weight",
  658. "decoder.embed_tokens.weight": "shared.weight",
  659. }
  660. def __init__(self, config):
  661. super().__init__(config)
  662. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  663. encoder_config = copy.deepcopy(config)
  664. encoder_config.is_decoder = False
  665. encoder_config.use_cache = False
  666. self.encoder = UMT5Stack(encoder_config)
  667. decoder_config = copy.deepcopy(config)
  668. decoder_config.is_decoder = True
  669. decoder_config.num_layers = config.num_decoder_layers
  670. self.decoder = UMT5Stack(decoder_config)
  671. # Initialize weights and apply final processing
  672. self.post_init()
  673. # Copied from transformers.models.t5.modeling_t5.T5Model.get_input_embeddings
  674. def get_input_embeddings(self):
  675. return self.shared
  676. # Copied from transformers.models.t5.modeling_t5.T5Model.set_input_embeddings
  677. def set_input_embeddings(self, new_embeddings):
  678. self.shared = new_embeddings
  679. self.encoder.set_input_embeddings(new_embeddings)
  680. self.decoder.set_input_embeddings(new_embeddings)
  681. @auto_docstring
  682. def forward(
  683. self,
  684. input_ids: torch.LongTensor | None = None,
  685. attention_mask: torch.FloatTensor | None = None,
  686. decoder_input_ids: torch.LongTensor | None = None,
  687. decoder_attention_mask: torch.BoolTensor | None = None,
  688. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  689. past_key_values: Cache | None = None,
  690. inputs_embeds: torch.Tensor | None = None,
  691. decoder_inputs_embeds: torch.Tensor | None = None,
  692. use_cache: bool | None = None,
  693. output_attentions: bool | None = None,
  694. output_hidden_states: bool | None = None,
  695. return_dict: bool | None = None,
  696. **kwargs,
  697. ) -> tuple[torch.FloatTensor] | Seq2SeqModelOutput:
  698. r"""
  699. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  700. Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so
  701. you should be able to pad the inputs on both the right and the left.
  702. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  703. [`PreTrainedTokenizer.__call__`] for detail.
  704. [What are input IDs?](../glossary#input-ids)
  705. To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
  706. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  707. Indices of decoder input sequence tokens in the vocabulary.
  708. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  709. [`PreTrainedTokenizer.__call__`] for details.
  710. [What are decoder input IDs?](../glossary#decoder-input-ids)
  711. UMT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  712. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  713. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [UMT5
  714. Training](./umt5#training).
  715. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  716. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  717. be used by default.
  718. Example:
  719. ```python
  720. >>> from transformers import AutoTokenizer, UMT5Model
  721. >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
  722. >>> model = UMT5Model.from_pretrained("google/umt5-small")
  723. >>> input_ids = tokenizer(
  724. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  725. ... ).input_ids # Batch size 1
  726. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  727. >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for UMT5Model.
  728. >>> # This is not needed for torch's UMT5ForConditionalGeneration as it does this internally using labels arg.
  729. >>> decoder_input_ids = model._shift_right(decoder_input_ids)
  730. >>> # forward pass
  731. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  732. >>> last_hidden_states = outputs.last_hidden_state
  733. ```"""
  734. use_cache = use_cache if use_cache is not None else self.config.use_cache
  735. return_dict = return_dict if return_dict is not None else self.config.return_dict
  736. # Encode if needed (training, first prediction pass)
  737. if encoder_outputs is None:
  738. encoder_outputs = self.encoder(
  739. input_ids=input_ids,
  740. attention_mask=attention_mask,
  741. inputs_embeds=inputs_embeds,
  742. output_attentions=output_attentions,
  743. output_hidden_states=output_hidden_states,
  744. return_dict=return_dict,
  745. )
  746. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  747. encoder_outputs = BaseModelOutput(
  748. last_hidden_state=encoder_outputs[0],
  749. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  750. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  751. )
  752. hidden_states = encoder_outputs[0]
  753. # Decode
  754. decoder_outputs = self.decoder(
  755. input_ids=decoder_input_ids,
  756. attention_mask=decoder_attention_mask,
  757. inputs_embeds=decoder_inputs_embeds,
  758. past_key_values=past_key_values,
  759. encoder_hidden_states=hidden_states,
  760. encoder_attention_mask=attention_mask,
  761. use_cache=use_cache,
  762. output_attentions=output_attentions,
  763. output_hidden_states=output_hidden_states,
  764. return_dict=return_dict,
  765. )
  766. if not return_dict:
  767. return decoder_outputs + encoder_outputs
  768. return Seq2SeqModelOutput(
  769. last_hidden_state=decoder_outputs.last_hidden_state,
  770. past_key_values=decoder_outputs.past_key_values,
  771. decoder_hidden_states=decoder_outputs.hidden_states,
  772. decoder_attentions=decoder_outputs.attentions,
  773. cross_attentions=decoder_outputs.cross_attentions,
  774. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  775. encoder_hidden_states=encoder_outputs.hidden_states,
  776. encoder_attentions=encoder_outputs.attentions,
  777. )
  778. @auto_docstring(
  779. custom_intro="""
  780. UMT5 Model with a `language modeling` head on top.
  781. """
  782. )
  783. class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin):
  784. r"""
  785. Examples:
  786. ```python
  787. >>> from transformers import UMT5ForConditionalGeneration, AutoTokenizer
  788. >>> model = UMT5ForConditionalGeneration.from_pretrained("google/umt5-small")
  789. >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
  790. >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
  791. >>> summary = "Weiter Verhandlung in Syrien."
  792. >>> inputs = tokenizer(article, text_target=summary, return_tensors="pt")
  793. >>> outputs = model(**inputs)
  794. >>> loss = outputs.loss
  795. ```"""
  796. model_type = "umt5"
  797. _tied_weights_keys = {
  798. "encoder.embed_tokens.weight": "shared.weight",
  799. "decoder.embed_tokens.weight": "shared.weight",
  800. "lm_head.weight": "shared.weight",
  801. }
  802. def __init__(self, config):
  803. super().__init__(config)
  804. self.model_dim = config.d_model
  805. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  806. encoder_config = copy.deepcopy(config)
  807. encoder_config.is_decoder = False
  808. encoder_config.use_cache = False
  809. self.encoder = UMT5Stack(encoder_config)
  810. decoder_config = copy.deepcopy(config)
  811. decoder_config.is_decoder = True
  812. decoder_config.num_layers = config.num_decoder_layers
  813. self.decoder = UMT5Stack(decoder_config)
  814. self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
  815. # Initialize weights and apply final processing
  816. self.post_init()
  817. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_input_embeddings
  818. def get_input_embeddings(self):
  819. return self.shared
  820. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_input_embeddings
  821. def set_input_embeddings(self, new_embeddings):
  822. self.shared = new_embeddings
  823. self.encoder.set_input_embeddings(new_embeddings)
  824. self.decoder.set_input_embeddings(new_embeddings)
  825. @auto_docstring
  826. def forward(
  827. self,
  828. input_ids: torch.LongTensor | None = None,
  829. attention_mask: torch.FloatTensor | None = None,
  830. decoder_input_ids: torch.LongTensor | None = None,
  831. decoder_attention_mask: torch.BoolTensor | None = None,
  832. encoder_outputs: tuple[tuple[torch.Tensor]] | None = None,
  833. past_key_values: Cache | None = None,
  834. inputs_embeds: torch.FloatTensor | None = None,
  835. decoder_inputs_embeds: torch.FloatTensor | None = None,
  836. labels: torch.LongTensor | None = None,
  837. use_cache: bool | None = None,
  838. output_attentions: bool | None = None,
  839. output_hidden_states: bool | None = None,
  840. return_dict: bool | None = None,
  841. **kwargs,
  842. ) -> tuple[torch.FloatTensor] | Seq2SeqLMOutput:
  843. r"""
  844. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  845. Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so
  846. you should be able to pad the inputs on both the right and the left.
  847. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  848. [`PreTrainedTokenizer.__call__`] for detail.
  849. [What are input IDs?](../glossary#input-ids)
  850. To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
  851. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  852. Indices of decoder input sequence tokens in the vocabulary.
  853. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  854. [`PreTrainedTokenizer.__call__`] for details.
  855. [What are decoder input IDs?](../glossary#decoder-input-ids)
  856. UMT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  857. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  858. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [UMT5
  859. Training](./umt5#training).
  860. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  861. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  862. be used by default.
  863. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  864. Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
  865. config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
  866. labels in `[0, ..., config.vocab_size]`
  867. Examples:
  868. ```python
  869. >>> from transformers import AutoTokenizer, UMT5ForConditionalGeneration
  870. >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
  871. >>> model = UMT5ForConditionalGeneration.from_pretrained("google/umt5-small")
  872. >>> # training
  873. >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
  874. >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
  875. >>> outputs = model(input_ids=input_ids, labels=labels)
  876. >>> loss = outputs.loss
  877. >>> logits = outputs.logits
  878. >>> # inference
  879. >>> input_ids = tokenizer("Studies have shown that <extra_id_0> good for you", return_tensors="pt").input_ids
  880. >>> outputs = model.generate(input_ids)
  881. >>> tokenizer.decode(outputs[0], skip_special_tokens=True)
  882. ```"""
  883. use_cache = use_cache if use_cache is not None else self.config.use_cache
  884. return_dict = return_dict if return_dict is not None else self.config.return_dict
  885. # Encode if needed (training, first prediction pass)
  886. if encoder_outputs is None:
  887. # Convert encoder inputs in embeddings if needed
  888. encoder_outputs = self.encoder(
  889. input_ids=input_ids,
  890. attention_mask=attention_mask,
  891. inputs_embeds=inputs_embeds,
  892. output_attentions=output_attentions,
  893. output_hidden_states=output_hidden_states,
  894. return_dict=return_dict,
  895. )
  896. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  897. encoder_outputs = BaseModelOutput(
  898. last_hidden_state=encoder_outputs[0],
  899. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  900. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  901. )
  902. hidden_states = encoder_outputs[0]
  903. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  904. # get decoder inputs from shifting lm labels to the right
  905. decoder_input_ids = self._shift_right(labels)
  906. # Decode
  907. decoder_outputs = self.decoder(
  908. input_ids=decoder_input_ids,
  909. attention_mask=decoder_attention_mask,
  910. inputs_embeds=decoder_inputs_embeds,
  911. past_key_values=past_key_values,
  912. encoder_hidden_states=hidden_states,
  913. encoder_attention_mask=attention_mask,
  914. use_cache=use_cache,
  915. output_attentions=output_attentions,
  916. output_hidden_states=output_hidden_states,
  917. return_dict=return_dict,
  918. )
  919. sequence_output = decoder_outputs[0]
  920. if self.config.tie_word_embeddings:
  921. # Rescale output before projecting on vocab
  922. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
  923. sequence_output = sequence_output * (self.model_dim**-0.5)
  924. lm_logits = self.lm_head(sequence_output)
  925. loss = None
  926. if labels is not None:
  927. loss_fct = CrossEntropyLoss(ignore_index=-100)
  928. # move labels to correct device to enable PP
  929. labels = labels.to(lm_logits.device)
  930. loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
  931. if not return_dict:
  932. output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
  933. return ((loss,) + output) if loss is not None else output
  934. return Seq2SeqLMOutput(
  935. loss=loss,
  936. logits=lm_logits,
  937. past_key_values=decoder_outputs.past_key_values,
  938. decoder_hidden_states=decoder_outputs.hidden_states,
  939. decoder_attentions=decoder_outputs.attentions,
  940. cross_attentions=decoder_outputs.cross_attentions,
  941. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  942. encoder_hidden_states=encoder_outputs.hidden_states,
  943. encoder_attentions=encoder_outputs.attentions,
  944. )
  945. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels
  946. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  947. return self._shift_right(labels)
  948. @auto_docstring
  949. class UMT5EncoderModel(UMT5PreTrainedModel):
  950. r"""
  951. Examples:
  952. ```python
  953. >>> from transformers import UMT5EncoderModel, AutoTokenizer
  954. >>> model = UMT5EncoderModel.from_pretrained("google/umt5-small")
  955. >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
  956. >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
  957. >>> input_ids = tokenizer(article, return_tensors="pt").input_ids
  958. >>> outputs = model(input_ids)
  959. >>> hidden_state = outputs.last_hidden_state
  960. ```"""
  961. model_type = "umt5"
  962. # config_class = UMT5Config
  963. _tied_weights_keys = {
  964. "encoder.embed_tokens.weight": "shared.weight",
  965. }
  966. def __init__(self, config):
  967. super().__init__(config)
  968. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  969. encoder_config = copy.deepcopy(config)
  970. encoder_config.use_cache = False
  971. encoder_config.is_encoder_decoder = False
  972. self.encoder = UMT5Stack(encoder_config)
  973. # Initialize weights and apply final processing
  974. self.post_init()
  975. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_input_embeddings
  976. def get_input_embeddings(self):
  977. return self.shared
  978. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.set_input_embeddings
  979. def set_input_embeddings(self, new_embeddings):
  980. self.shared = new_embeddings
  981. self.encoder.set_input_embeddings(new_embeddings)
  982. @auto_docstring
  983. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.forward with T5->UMT5, google-t5/t5-small->google/umt5-small, t5#training->umt5#training
  984. def forward(
  985. self,
  986. input_ids: torch.LongTensor | None = None,
  987. attention_mask: torch.FloatTensor | None = None,
  988. inputs_embeds: torch.FloatTensor | None = None,
  989. output_attentions: bool | None = None,
  990. output_hidden_states: bool | None = None,
  991. return_dict: bool | None = None,
  992. **kwargs,
  993. ) -> tuple[torch.FloatTensor] | BaseModelOutput:
  994. r"""
  995. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  996. Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so you
  997. should be able to pad the inputs on both the right and the left.
  998. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  999. [`PreTrainedTokenizer.__call__`] for detail.
  1000. To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
  1001. Example:
  1002. ```python
  1003. >>> from transformers import AutoTokenizer, UMT5EncoderModel
  1004. >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
  1005. >>> model = UMT5EncoderModel.from_pretrained("google/umt5-small")
  1006. >>> input_ids = tokenizer(
  1007. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1008. ... ).input_ids # Batch size 1
  1009. >>> outputs = model(input_ids=input_ids)
  1010. >>> last_hidden_states = outputs.last_hidden_state
  1011. ```"""
  1012. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1013. encoder_outputs = self.encoder(
  1014. input_ids=input_ids,
  1015. attention_mask=attention_mask,
  1016. inputs_embeds=inputs_embeds,
  1017. output_attentions=output_attentions,
  1018. output_hidden_states=output_hidden_states,
  1019. return_dict=return_dict,
  1020. )
  1021. return encoder_outputs
  1022. @auto_docstring(
  1023. custom_intro="""
  1024. UMT5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  1025. tasks.
  1026. """
  1027. )
  1028. class UMT5ForSequenceClassification(UMT5PreTrainedModel):
  1029. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1030. # Copied from transformers.models.t5.modeling_t5.T5ForSequenceClassification.__init__ with T5->UMT5
  1031. def __init__(self, config: UMT5Config):
  1032. super().__init__(config)
  1033. self.transformer = UMT5Model(config)
  1034. self.classification_head = UMT5ClassificationHead(config)
  1035. # Initialize weights and apply final processing
  1036. self.post_init()
  1037. @auto_docstring
  1038. def forward(
  1039. self,
  1040. input_ids: torch.LongTensor | None = None,
  1041. attention_mask: torch.Tensor | None = None,
  1042. decoder_input_ids: torch.LongTensor | None = None,
  1043. decoder_attention_mask: torch.LongTensor | None = None,
  1044. encoder_outputs: list[torch.FloatTensor] | None = None,
  1045. inputs_embeds: torch.FloatTensor | None = None,
  1046. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1047. labels: torch.LongTensor | None = None,
  1048. use_cache: bool | None = None,
  1049. output_attentions: bool | None = None,
  1050. output_hidden_states: bool | None = None,
  1051. return_dict: bool | None = None,
  1052. **kwargs,
  1053. ) -> tuple | Seq2SeqSequenceClassifierOutput:
  1054. r"""
  1055. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1056. Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so
  1057. you should be able to pad the inputs on both the right and the left.
  1058. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1059. [`PreTrainedTokenizer.__call__`] for detail.
  1060. [What are input IDs?](../glossary#input-ids)
  1061. To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
  1062. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1063. Indices of decoder input sequence tokens in the vocabulary.
  1064. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1065. [`PreTrainedTokenizer.__call__`] for details.
  1066. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1067. UMT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1068. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1069. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [UMT5
  1070. Training](./umt5#training).
  1071. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1072. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1073. be used by default.
  1074. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1075. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1076. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1077. """
  1078. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1079. if labels is not None:
  1080. use_cache = False
  1081. if input_ids is None and inputs_embeds is not None:
  1082. raise NotImplementedError(
  1083. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  1084. )
  1085. # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates
  1086. # decoder_input_ids from input_ids if no decoder_input_ids are provided
  1087. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1088. if input_ids is None:
  1089. raise ValueError(
  1090. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1091. "passed, `input_ids` cannot be `None`. Please pass either "
  1092. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1093. )
  1094. decoder_input_ids = self._shift_right(input_ids)
  1095. outputs = self.transformer(
  1096. input_ids,
  1097. attention_mask=attention_mask,
  1098. decoder_input_ids=decoder_input_ids,
  1099. decoder_attention_mask=decoder_attention_mask,
  1100. encoder_outputs=encoder_outputs,
  1101. inputs_embeds=inputs_embeds,
  1102. decoder_inputs_embeds=decoder_inputs_embeds,
  1103. use_cache=use_cache,
  1104. output_attentions=output_attentions,
  1105. output_hidden_states=output_hidden_states,
  1106. return_dict=return_dict,
  1107. )
  1108. sequence_output = outputs[0]
  1109. eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
  1110. torch_compilable_check(
  1111. torch.unique_consecutive(eos_mask.sum(1)).numel() == 1,
  1112. "All examples must have the same number of <eos> tokens.",
  1113. )
  1114. batch_size, _, hidden_size = sequence_output.shape
  1115. sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
  1116. logits = self.classification_head(sentence_representation)
  1117. loss = None
  1118. if labels is not None:
  1119. labels = labels.to(logits.device)
  1120. if self.config.problem_type is None:
  1121. if self.config.num_labels == 1:
  1122. self.config.problem_type = "regression"
  1123. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1124. self.config.problem_type = "single_label_classification"
  1125. else:
  1126. self.config.problem_type = "multi_label_classification"
  1127. if self.config.problem_type == "regression":
  1128. loss_fct = MSELoss()
  1129. if self.config.num_labels == 1:
  1130. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1131. else:
  1132. loss = loss_fct(logits, labels)
  1133. elif self.config.problem_type == "single_label_classification":
  1134. loss_fct = CrossEntropyLoss()
  1135. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1136. elif self.config.problem_type == "multi_label_classification":
  1137. loss_fct = BCEWithLogitsLoss()
  1138. loss = loss_fct(logits, labels)
  1139. if not return_dict:
  1140. output = (logits,) + outputs[1:]
  1141. return ((loss,) + output) if loss is not None else output
  1142. return Seq2SeqSequenceClassifierOutput(
  1143. loss=loss,
  1144. logits=logits,
  1145. past_key_values=outputs.past_key_values,
  1146. decoder_hidden_states=outputs.decoder_hidden_states,
  1147. decoder_attentions=outputs.decoder_attentions,
  1148. cross_attentions=outputs.cross_attentions,
  1149. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1150. encoder_hidden_states=outputs.encoder_hidden_states,
  1151. encoder_attentions=outputs.encoder_attentions,
  1152. )
  1153. @auto_docstring
  1154. class UMT5ForTokenClassification(UMT5PreTrainedModel):
  1155. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1156. # Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.__init__ with T5->UMT5
  1157. def __init__(self, config: UMT5Config):
  1158. super().__init__(config)
  1159. self.num_labels = config.num_labels
  1160. self.transformer = UMT5EncoderModel(config)
  1161. self.dropout = nn.Dropout(config.classifier_dropout)
  1162. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1163. # Initialize weights and apply final processing
  1164. self.post_init()
  1165. @auto_docstring
  1166. # Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.forward with T5->UMT5, t5->umt5
  1167. def forward(
  1168. self,
  1169. input_ids: torch.Tensor | None = None,
  1170. attention_mask: torch.Tensor | None = None,
  1171. inputs_embeds: torch.Tensor | None = None,
  1172. labels: torch.Tensor | None = None,
  1173. output_attentions: bool | None = None,
  1174. output_hidden_states: bool | None = None,
  1175. return_dict: bool | None = None,
  1176. **kwargs,
  1177. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  1178. r"""
  1179. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1180. Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so you
  1181. should be able to pad the inputs on both the right and the left.
  1182. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1183. [`PreTrainedTokenizer.__call__`] for detail.
  1184. [What are input IDs?](../glossary#input-ids)
  1185. To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
  1186. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1187. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1188. """
  1189. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1190. outputs = self.transformer(
  1191. input_ids,
  1192. attention_mask=attention_mask,
  1193. inputs_embeds=inputs_embeds,
  1194. output_attentions=output_attentions,
  1195. output_hidden_states=output_hidden_states,
  1196. return_dict=return_dict,
  1197. )
  1198. hidden_states = outputs[0]
  1199. hidden_states = self.dropout(hidden_states)
  1200. logits = self.classifier(hidden_states)
  1201. loss = None
  1202. if labels is not None:
  1203. loss_fct = CrossEntropyLoss()
  1204. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1205. if not return_dict:
  1206. output = (logits, outputs[2:-1])
  1207. return ((loss,) + output) if loss is not None else output
  1208. return TokenClassifierOutput(
  1209. loss=loss,
  1210. logits=logits,
  1211. hidden_states=outputs.hidden_states,
  1212. attentions=outputs.attentions,
  1213. )
  1214. @auto_docstring
  1215. class UMT5ForQuestionAnswering(UMT5PreTrainedModel):
  1216. _tied_weights_keys = {
  1217. "encoder.embed_tokens.weight": "shared.weight",
  1218. "decoder.embed_tokens.weight": "shared.weight",
  1219. }
  1220. def __init__(self, config):
  1221. super().__init__(config)
  1222. self.model_dim = config.d_model
  1223. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1224. encoder_config = copy.deepcopy(config)
  1225. encoder_config.is_decoder = False
  1226. encoder_config.use_cache = False
  1227. self.encoder = UMT5Stack(encoder_config)
  1228. decoder_config = copy.deepcopy(config)
  1229. decoder_config.is_decoder = True
  1230. decoder_config.num_layers = config.num_decoder_layers
  1231. self.decoder = UMT5Stack(decoder_config)
  1232. self.num_labels = config.num_labels
  1233. self.qa_outputs = nn.Linear(config.d_model, config.num_labels)
  1234. # Initialize weights and apply final processing
  1235. self.post_init()
  1236. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_input_embeddings
  1237. def get_input_embeddings(self):
  1238. return self.shared
  1239. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.set_input_embeddings
  1240. def set_input_embeddings(self, new_embeddings):
  1241. self.shared = new_embeddings
  1242. self.encoder.set_input_embeddings(new_embeddings)
  1243. self.decoder.set_input_embeddings(new_embeddings)
  1244. @auto_docstring
  1245. def forward(
  1246. self,
  1247. input_ids: torch.LongTensor | None = None,
  1248. attention_mask: torch.FloatTensor | None = None,
  1249. decoder_input_ids: torch.LongTensor | None = None,
  1250. decoder_attention_mask: torch.BoolTensor | None = None,
  1251. encoder_outputs: tuple[tuple[torch.Tensor]] | None = None,
  1252. start_positions: torch.LongTensor | None = None,
  1253. end_positions: torch.LongTensor | None = None,
  1254. inputs_embeds: torch.FloatTensor | None = None,
  1255. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1256. use_cache: bool | None = None,
  1257. output_attentions: bool | None = None,
  1258. output_hidden_states: bool | None = None,
  1259. return_dict: bool | None = None,
  1260. **kwargs,
  1261. ) -> tuple[torch.FloatTensor] | Seq2SeqQuestionAnsweringModelOutput:
  1262. r"""
  1263. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1264. Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so
  1265. you should be able to pad the inputs on both the right and the left.
  1266. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1267. [`PreTrainedTokenizer.__call__`] for detail.
  1268. [What are input IDs?](../glossary#input-ids)
  1269. To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
  1270. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1271. Indices of decoder input sequence tokens in the vocabulary.
  1272. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1273. [`PreTrainedTokenizer.__call__`] for details.
  1274. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1275. UMT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1276. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1277. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [UMT5
  1278. Training](./umt5#training).
  1279. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1280. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1281. be used by default.
  1282. """
  1283. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1284. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1285. if start_positions is not None and end_positions is not None:
  1286. use_cache = False
  1287. # Copied from models.bart.modeling_bart.BartModel.forward
  1288. # different to other models, T5 automatically creates decoder_input_ids from
  1289. # input_ids if no decoder_input_ids are provided
  1290. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1291. if input_ids is None:
  1292. raise ValueError(
  1293. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1294. "passed, `input_ids` cannot be `None`. Please pass either "
  1295. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1296. )
  1297. decoder_input_ids = self._shift_right(input_ids)
  1298. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1299. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1300. # Encode if needed (training, first prediction pass)
  1301. if encoder_outputs is None:
  1302. encoder_outputs = self.encoder(
  1303. input_ids=input_ids,
  1304. attention_mask=attention_mask,
  1305. inputs_embeds=inputs_embeds,
  1306. output_attentions=output_attentions,
  1307. output_hidden_states=output_hidden_states,
  1308. return_dict=return_dict,
  1309. )
  1310. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1311. encoder_outputs = BaseModelOutput(
  1312. last_hidden_state=encoder_outputs[0],
  1313. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1314. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1315. )
  1316. hidden_states = encoder_outputs[0]
  1317. # Decode
  1318. decoder_outputs = self.decoder(
  1319. input_ids=decoder_input_ids,
  1320. attention_mask=decoder_attention_mask,
  1321. inputs_embeds=decoder_inputs_embeds,
  1322. past_key_values=None,
  1323. encoder_hidden_states=hidden_states,
  1324. encoder_attention_mask=attention_mask,
  1325. use_cache=use_cache,
  1326. output_attentions=output_attentions,
  1327. output_hidden_states=output_hidden_states,
  1328. return_dict=return_dict,
  1329. )
  1330. sequence_output = decoder_outputs[0]
  1331. logits = self.qa_outputs(sequence_output)
  1332. start_logits, end_logits = logits.split(1, dim=-1)
  1333. start_logits = start_logits.squeeze(-1).contiguous()
  1334. end_logits = end_logits.squeeze(-1).contiguous()
  1335. total_loss = None
  1336. if start_positions is not None and end_positions is not None:
  1337. # If we are on multi-GPU, split add a dimension
  1338. if len(start_positions.size()) > 1:
  1339. start_positions = start_positions.squeeze(-1).to(start_logits.device)
  1340. if len(end_positions.size()) > 1:
  1341. end_positions = end_positions.squeeze(-1).to(end_logits.device)
  1342. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1343. ignored_index = start_logits.size(1)
  1344. start_positions = start_positions.clamp(0, ignored_index)
  1345. end_positions = end_positions.clamp(0, ignored_index)
  1346. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1347. start_loss = loss_fct(start_logits, start_positions)
  1348. end_loss = loss_fct(end_logits, end_positions)
  1349. total_loss = (start_loss + end_loss) / 2
  1350. if not return_dict:
  1351. output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs
  1352. return ((total_loss,) + output) if total_loss is not None else output
  1353. return Seq2SeqQuestionAnsweringModelOutput(
  1354. loss=total_loss,
  1355. start_logits=start_logits,
  1356. end_logits=end_logits,
  1357. past_key_values=decoder_outputs.past_key_values,
  1358. decoder_hidden_states=decoder_outputs.hidden_states,
  1359. decoder_attentions=decoder_outputs.attentions,
  1360. cross_attentions=decoder_outputs.cross_attentions,
  1361. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1362. encoder_hidden_states=encoder_outputs.hidden_states,
  1363. encoder_attentions=encoder_outputs.attentions,
  1364. )
  1365. __all__ = [
  1366. "UMT5EncoderModel",
  1367. "UMT5ForConditionalGeneration",
  1368. "UMT5ForQuestionAnswering",
  1369. "UMT5ForSequenceClassification",
  1370. "UMT5ForTokenClassification",
  1371. "UMT5Model",
  1372. "UMT5PreTrainedModel",
  1373. ]