modeling_longt5.py 83 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821
  1. # Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch LongT5 model."""
  15. import copy
  16. import math
  17. from typing import Any
  18. import torch
  19. from torch import nn
  20. from torch.nn import CrossEntropyLoss
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  24. from ...generation import GenerationMixin
  25. from ...masking_utils import create_causal_mask
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import (
  28. BaseModelOutput,
  29. BaseModelOutputWithPastAndCrossAttentions,
  30. Seq2SeqLMOutput,
  31. Seq2SeqModelOutput,
  32. )
  33. from ...modeling_utils import PreTrainedModel
  34. from ...utils import (
  35. DUMMY_INPUTS,
  36. DUMMY_MASK,
  37. auto_docstring,
  38. is_torchdynamo_compiling,
  39. logging,
  40. )
  41. from .configuration_longt5 import LongT5Config
  42. logger = logging.get_logger(__name__)
  43. # TODO: Update before the merge
  44. def _pad_to_multiple(x: torch.Tensor, block_len: int, dim: int, pad_value: int = 0) -> torch.Tensor:
  45. """Pad a tensor so that a sequence length will be a multiple of `block_len`"""
  46. pad_len = -x.shape[dim] % block_len
  47. # Handle cases when an empty input sequence is given
  48. if not all(x.shape):
  49. new_shape = list(x.shape)
  50. new_shape[dim] += pad_len
  51. return torch.zeros(new_shape, dtype=x.dtype)
  52. pad = [(0, 0)] * x.ndim
  53. pad[dim] = (0, pad_len)
  54. pad = sum(pad[::-1], ())
  55. x = nn.functional.pad(x, pad=pad, mode="constant", value=pad_value)
  56. return x
  57. def _split_into_blocks(x: torch.Tensor, block_len: int, dim: int) -> torch.Tensor:
  58. """Split an input tensor into blocks of a given `block_len` along the given `dim`. If the dimension length
  59. is not a multiple of `block_len`, it will be padded first with selected `pad_value`.
  60. """
  61. # pad tensor to multiple of block_len
  62. if x.shape[dim] % block_len != 0:
  63. x = _pad_to_multiple(x, block_len, dim, pad_value=0)
  64. num_blocks = x.shape[dim] // block_len
  65. output_shape = x.shape[:dim] + (num_blocks, block_len) + x.shape[(dim + 1) :]
  66. # If 0 is in output_shape, we cannot apply reshape because of incompatibility with ONNX conversion
  67. if 0 in output_shape:
  68. return torch.empty(output_shape, dtype=x.dtype, device=x.device)
  69. return x.reshape(output_shape)
  70. def _concatenate_3_blocks(x: torch.Tensor, block_dim: int, sequence_dim: int, pad_value: int = 0) -> torch.Tensor:
  71. """Concatenate three consecutive blocks for each input block for local attentiont.
  72. For more information, see: https://huggingface.co/papers/2112.07916.
  73. """
  74. num_blocks = x.shape[block_dim]
  75. pad = [(0, 0)] * x.ndim
  76. pad[block_dim] = (1, 1)
  77. pad = sum(pad[::-1], ())
  78. # [batch_size, num_blocks, block_len] -> [batch_size, num_blocks + 2, block_len]
  79. x = nn.functional.pad(x, pad=pad, mode="constant", value=pad_value)
  80. blocks_list: list[torch.Tensor] = []
  81. for i in range(3):
  82. # We use indexing approach here:
  83. # https://numpy.org/doc/stable/user/basics.indexing.html#dealing-with-variable-numbers-of-indices-within-programs
  84. indices = [slice(0, None)] * x.ndim
  85. indices[block_dim] = slice(i, i + num_blocks)
  86. indices = tuple(indices)
  87. blocks_list.append(x[indices])
  88. # [batch_size, num_blocks, 3 * block_len, ...]
  89. return torch.cat(blocks_list, dim=sequence_dim)
  90. def _make_3block_relative_position_ids(block_len: int) -> torch.Tensor:
  91. """Makes 3-blocked relative position ids for local attention."""
  92. position_ids = torch.arange(3 * block_len, dtype=torch.int32)
  93. center_position_ids = position_ids[block_len:-block_len]
  94. # [block_len, 3 * block_len]
  95. relative_position_ids = position_ids.unsqueeze(0) - center_position_ids.unsqueeze(1)
  96. return relative_position_ids
  97. def _mask_local_attention_mask(local_attention_mask: torch.Tensor, block_len: int) -> torch.Tensor:
  98. """Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius."""
  99. relative_position_ids = _make_3block_relative_position_ids(block_len)
  100. locality_mask = torch.abs(relative_position_ids) < block_len
  101. locality_mask = locality_mask[None, None, :, :]
  102. locality_mask = locality_mask.to(local_attention_mask.device)
  103. return torch.logical_and(local_attention_mask, locality_mask)
  104. def _get_local_attention_mask(attention_mask: torch.Tensor, block_len: int, device: torch.device) -> torch.Tensor:
  105. """Prepare attention mask to be applied for a local attention."""
  106. # [batch_size, num_blocks, block_len]
  107. _blocked_attention_mask = _split_into_blocks(attention_mask, block_len, dim=1)
  108. # [batch_size, num_block, 3 * block_len]
  109. _3blocked_attention_mask = _concatenate_3_blocks(_blocked_attention_mask, block_dim=1, sequence_dim=2)
  110. _blocked_attention_mask = _blocked_attention_mask.unsqueeze(-1)
  111. _3blocked_attention_mask = _3blocked_attention_mask.unsqueeze(-2)
  112. # [batch_size, num_block, block_len, 3 * block_len]
  113. local_attention_mask = torch.logical_and(_blocked_attention_mask, _3blocked_attention_mask)
  114. local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len)
  115. # [batch_size, 1, num_block, block_len, 3 * block_len]
  116. return local_attention_mask.unsqueeze(1).to(device)
  117. def _make_global_fixed_block_ids(
  118. attention_mask: torch.Tensor, global_block_size: int
  119. ) -> tuple[torch.Tensor, torch.Tensor]:
  120. """Obtain the "fixed block" global id corresponding to each input token.
  121. This implementation is a simplified version of the original Flaxformr implementation adopted from:
  122. https://github.com/google/flaxformer/blob/main/flaxformer/architectures/longt5/long_attention.py.
  123. In our scenario, as we use this strategy only for a decoder, orphan tokens, i.e. those tokens which do not make for
  124. the whole fixed block, are assigned to the preceding block.
  125. Padding tokens from the original sequence are represented by -1.
  126. """
  127. batch_size, seq_len = attention_mask.shape[:2]
  128. def handle_orphan_tokens(block_ids: torch.Tensor) -> torch.Tensor:
  129. block_ends = (torch.arange(seq_len) % global_block_size) == global_block_size - 1
  130. block_ends = block_ends.to(block_ids.device)
  131. true_block_ends = torch.logical_and(block_ends, block_ids >= 0)
  132. full_blocks = true_block_ends.sum(-1).unsqueeze(-1).type(block_ids.dtype) - 1
  133. block_ids = torch.where(block_ids < full_blocks, block_ids, full_blocks)
  134. return block_ids
  135. fixed_block_mask = torch.ones_like(attention_mask, device=attention_mask.device) / global_block_size
  136. fixed_block_mask = torch.cumsum(fixed_block_mask, axis=1) - fixed_block_mask
  137. mask = torch.where(attention_mask != 0.0, 1.0, -1000.0).type(attention_mask.dtype)
  138. global_block_ids = torch.floor(mask + fixed_block_mask - 1.0).type(attention_mask.dtype)
  139. _global_block_ids_lower_bound = torch.tensor(-1, dtype=global_block_ids.dtype, device=global_block_ids.device)
  140. global_block_ids = torch.where(
  141. global_block_ids > _global_block_ids_lower_bound, global_block_ids, _global_block_ids_lower_bound
  142. )
  143. # set padding tokens to -1
  144. global_block_ids = (global_block_ids * attention_mask) + (attention_mask - 1)
  145. # [batch_size, seq_len]
  146. global_block_ids = handle_orphan_tokens(global_block_ids)
  147. num_globals = seq_len // global_block_size
  148. # [batch_size, seq_len // global_block_size]
  149. if num_globals > 0:
  150. _sequence_block_ids_max = torch.max(global_block_ids, dim=-1).values.repeat(num_globals, 1).transpose(0, 1)
  151. else:
  152. _sequence_block_ids_max = torch.zeros(
  153. batch_size, 0, dtype=global_block_ids.dtype, device=global_block_ids.device
  154. )
  155. global_segment_ids = torch.cumsum(torch.ones(batch_size, num_globals), dim=-1) - 1
  156. global_segment_ids = global_segment_ids.to(attention_mask.device)
  157. global_segment_ids = torch.where(global_segment_ids <= _sequence_block_ids_max, 1, 0)
  158. return global_block_ids.type(torch.int), global_segment_ids.type(torch.int)
  159. def _make_side_relative_position_ids(attention_mask: torch.Tensor, global_block_size: int) -> torch.Tensor:
  160. """Create the relative position tensor for local -> global attention."""
  161. block_ids, global_segment_ids = _make_global_fixed_block_ids(attention_mask, global_block_size)
  162. global_seq_len = global_segment_ids.shape[-1]
  163. global_positions = torch.arange(global_seq_len, device=block_ids.device)
  164. side_relative_position = global_positions - block_ids[..., None]
  165. return side_relative_position.type(torch.int64)
  166. def _create_global_aggregates(
  167. hidden_states: torch.Tensor, block_ids: torch.Tensor, global_seq_len: int
  168. ) -> torch.Tensor:
  169. """Compute individual block aggregates by summing over individual blocks."""
  170. # (batch..., seq_len, global_seq_len))
  171. block_ids = block_ids.where(
  172. block_ids >= 0, torch.tensor(global_seq_len, dtype=block_ids.dtype, device=block_ids.device)
  173. )
  174. one_hot_block_ids = nn.functional.one_hot(block_ids.type(torch.int64), global_seq_len + 1)[:, :, :-1]
  175. return torch.einsum("...nd,...ng->...gd", hidden_states, one_hot_block_ids.type(hidden_states.dtype))
  176. # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->LongT5
  177. class LongT5LayerNorm(nn.Module):
  178. def __init__(self, hidden_size, eps=1e-6):
  179. """
  180. Construct a layernorm module in the LongT5 style. No bias and no subtraction of mean.
  181. """
  182. super().__init__()
  183. self.weight = nn.Parameter(torch.ones(hidden_size))
  184. self.variance_epsilon = eps
  185. def forward(self, hidden_states):
  186. # LongT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  187. # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
  188. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  189. # half-precision inputs is done in fp32
  190. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  191. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  192. # convert into half-precision if necessary
  193. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  194. hidden_states = hidden_states.to(self.weight.dtype)
  195. return self.weight * hidden_states
  196. try:
  197. from apex.normalization import FusedRMSNorm
  198. LongT5LayerNorm = FusedRMSNorm
  199. logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of LongT5LayerNorm")
  200. except ImportError:
  201. # using the normal LongT5LayerNorm
  202. pass
  203. except Exception:
  204. logger.warning("discovered apex but it failed to load, falling back to LongT5LayerNorm")
  205. # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->LongT5
  206. class LongT5DenseActDense(nn.Module):
  207. def __init__(self, config: LongT5Config):
  208. super().__init__()
  209. self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
  210. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  211. self.dropout = nn.Dropout(config.dropout_rate)
  212. self.act = ACT2FN[config.dense_act_fn]
  213. def forward(self, hidden_states):
  214. hidden_states = self.wi(hidden_states)
  215. hidden_states = self.act(hidden_states)
  216. hidden_states = self.dropout(hidden_states)
  217. if (
  218. isinstance(self.wo.weight, torch.Tensor)
  219. and hidden_states.dtype != self.wo.weight.dtype
  220. and self.wo.weight.dtype != torch.int8
  221. ):
  222. hidden_states = hidden_states.to(self.wo.weight.dtype)
  223. hidden_states = self.wo(hidden_states)
  224. return hidden_states
  225. class LongT5DenseGatedActDense(nn.Module):
  226. def __init__(self, config: LongT5Config):
  227. super().__init__()
  228. self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
  229. self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
  230. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  231. self.dropout = nn.Dropout(config.dropout_rate)
  232. self.act = ACT2FN[config.dense_act_fn]
  233. def forward(self, hidden_states):
  234. hidden_gelu = self.act(self.wi_0(hidden_states))
  235. hidden_linear = self.wi_1(hidden_states)
  236. hidden_states = hidden_gelu * hidden_linear
  237. hidden_states = self.dropout(hidden_states)
  238. hidden_states = self.wo(hidden_states)
  239. return hidden_states
  240. # Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->LongT5
  241. class LongT5LayerFF(nn.Module):
  242. def __init__(self, config: LongT5Config):
  243. super().__init__()
  244. if config.is_gated_act:
  245. self.DenseReluDense = LongT5DenseGatedActDense(config)
  246. else:
  247. self.DenseReluDense = LongT5DenseActDense(config)
  248. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  249. self.dropout = nn.Dropout(config.dropout_rate)
  250. def forward(self, hidden_states):
  251. forwarded_states = self.layer_norm(hidden_states)
  252. forwarded_states = self.DenseReluDense(forwarded_states)
  253. hidden_states = hidden_states + self.dropout(forwarded_states)
  254. return hidden_states
  255. # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->LongT5
  256. class LongT5Attention(nn.Module):
  257. def __init__(
  258. self,
  259. config: LongT5Config,
  260. has_relative_attention_bias=False,
  261. layer_idx: int | None = None,
  262. ):
  263. super().__init__()
  264. self.is_decoder = config.is_decoder
  265. self.has_relative_attention_bias = has_relative_attention_bias
  266. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  267. self.relative_attention_max_distance = config.relative_attention_max_distance
  268. self.d_model = config.d_model
  269. self.key_value_proj_dim = config.d_kv
  270. self.n_heads = config.num_heads
  271. self.dropout = config.dropout_rate
  272. self.inner_dim = self.n_heads * self.key_value_proj_dim
  273. self.layer_idx = layer_idx
  274. if layer_idx is None and self.is_decoder:
  275. logger.warning_once(
  276. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  277. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  278. "when creating this class."
  279. )
  280. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  281. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  282. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  283. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  284. if self.has_relative_attention_bias:
  285. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  286. self.gradient_checkpointing = False
  287. @staticmethod
  288. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  289. """
  290. Adapted from Mesh Tensorflow:
  291. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  292. Translate relative position to a bucket number for relative attention. The relative position is defined as
  293. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  294. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  295. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  296. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  297. This should allow for more graceful generalization to longer sequences than the model has been trained on
  298. Args:
  299. relative_position: an int32 Tensor
  300. bidirectional: a boolean - whether the attention is bidirectional
  301. num_buckets: an integer
  302. max_distance: an integer
  303. Returns:
  304. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  305. """
  306. relative_buckets = 0
  307. if bidirectional:
  308. num_buckets //= 2
  309. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  310. relative_position = torch.abs(relative_position)
  311. else:
  312. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  313. # now relative_position is in the range [0, inf)
  314. # half of the buckets are for exact increments in positions
  315. max_exact = num_buckets // 2
  316. is_small = relative_position < max_exact
  317. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  318. relative_position_if_large = max_exact + (
  319. torch.log(relative_position.float() / max_exact)
  320. / math.log(max_distance / max_exact)
  321. * (num_buckets - max_exact)
  322. ).to(torch.long)
  323. relative_position_if_large = torch.min(
  324. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  325. )
  326. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  327. return relative_buckets
  328. def compute_bias(self, query_length, key_length, device=None, past_seen_tokens=0):
  329. """Compute binned relative position bias"""
  330. if device is None:
  331. device = self.relative_attention_bias.weight.device
  332. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + past_seen_tokens
  333. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  334. relative_position = memory_position - context_position # shape (query_length, key_length)
  335. relative_position_bucket = self._relative_position_bucket(
  336. relative_position, # shape (query_length, key_length)
  337. bidirectional=(not self.is_decoder),
  338. num_buckets=self.relative_attention_num_buckets,
  339. max_distance=self.relative_attention_max_distance,
  340. )
  341. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  342. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  343. return values
  344. def forward(
  345. self,
  346. hidden_states,
  347. mask=None,
  348. key_value_states=None,
  349. position_bias=None,
  350. past_key_values=None,
  351. output_attentions=False,
  352. **kwargs,
  353. ):
  354. """
  355. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  356. """
  357. # Input is (batch_size, seq_length, dim)
  358. # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
  359. input_shape = hidden_states.shape[:-1]
  360. hidden_shape = (*input_shape, -1, self.key_value_proj_dim)
  361. past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0
  362. # We clone here for StaticCache, as we get the value before updating it, but use it after and it's the same ref
  363. past_seen_tokens = past_seen_tokens.clone() if isinstance(past_seen_tokens, torch.Tensor) else past_seen_tokens
  364. # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
  365. is_cross_attention = key_value_states is not None
  366. query_states = self.q(hidden_states).view(hidden_shape).transpose(1, 2)
  367. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  368. is_updated = False
  369. if isinstance(past_key_values, EncoderDecoderCache):
  370. is_updated = past_key_values.is_updated.get(self.layer_idx)
  371. if is_cross_attention:
  372. # after the first generated id, we can subsequently re-use all key/value_states from cache
  373. curr_past_key_values = past_key_values.cross_attention_cache
  374. else:
  375. curr_past_key_values = past_key_values.self_attention_cache
  376. else:
  377. curr_past_key_values = past_key_values
  378. current_states = key_value_states if is_cross_attention else hidden_states
  379. if is_cross_attention and past_key_values is not None and is_updated:
  380. # reuse k,v, cross_attentions
  381. key_states = curr_past_key_values.layers[self.layer_idx].keys
  382. value_states = curr_past_key_values.layers[self.layer_idx].values
  383. else:
  384. kv_shape = (*current_states.shape[:-1], -1, self.key_value_proj_dim)
  385. key_states = self.k(current_states).view(kv_shape).transpose(1, 2)
  386. value_states = self.v(current_states).view(kv_shape).transpose(1, 2)
  387. if past_key_values is not None:
  388. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  389. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  390. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  391. past_key_values.is_updated[self.layer_idx] = True
  392. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  393. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  394. if position_bias is None:
  395. key_length = key_states.shape[-2]
  396. if not self.has_relative_attention_bias:
  397. position_bias = torch.zeros(
  398. (1, query_states.shape[1], input_shape[1], key_length), device=scores.device, dtype=scores.dtype
  399. )
  400. if self.gradient_checkpointing and self.training:
  401. position_bias.requires_grad = True
  402. else:
  403. position_bias = self.compute_bias(
  404. input_shape[1], key_length, device=scores.device, past_seen_tokens=past_seen_tokens
  405. )
  406. if mask is not None:
  407. causal_mask = mask[:, :, :, : key_states.shape[-2]]
  408. position_bias = position_bias + causal_mask
  409. position_bias_masked = position_bias
  410. scores += position_bias_masked
  411. # (batch_size, n_heads, seq_length, key_length)
  412. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  413. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  414. attn_output = torch.matmul(attn_weights, value_states)
  415. attn_output = attn_output.transpose(1, 2).contiguous()
  416. attn_output = attn_output.reshape(*input_shape, -1)
  417. attn_output = self.o(attn_output)
  418. outputs = (attn_output, position_bias)
  419. if output_attentions:
  420. outputs = outputs + (attn_weights,)
  421. return outputs
  422. class LongT5LocalAttention(nn.Module):
  423. def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = False) -> None:
  424. super().__init__()
  425. self.is_decoder = config.is_decoder
  426. self.has_relative_attention_bias = has_relative_attention_bias
  427. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  428. self.relative_attention_max_distance = config.relative_attention_max_distance
  429. self.d_model = config.d_model
  430. self.key_value_proj_dim = config.d_kv
  431. self.n_heads = config.num_heads
  432. self.local_radius = config.local_radius
  433. self.block_len = self.local_radius + 1
  434. self.dropout = config.dropout_rate
  435. self.inner_dim = self.n_heads * self.key_value_proj_dim
  436. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  437. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  438. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  439. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  440. if self.has_relative_attention_bias:
  441. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  442. self.gradient_checkpointing = False
  443. @staticmethod
  444. # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
  445. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  446. """
  447. Adapted from Mesh Tensorflow:
  448. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  449. Translate relative position to a bucket number for relative attention. The relative position is defined as
  450. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  451. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  452. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  453. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  454. This should allow for more graceful generalization to longer sequences than the model has been trained on
  455. Args:
  456. relative_position: an int32 Tensor
  457. bidirectional: a boolean - whether the attention is bidirectional
  458. num_buckets: an integer
  459. max_distance: an integer
  460. Returns:
  461. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  462. """
  463. relative_buckets = 0
  464. if bidirectional:
  465. num_buckets //= 2
  466. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  467. relative_position = torch.abs(relative_position)
  468. else:
  469. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  470. # now relative_position is in the range [0, inf)
  471. # half of the buckets are for exact increments in positions
  472. max_exact = num_buckets // 2
  473. is_small = relative_position < max_exact
  474. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  475. relative_position_if_large = max_exact + (
  476. torch.log(relative_position.float() / max_exact)
  477. / math.log(max_distance / max_exact)
  478. * (num_buckets - max_exact)
  479. ).to(torch.long)
  480. relative_position_if_large = torch.min(
  481. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  482. )
  483. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  484. return relative_buckets
  485. def compute_bias(self, block_length: int):
  486. """Compute binned relative position bias"""
  487. target_device = (
  488. self.relative_attention_bias.weight.device
  489. if self.relative_attention_bias.weight.device.type != "meta"
  490. else None
  491. )
  492. memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device)
  493. context_position = memory_position[block_length:-block_length]
  494. # (block_length, 3 * block_length)
  495. relative_position = memory_position[None, :] - context_position[:, None]
  496. relative_position_bucket = self._relative_position_bucket(
  497. relative_position, # (block_length, 3 * block_length)
  498. bidirectional=(not self.is_decoder),
  499. num_buckets=self.relative_attention_num_buckets,
  500. max_distance=self.relative_attention_max_distance,
  501. )
  502. # (block_length, 3 * block_length, num_heads)
  503. values = self.relative_attention_bias(relative_position_bucket)
  504. # (1, 1, num_heads, block_length, 3 * block_length)
  505. values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0)
  506. return values
  507. def forward(
  508. self,
  509. hidden_states,
  510. mask=None,
  511. position_bias=None,
  512. output_attentions=False,
  513. ):
  514. batch_size, seq_length = hidden_states.shape[:2]
  515. def shape(states):
  516. """projection"""
  517. return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
  518. def unshape(states):
  519. """reshape"""
  520. return states.contiguous().view(batch_size, -1, self.inner_dim)
  521. # get query/key/value states -> (batch_size, seq_length, n_heads, dim_per_head)
  522. query_states = shape(self.q(hidden_states))
  523. key_states = shape(self.k(hidden_states))
  524. value_states = shape(self.v(hidden_states))
  525. # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
  526. query_states = _split_into_blocks(query_states, self.block_len, dim=1)
  527. key_states = _split_into_blocks(key_states, self.block_len, dim=1)
  528. value_states = _split_into_blocks(value_states, self.block_len, dim=1)
  529. # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
  530. key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
  531. value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
  532. # Compute scores
  533. scores = torch.einsum(
  534. "...qhd,...khd->...hqk", query_states, key_states
  535. ) # (batch_size, num_block, n_heads, block_len, 3 * block_len)
  536. if position_bias is None:
  537. # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
  538. if not self.has_relative_attention_bias:
  539. position_bias = torch.zeros(
  540. (1, 1, self.n_heads, self.block_len, 3 * self.block_len), device=scores.device, dtype=scores.dtype
  541. )
  542. if self.gradient_checkpointing and self.training:
  543. position_bias.requires_grad = True
  544. else:
  545. position_bias = self.compute_bias(self.block_len)
  546. if mask is not None:
  547. # Replace masked positions with -1e10 (according to the original implementation)
  548. mask = torch.where(mask > 0, 0.0, -1e10)
  549. # We need to adjust position bias shape to be sum with mask
  550. position_bias = position_bias + mask.transpose(1, 2)
  551. scores += position_bias
  552. # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
  553. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  554. # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
  555. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  556. attn_weights = attn_weights.type(value_states.dtype)
  557. attn_output = unshape(torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states))
  558. attn_output = attn_output[:, :seq_length, :]
  559. attn_output = self.o(attn_output)
  560. outputs = (
  561. attn_output,
  562. position_bias,
  563. )
  564. if output_attentions:
  565. outputs = outputs + (attn_weights,)
  566. return outputs
  567. class LongT5TransientGlobalAttention(nn.Module):
  568. def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = False) -> None:
  569. super().__init__()
  570. self.is_decoder = config.is_decoder
  571. self.has_relative_attention_bias = has_relative_attention_bias
  572. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  573. self.relative_attention_max_distance = config.relative_attention_max_distance
  574. self.d_model = config.d_model
  575. self.key_value_proj_dim = config.d_kv
  576. self.n_heads = config.num_heads
  577. self.local_radius = config.local_radius
  578. self.block_len = self.local_radius + 1
  579. self.global_block_size = config.global_block_size
  580. self.dropout = config.dropout_rate
  581. self.inner_dim = self.n_heads * self.key_value_proj_dim
  582. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  583. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  584. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  585. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  586. if self.has_relative_attention_bias:
  587. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  588. # Relativen attention bias & Layer norm for global attention
  589. if self.has_relative_attention_bias:
  590. self.global_relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  591. self.global_input_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  592. @staticmethod
  593. # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
  594. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  595. """
  596. Adapted from Mesh Tensorflow:
  597. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  598. Translate relative position to a bucket number for relative attention. The relative position is defined as
  599. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  600. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  601. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  602. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  603. This should allow for more graceful generalization to longer sequences than the model has been trained on
  604. Args:
  605. relative_position: an int32 Tensor
  606. bidirectional: a boolean - whether the attention is bidirectional
  607. num_buckets: an integer
  608. max_distance: an integer
  609. Returns:
  610. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  611. """
  612. relative_buckets = 0
  613. if bidirectional:
  614. num_buckets //= 2
  615. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  616. relative_position = torch.abs(relative_position)
  617. else:
  618. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  619. # now relative_position is in the range [0, inf)
  620. # half of the buckets are for exact increments in positions
  621. max_exact = num_buckets // 2
  622. is_small = relative_position < max_exact
  623. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  624. relative_position_if_large = max_exact + (
  625. torch.log(relative_position.float() / max_exact)
  626. / math.log(max_distance / max_exact)
  627. * (num_buckets - max_exact)
  628. ).to(torch.long)
  629. relative_position_if_large = torch.min(
  630. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  631. )
  632. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  633. return relative_buckets
  634. def compute_bias(self, block_length: int):
  635. """Compute binned relative position bias"""
  636. target_device = (
  637. self.relative_attention_bias.weight.device
  638. if self.relative_attention_bias.weight.device.type != "meta"
  639. else None
  640. )
  641. memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device)
  642. context_position = memory_position[block_length:-block_length]
  643. # (block_length, 3 * block_length)
  644. relative_position = memory_position[None, :] - context_position[:, None]
  645. relative_position_bucket = self._relative_position_bucket(
  646. relative_position, # (block_length, 3 * block_length)
  647. bidirectional=(not self.is_decoder),
  648. num_buckets=self.relative_attention_num_buckets,
  649. max_distance=self.relative_attention_max_distance,
  650. )
  651. # (block_length, 3 * block_length, num_heads)
  652. values = self.relative_attention_bias(relative_position_bucket)
  653. # (1, 1, num_heads, block_length, 3 * block_length)
  654. values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0)
  655. return values
  656. def compute_side_bias(self, mask: torch.Tensor, global_segment_ids: torch.Tensor) -> torch.Tensor:
  657. # (batch_size, 1, seq_len, global_seq_len)
  658. side_attention_mask = torch.eq(mask[..., None], global_segment_ids[:, None, :])[:, None, ...]
  659. attention_side_bias = torch.where(side_attention_mask > 0, 0.0, -1e10)
  660. # (batch_size, seq_len, global_seq_len)
  661. side_relative_position = _make_side_relative_position_ids(mask, self.global_block_size)
  662. side_relative_position_bucket = self._relative_position_bucket(
  663. side_relative_position,
  664. bidirectional=(not self.is_decoder),
  665. num_buckets=self.relative_attention_num_buckets,
  666. max_distance=self.relative_attention_max_distance,
  667. )
  668. # (batch_size, seq_len, global_seq_len, num_heads)
  669. side_bias = self.global_relative_attention_bias(side_relative_position_bucket)
  670. # (batch_size, num_heads, seq_len, global_seq_len)
  671. side_bias = side_bias.permute([0, 3, 1, 2])
  672. # (batch_size, num_heads, seq_len, global_seq_len)
  673. attention_side_bias = attention_side_bias + side_bias
  674. return attention_side_bias
  675. def forward(
  676. self,
  677. hidden_states,
  678. mask=None,
  679. position_bias=None,
  680. output_attentions=False,
  681. ):
  682. batch_size, seq_length = hidden_states.shape[:2]
  683. def shape(states):
  684. """projection"""
  685. return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
  686. def unshape(states):
  687. """reshape"""
  688. return states.contiguous().view(batch_size, -1, self.inner_dim)
  689. # Prepare components for transient-global attention
  690. # Obtain block_ids and global_segment_ids
  691. # global_seq_len := seq_len // self.global_block_size
  692. # shapes: (batch_size, seq_len) & (batch_size, global_seq_len)
  693. block_ids, global_segment_ids = _make_global_fixed_block_ids(
  694. mask if mask is not None else torch.ones(hidden_states.shape[:-1]),
  695. self.global_block_size,
  696. )
  697. # Create global inputs
  698. _global_seq_len = global_segment_ids.shape[-1]
  699. global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len)
  700. global_inputs = self.global_input_layer_norm(global_inputs)
  701. # get query states -> (batch_size, seq_length, n_heads, dim_per_head)
  702. query_states = shape(self.q(hidden_states))
  703. key_states = shape(self.k(hidden_states))
  704. value_states = shape(self.v(hidden_states))
  705. # Get global/side key/value states shape: (batch_size, global_seq_len, n_heads, dim_per_head)
  706. side_key_states = shape(self.k(global_inputs))
  707. side_value_states = shape(self.v(global_inputs))
  708. # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
  709. query_states = _split_into_blocks(query_states, self.block_len, dim=1)
  710. key_states = _split_into_blocks(key_states, self.block_len, dim=1)
  711. value_states = _split_into_blocks(value_states, self.block_len, dim=1)
  712. # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
  713. key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
  714. value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
  715. # Tile side inputs across local key/value blocks
  716. # New shape: (batch_size, num_blocks, global_seq_len, n_heads, dim_per_head)
  717. reps = [1] * (side_key_states.ndim + 1)
  718. reps[1] = key_states.shape[1]
  719. side_key_states = side_key_states.unsqueeze(1).repeat(reps)
  720. side_value_states = side_value_states.unsqueeze(1).repeat(reps)
  721. # Concatenate "local" and "side"/"global" key/value states to allow each token to attend global aggregated ones
  722. # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head)
  723. key_states = torch.cat([key_states, side_key_states], dim=2)
  724. value_states = torch.cat([value_states, side_value_states], dim=2)
  725. # Compute scores -> (batch_size, num_block, n_heads, block_len, 3 * block_len + global_seq_len)
  726. scores = torch.einsum("...qhd,...khd->...hqk", query_states, key_states)
  727. if mask is not None:
  728. # We need to adjust position bias shape to be sum with mask
  729. local_attention_mask = _get_local_attention_mask(mask, self.block_len, hidden_states.device)
  730. # Replace masked positions with -10_000 (according to the original implementation)
  731. local_attention_mask = torch.where(local_attention_mask > 0, 0.0, -1e10)
  732. else:
  733. local_attention_mask = None
  734. if position_bias is None:
  735. # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
  736. if not self.has_relative_attention_bias:
  737. position_bias = torch.zeros(
  738. (1, 1, self.n_heads, self.block_len, 3 * self.block_len),
  739. device=scores.device,
  740. dtype=scores.dtype,
  741. )
  742. if self.gradient_checkpointing and self.training:
  743. position_bias.requires_grad = True
  744. else:
  745. position_bias = self.compute_bias(self.block_len)
  746. if local_attention_mask is not None:
  747. # (batch_size, 1, n_heads, block_len, 3 * block_len)
  748. position_bias = position_bias + local_attention_mask.transpose(1, 2)
  749. position_bias = position_bias.type(scores.dtype)
  750. # Calculate global/side bias - shape: # (batch_size, num_heads, seq_len, global_seq_len)
  751. if mask is None:
  752. mask = torch.ones(batch_size, seq_length)
  753. # (batch_size, num_heads, seq_len, global_seq_len)
  754. side_position_bias = self.compute_side_bias(mask, global_segment_ids)
  755. # (batch_size, num_blocks, num_heads, block_len, global_seq_len)
  756. side_position_bias = _split_into_blocks(side_position_bias, self.block_len, dim=-2).transpose(1, 2)
  757. side_position_bias = side_position_bias.type(scores.dtype).to(scores.device)
  758. # (batch_size, num_blocks, num_heads, block_len, 3 * block_len + global_seq_len)
  759. position_bias = torch.cat([position_bias, side_position_bias], dim=-1)
  760. scores += position_bias
  761. # (batch_size, num_blocks, n_heads, block_len, 3 * block_len + global_seq_len)
  762. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  763. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  764. attn_weights = attn_weights.type(value_states.dtype)
  765. attn_output = unshape(torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states))
  766. attn_output = attn_output[:, :seq_length, :]
  767. attn_output = self.o(attn_output)
  768. outputs = (attn_output, position_bias)
  769. if output_attentions:
  770. outputs = outputs + (attn_weights,)
  771. return outputs
  772. # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5
  773. class LongT5LayerSelfAttention(nn.Module):
  774. def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
  775. super().__init__()
  776. self.SelfAttention = LongT5Attention(
  777. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  778. )
  779. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  780. self.dropout = nn.Dropout(config.dropout_rate)
  781. def forward(
  782. self,
  783. hidden_states,
  784. attention_mask=None,
  785. position_bias=None,
  786. past_key_values=None,
  787. use_cache=False,
  788. output_attentions=False,
  789. **kwargs,
  790. ):
  791. normed_hidden_states = self.layer_norm(hidden_states)
  792. attention_output = self.SelfAttention(
  793. normed_hidden_states,
  794. mask=attention_mask,
  795. position_bias=position_bias,
  796. past_key_values=past_key_values,
  797. use_cache=use_cache,
  798. output_attentions=output_attentions,
  799. )
  800. hidden_states = hidden_states + self.dropout(attention_output[0])
  801. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  802. return outputs
  803. class LongT5LayerLocalSelfAttention(nn.Module):
  804. """Local self attention used in encoder"""
  805. def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
  806. super().__init__()
  807. self.LocalSelfAttention = LongT5LocalAttention(config, has_relative_attention_bias=has_relative_attention_bias)
  808. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  809. self.dropout = nn.Dropout(config.dropout_rate)
  810. def forward(
  811. self,
  812. hidden_states,
  813. attention_mask=None,
  814. position_bias=None,
  815. output_attentions=False,
  816. **kwargs: Any, # to accept past_key_values and use_cache kwargs
  817. ):
  818. normed_hidden_states = self.layer_norm(hidden_states)
  819. attention_output = self.LocalSelfAttention(
  820. normed_hidden_states,
  821. mask=attention_mask,
  822. position_bias=position_bias,
  823. output_attentions=output_attentions,
  824. )
  825. hidden_states = hidden_states + self.dropout(attention_output[0])
  826. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  827. return outputs
  828. class LongT5LayerTransientGlobalSelfAttention(nn.Module):
  829. """Transient-Global self attention used in encoder"""
  830. def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
  831. super().__init__()
  832. self.TransientGlobalSelfAttention = LongT5TransientGlobalAttention(
  833. config, has_relative_attention_bias=has_relative_attention_bias
  834. )
  835. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  836. self.dropout = nn.Dropout(config.dropout_rate)
  837. def forward(
  838. self,
  839. hidden_states,
  840. attention_mask=None,
  841. position_bias=None,
  842. output_attentions=False,
  843. **kwargs: Any, # to accept past_key_values and use_cache kwargs
  844. ):
  845. normed_hidden_states = self.layer_norm(hidden_states)
  846. attention_output = self.TransientGlobalSelfAttention(
  847. normed_hidden_states,
  848. mask=attention_mask,
  849. position_bias=position_bias,
  850. output_attentions=output_attentions,
  851. )
  852. hidden_states = hidden_states + self.dropout(attention_output[0])
  853. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  854. return outputs
  855. # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5
  856. class LongT5LayerCrossAttention(nn.Module):
  857. def __init__(self, config, layer_idx: int | None = None):
  858. super().__init__()
  859. self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  860. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  861. self.dropout = nn.Dropout(config.dropout_rate)
  862. def forward(
  863. self,
  864. hidden_states,
  865. key_value_states,
  866. attention_mask=None,
  867. position_bias=None,
  868. past_key_values=None,
  869. output_attentions=False,
  870. **kwargs,
  871. ):
  872. normed_hidden_states = self.layer_norm(hidden_states)
  873. attention_output = self.EncDecAttention(
  874. normed_hidden_states,
  875. mask=attention_mask,
  876. key_value_states=key_value_states,
  877. position_bias=position_bias,
  878. past_key_values=past_key_values,
  879. output_attentions=output_attentions,
  880. )
  881. layer_output = hidden_states + self.dropout(attention_output[0])
  882. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  883. return outputs
  884. class LongT5Block(GradientCheckpointingLayer):
  885. def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
  886. super().__init__()
  887. self.is_decoder = config.is_decoder
  888. if config.is_decoder:
  889. attention_layer = LongT5LayerSelfAttention
  890. elif config.encoder_attention_type == "local":
  891. attention_layer = LongT5LayerLocalSelfAttention
  892. elif config.encoder_attention_type == "transient-global":
  893. attention_layer = LongT5LayerTransientGlobalSelfAttention
  894. else:
  895. raise ValueError(
  896. "For encoder attention mechanism, either `local` or `transient-global` attention type is expected, "
  897. f"but got {config.encoder_attention_type}."
  898. )
  899. self.layer = nn.ModuleList()
  900. self.layer.append(
  901. attention_layer(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
  902. )
  903. if self.is_decoder:
  904. self.layer.append(LongT5LayerCrossAttention(config, layer_idx=layer_idx))
  905. self.layer.append(LongT5LayerFF(config))
  906. def forward(
  907. self,
  908. hidden_states,
  909. attention_mask=None,
  910. position_bias=None,
  911. encoder_hidden_states=None,
  912. encoder_attention_mask=None,
  913. encoder_decoder_position_bias=None,
  914. past_key_values=None,
  915. use_cache=False,
  916. output_attentions=False,
  917. return_dict=True,
  918. **kwargs,
  919. ):
  920. self_attention_outputs = self.layer[0](
  921. hidden_states,
  922. attention_mask=attention_mask,
  923. position_bias=position_bias,
  924. past_key_values=past_key_values,
  925. use_cache=use_cache,
  926. output_attentions=output_attentions,
  927. )
  928. hidden_states = self_attention_outputs[0]
  929. attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
  930. # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
  931. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  932. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  933. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  934. do_cross_attention = self.is_decoder and encoder_hidden_states is not None
  935. if do_cross_attention:
  936. cross_attention_outputs = self.layer[1](
  937. hidden_states,
  938. key_value_states=encoder_hidden_states,
  939. attention_mask=encoder_attention_mask,
  940. position_bias=encoder_decoder_position_bias,
  941. past_key_values=past_key_values,
  942. output_attentions=output_attentions,
  943. )
  944. hidden_states = cross_attention_outputs[0]
  945. # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
  946. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  947. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  948. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  949. # Keep cross-attention outputs and relative position weights
  950. attention_outputs = attention_outputs + cross_attention_outputs[1:]
  951. # Apply Feed Forward layer
  952. hidden_states = self.layer[-1](hidden_states)
  953. # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
  954. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  955. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  956. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  957. return (
  958. (hidden_states,) + attention_outputs
  959. ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  960. @auto_docstring
  961. class LongT5PreTrainedModel(PreTrainedModel):
  962. config: LongT5Config
  963. base_model_prefix = "transformer"
  964. supports_gradient_checkpointing = True
  965. _no_split_modules = ["LongT5Block"]
  966. _can_compile_fullgraph = False # TODO: @raushan more involved due to local/global attn
  967. @property
  968. # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs
  969. def dummy_inputs(self):
  970. input_ids = torch.tensor(DUMMY_INPUTS)
  971. input_mask = torch.tensor(DUMMY_MASK)
  972. dummy_inputs = {
  973. "decoder_input_ids": input_ids,
  974. "input_ids": input_ids,
  975. "decoder_attention_mask": input_mask,
  976. }
  977. return dummy_inputs
  978. @torch.no_grad()
  979. def _init_weights(self, module):
  980. """Initialize the weights"""
  981. factor = self.config.initializer_factor # Used for testing weights initialization
  982. if isinstance(module, LongT5LayerNorm):
  983. init.constant_(module.weight, factor * 1.0)
  984. elif isinstance(module, (LongT5Model, LongT5ForConditionalGeneration, LongT5EncoderModel)):
  985. init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
  986. if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
  987. init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0)
  988. elif isinstance(module, LongT5DenseActDense):
  989. init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  990. if hasattr(module.wi, "bias") and module.wi.bias is not None:
  991. init.zeros_(module.wi.bias)
  992. init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  993. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  994. init.zeros_(module.wo.bias)
  995. elif isinstance(module, LongT5DenseGatedActDense):
  996. init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  997. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  998. init.zeros_(module.wi_0.bias)
  999. init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  1000. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  1001. init.zeros_(module.wi_1.bias)
  1002. init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  1003. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  1004. init.zeros_(module.wo.bias)
  1005. elif isinstance(module, (LongT5Attention, LongT5LocalAttention, LongT5TransientGlobalAttention)):
  1006. d_model = self.config.d_model
  1007. key_value_proj_dim = self.config.d_kv
  1008. n_heads = self.config.num_heads
  1009. init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
  1010. init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5))
  1011. init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5))
  1012. init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  1013. if module.has_relative_attention_bias:
  1014. init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5))
  1015. if isinstance(module, LongT5TransientGlobalAttention):
  1016. init.normal_(
  1017. module.global_relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5)
  1018. )
  1019. # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5
  1020. def _shift_right(self, input_ids):
  1021. decoder_start_token_id = self.config.decoder_start_token_id
  1022. pad_token_id = self.config.pad_token_id
  1023. if decoder_start_token_id is None:
  1024. raise ValueError(
  1025. "self.model.config.decoder_start_token_id has to be defined. In LongT5 it is usually set to the pad_token_id. "
  1026. "See LongT5 docs for more information."
  1027. )
  1028. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  1029. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  1030. shifted_input_ids[..., 0] = decoder_start_token_id
  1031. if pad_token_id is None:
  1032. raise ValueError("self.model.config.pad_token_id has to be defined.")
  1033. # replace possible -100 values in labels by `pad_token_id`
  1034. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  1035. return shifted_input_ids
  1036. class LongT5Stack(LongT5PreTrainedModel):
  1037. def __init__(self, config):
  1038. super().__init__(config)
  1039. self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
  1040. self.is_decoder = config.is_decoder
  1041. self.local_radius = config.local_radius
  1042. self.block_len = self.local_radius + 1
  1043. self.block = nn.ModuleList(
  1044. [
  1045. LongT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i)
  1046. for i in range(config.num_layers)
  1047. ]
  1048. )
  1049. self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  1050. self.dropout = nn.Dropout(config.dropout_rate)
  1051. self.gradient_checkpointing = False
  1052. # Initialize weights and apply final processing
  1053. self.post_init()
  1054. # Copied from transformers.models.t5.modeling_t5.T5Stack.set_input_embeddings
  1055. def set_input_embeddings(self, new_embeddings):
  1056. self.embed_tokens = new_embeddings
  1057. def forward(
  1058. self,
  1059. input_ids=None,
  1060. attention_mask=None,
  1061. encoder_hidden_states=None,
  1062. encoder_attention_mask=None,
  1063. inputs_embeds=None,
  1064. past_key_values=None,
  1065. use_cache=None,
  1066. output_attentions=None,
  1067. output_hidden_states=None,
  1068. return_dict=None,
  1069. **kwargs,
  1070. ):
  1071. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1072. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1073. output_hidden_states = (
  1074. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1075. )
  1076. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1077. if input_ids is not None and inputs_embeds is not None:
  1078. err_msg_prefix = "decoder_" if self.is_decoder else ""
  1079. raise ValueError(
  1080. f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
  1081. )
  1082. elif input_ids is not None:
  1083. input_shape = input_ids.size()
  1084. input_ids = input_ids.view(-1, input_shape[-1])
  1085. elif inputs_embeds is not None:
  1086. input_shape = inputs_embeds.size()[:-1]
  1087. else:
  1088. err_msg_prefix = "decoder_" if self.is_decoder else ""
  1089. raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
  1090. if self.gradient_checkpointing and self.training:
  1091. if use_cache:
  1092. logger.warning_once(
  1093. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  1094. )
  1095. use_cache = False
  1096. if inputs_embeds is None:
  1097. assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
  1098. inputs_embeds = self.embed_tokens(input_ids)
  1099. batch_size, seq_length = input_shape
  1100. if self.is_decoder:
  1101. if use_cache and past_key_values is None:
  1102. if self.config.is_encoder_decoder:
  1103. past_key_values = EncoderDecoderCache(
  1104. DynamicCache(config=self.config), DynamicCache(config=self.config)
  1105. )
  1106. else:
  1107. past_key_values = DynamicCache(config=self.config)
  1108. elif not self.is_decoder:
  1109. # do not pass cache object down the line for encoder stack
  1110. # it messes indexing later in decoder-stack because cache object is modified in-place
  1111. past_key_values = None
  1112. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  1113. if attention_mask is None and not is_torchdynamo_compiling():
  1114. # required mask seq length can be calculated via length of past
  1115. mask_seq_length = past_key_values_length + seq_length
  1116. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  1117. if self.is_decoder:
  1118. causal_mask = create_causal_mask(
  1119. config=self.config,
  1120. inputs_embeds=inputs_embeds,
  1121. attention_mask=attention_mask,
  1122. past_key_values=past_key_values,
  1123. )
  1124. # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used
  1125. elif self.config.encoder_attention_type == "local":
  1126. causal_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device)
  1127. else: # we need to use both local attention mask and standard extended mask for transient-global attention
  1128. causal_mask = attention_mask
  1129. # If a 2D or 3D attention mask is provided for the cross-attention
  1130. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  1131. if self.is_decoder and encoder_hidden_states is not None:
  1132. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  1133. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  1134. if encoder_attention_mask is None:
  1135. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
  1136. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  1137. else:
  1138. encoder_extended_attention_mask = None
  1139. all_hidden_states = () if output_hidden_states else None
  1140. all_attentions = () if output_attentions else None
  1141. all_cross_attentions = () if (output_attentions and self.is_decoder) else None
  1142. position_bias = None
  1143. encoder_decoder_position_bias = None
  1144. hidden_states = self.dropout(inputs_embeds)
  1145. for i, layer_module in enumerate(self.block):
  1146. if output_hidden_states:
  1147. all_hidden_states = all_hidden_states + (hidden_states,)
  1148. layer_outputs = layer_module(
  1149. hidden_states,
  1150. causal_mask,
  1151. position_bias,
  1152. encoder_hidden_states,
  1153. encoder_extended_attention_mask,
  1154. encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
  1155. past_key_values=past_key_values,
  1156. use_cache=use_cache,
  1157. output_attentions=output_attentions,
  1158. return_dict=return_dict,
  1159. )
  1160. # layer_outputs is a tuple with:
  1161. # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  1162. hidden_states = layer_outputs[0]
  1163. # We share the position biases between the layers - the first layer store them
  1164. # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
  1165. # (cross-attention position bias), (cross-attention weights)
  1166. position_bias = layer_outputs[1]
  1167. if self.is_decoder and encoder_hidden_states is not None:
  1168. encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
  1169. if output_attentions:
  1170. all_attentions = all_attentions + (layer_outputs[2],)
  1171. if self.is_decoder:
  1172. all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
  1173. hidden_states = self.final_layer_norm(hidden_states)
  1174. hidden_states = self.dropout(hidden_states)
  1175. # Add last layer
  1176. if output_hidden_states:
  1177. all_hidden_states = all_hidden_states + (hidden_states,)
  1178. if not return_dict:
  1179. return tuple(
  1180. v
  1181. for v in [
  1182. hidden_states,
  1183. past_key_values,
  1184. all_hidden_states,
  1185. all_attentions,
  1186. all_cross_attentions,
  1187. ]
  1188. if v is not None
  1189. )
  1190. return BaseModelOutputWithPastAndCrossAttentions(
  1191. last_hidden_state=hidden_states,
  1192. past_key_values=past_key_values,
  1193. hidden_states=all_hidden_states,
  1194. attentions=all_attentions,
  1195. cross_attentions=all_cross_attentions,
  1196. )
  1197. @auto_docstring
  1198. class LongT5Model(LongT5PreTrainedModel):
  1199. _keys_to_ignore_on_load_unexpected = [
  1200. r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
  1201. ]
  1202. _tied_weights_keys = {
  1203. "encoder.embed_tokens.weight": "shared.weight",
  1204. "decoder.embed_tokens.weight": "shared.weight",
  1205. }
  1206. def __init__(self, config: LongT5Config):
  1207. super().__init__(config)
  1208. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1209. encoder_config = copy.deepcopy(config)
  1210. encoder_config.is_decoder = False
  1211. encoder_config.use_cache = False
  1212. self.encoder = LongT5Stack(encoder_config)
  1213. decoder_config = copy.deepcopy(config)
  1214. decoder_config.is_decoder = True
  1215. decoder_config.num_layers = config.num_decoder_layers
  1216. self.decoder = LongT5Stack(decoder_config)
  1217. # Initialize weights and apply final processing
  1218. self.post_init()
  1219. def get_input_embeddings(self):
  1220. return self.shared
  1221. def set_input_embeddings(self, new_embeddings):
  1222. self.shared = new_embeddings
  1223. self.encoder.set_input_embeddings(new_embeddings)
  1224. self.decoder.set_input_embeddings(new_embeddings)
  1225. @auto_docstring
  1226. def forward(
  1227. self,
  1228. input_ids: torch.LongTensor | None = None,
  1229. attention_mask: torch.FloatTensor | None = None,
  1230. decoder_input_ids: torch.LongTensor | None = None,
  1231. decoder_attention_mask: torch.BoolTensor | None = None,
  1232. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  1233. past_key_values: Cache | None = None,
  1234. inputs_embeds: torch.Tensor | None = None,
  1235. decoder_inputs_embeds: torch.Tensor | None = None,
  1236. use_cache: bool | None = None,
  1237. output_attentions: bool | None = None,
  1238. output_hidden_states: bool | None = None,
  1239. return_dict: bool | None = None,
  1240. **kwargs,
  1241. ) -> tuple[torch.FloatTensor] | Seq2SeqModelOutput:
  1242. r"""
  1243. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1244. Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
  1245. you should be able to pad the inputs on both the right and the left.
  1246. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1247. [`PreTrainedTokenizer.__call__`] for detail.
  1248. [What are input IDs?](../glossary#input-ids)
  1249. To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
  1250. Training](./longt5#training).
  1251. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1252. Indices of decoder input sequence tokens in the vocabulary.
  1253. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1254. [`PreTrainedTokenizer.__call__`] for details.
  1255. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1256. LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  1257. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1258. `past_key_values`).
  1259. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5
  1260. Training](./longt5#training).
  1261. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1262. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1263. be used by default.
  1264. Example:
  1265. ```python
  1266. >>> from transformers import AutoTokenizer, LongT5Model
  1267. >>> tokenizer = AutoTokenizer.from_pretrained("google/long-t5-local-base")
  1268. >>> model = LongT5Model.from_pretrained("google/long-t5-local-base")
  1269. >>> # Let's try a very long encoder input.
  1270. >>> input_ids = tokenizer(
  1271. ... 100 * "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1272. ... ).input_ids # Batch size 1
  1273. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  1274. >>> # forward pass
  1275. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  1276. >>> last_hidden_states = outputs.last_hidden_state
  1277. ```"""
  1278. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1279. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1280. # Encode if needed (training, first prediction pass)
  1281. if encoder_outputs is None:
  1282. encoder_outputs = self.encoder(
  1283. input_ids=input_ids,
  1284. attention_mask=attention_mask,
  1285. inputs_embeds=inputs_embeds,
  1286. output_attentions=output_attentions,
  1287. output_hidden_states=output_hidden_states,
  1288. return_dict=return_dict,
  1289. )
  1290. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1291. encoder_outputs = BaseModelOutput(
  1292. last_hidden_state=encoder_outputs[0],
  1293. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1294. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1295. )
  1296. hidden_states = encoder_outputs[0]
  1297. # Decode
  1298. decoder_outputs = self.decoder(
  1299. input_ids=decoder_input_ids,
  1300. attention_mask=decoder_attention_mask,
  1301. inputs_embeds=decoder_inputs_embeds,
  1302. past_key_values=past_key_values,
  1303. encoder_hidden_states=hidden_states,
  1304. encoder_attention_mask=attention_mask,
  1305. use_cache=use_cache,
  1306. output_attentions=output_attentions,
  1307. output_hidden_states=output_hidden_states,
  1308. return_dict=return_dict,
  1309. )
  1310. if not return_dict:
  1311. return decoder_outputs + encoder_outputs
  1312. return Seq2SeqModelOutput(
  1313. last_hidden_state=decoder_outputs.last_hidden_state,
  1314. past_key_values=decoder_outputs.past_key_values,
  1315. decoder_hidden_states=decoder_outputs.hidden_states,
  1316. decoder_attentions=decoder_outputs.attentions,
  1317. cross_attentions=decoder_outputs.cross_attentions,
  1318. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1319. encoder_hidden_states=encoder_outputs.hidden_states,
  1320. encoder_attentions=encoder_outputs.attentions,
  1321. )
  1322. @auto_docstring(
  1323. custom_intro="""
  1324. LONGT5 Model with a `language modeling` head on top.
  1325. """
  1326. )
  1327. class LongT5ForConditionalGeneration(LongT5PreTrainedModel, GenerationMixin):
  1328. _keys_to_ignore_on_load_unexpected = [
  1329. r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
  1330. ]
  1331. _tied_weights_keys = {
  1332. "encoder.embed_tokens.weight": "shared.weight",
  1333. "decoder.embed_tokens.weight": "shared.weight",
  1334. "lm_head.weight": "shared.weight",
  1335. }
  1336. def __init__(self, config: LongT5Config):
  1337. super().__init__(config)
  1338. self.model_dim = config.d_model
  1339. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1340. encoder_config = copy.deepcopy(config)
  1341. encoder_config.is_decoder = False
  1342. encoder_config.use_cache = False
  1343. self.encoder = LongT5Stack(encoder_config)
  1344. decoder_config = copy.deepcopy(config)
  1345. decoder_config.is_decoder = True
  1346. decoder_config.num_layers = config.num_decoder_layers
  1347. self.decoder = LongT5Stack(decoder_config)
  1348. self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
  1349. # Initialize weights and apply final processing
  1350. self.post_init()
  1351. def get_input_embeddings(self):
  1352. return self.shared
  1353. def set_input_embeddings(self, new_embeddings):
  1354. self.shared = new_embeddings
  1355. self.encoder.set_input_embeddings(new_embeddings)
  1356. self.decoder.set_input_embeddings(new_embeddings)
  1357. @auto_docstring
  1358. def forward(
  1359. self,
  1360. input_ids: torch.LongTensor | None = None,
  1361. attention_mask: torch.FloatTensor | None = None,
  1362. decoder_input_ids: torch.LongTensor | None = None,
  1363. decoder_attention_mask: torch.BoolTensor | None = None,
  1364. encoder_outputs: tuple[tuple[torch.Tensor]] | None = None,
  1365. past_key_values: Cache | None = None,
  1366. inputs_embeds: torch.FloatTensor | None = None,
  1367. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1368. labels: torch.LongTensor | None = None,
  1369. use_cache: bool | None = None,
  1370. output_attentions: bool | None = None,
  1371. output_hidden_states: bool | None = None,
  1372. return_dict: bool | None = None,
  1373. **kwargs,
  1374. ) -> tuple[torch.FloatTensor] | Seq2SeqLMOutput:
  1375. r"""
  1376. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1377. Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
  1378. you should be able to pad the inputs on both the right and the left.
  1379. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1380. [`PreTrainedTokenizer.__call__`] for detail.
  1381. [What are input IDs?](../glossary#input-ids)
  1382. To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
  1383. Training](./longt5#training).
  1384. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1385. Indices of decoder input sequence tokens in the vocabulary.
  1386. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1387. [`PreTrainedTokenizer.__call__`] for details.
  1388. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1389. LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  1390. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1391. `past_key_values`).
  1392. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5
  1393. Training](./longt5#training).
  1394. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1395. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1396. be used by default.
  1397. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1398. Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
  1399. config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
  1400. labels in `[0, ..., config.vocab_size]`
  1401. Examples:
  1402. ```python
  1403. >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
  1404. >>> tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
  1405. >>> model = LongT5ForConditionalGeneration.from_pretrained(
  1406. ... "Stancld/longt5-tglobal-large-16384-pubmed-3k_steps"
  1407. ... )
  1408. >>> # Let's try a very long input.
  1409. >>> inputs = tokenizer(100 * "studies have shown that owning a dog is good for you ", return_tensors="pt")
  1410. >>> input_ids = inputs.input_ids
  1411. >>> outputs = model.generate(input_ids)
  1412. >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
  1413. abstractthe aim of this article is to provide an overview of the literature on the role of dog
  1414. ```"""
  1415. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1416. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1417. # Encode if needed (training, first prediction pass)
  1418. if encoder_outputs is None:
  1419. # Convert encoder inputs in embeddings if needed
  1420. encoder_outputs = self.encoder(
  1421. input_ids=input_ids,
  1422. attention_mask=attention_mask,
  1423. inputs_embeds=inputs_embeds,
  1424. output_attentions=output_attentions,
  1425. output_hidden_states=output_hidden_states,
  1426. return_dict=return_dict,
  1427. )
  1428. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1429. encoder_outputs = BaseModelOutput(
  1430. last_hidden_state=encoder_outputs[0],
  1431. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1432. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1433. )
  1434. hidden_states = encoder_outputs[0]
  1435. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  1436. # get decoder inputs from shifting lm labels to the right
  1437. decoder_input_ids = self._shift_right(labels)
  1438. # Decode
  1439. decoder_outputs = self.decoder(
  1440. input_ids=decoder_input_ids,
  1441. attention_mask=decoder_attention_mask,
  1442. inputs_embeds=decoder_inputs_embeds,
  1443. past_key_values=past_key_values,
  1444. encoder_hidden_states=hidden_states,
  1445. encoder_attention_mask=attention_mask,
  1446. use_cache=use_cache,
  1447. output_attentions=output_attentions,
  1448. output_hidden_states=output_hidden_states,
  1449. return_dict=return_dict,
  1450. )
  1451. sequence_output = decoder_outputs[0]
  1452. if self.config.tie_word_embeddings:
  1453. sequence_output = sequence_output * (self.model_dim**-0.5)
  1454. lm_logits = self.lm_head(sequence_output)
  1455. loss = None
  1456. if labels is not None:
  1457. loss_fct = CrossEntropyLoss(ignore_index=-100)
  1458. labels = labels.to(lm_logits.device)
  1459. loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
  1460. if not return_dict:
  1461. output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
  1462. return ((loss,) + output) if loss is not None else output
  1463. return Seq2SeqLMOutput(
  1464. loss=loss,
  1465. logits=lm_logits,
  1466. past_key_values=decoder_outputs.past_key_values,
  1467. decoder_hidden_states=decoder_outputs.hidden_states,
  1468. decoder_attentions=decoder_outputs.attentions,
  1469. cross_attentions=decoder_outputs.cross_attentions,
  1470. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1471. encoder_hidden_states=encoder_outputs.hidden_states,
  1472. encoder_attentions=encoder_outputs.attentions,
  1473. )
  1474. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1475. return self._shift_right(labels)
  1476. @auto_docstring
  1477. class LongT5EncoderModel(LongT5PreTrainedModel):
  1478. _tied_weights_keys = {
  1479. "encoder.embed_tokens.weight": "shared.weight",
  1480. }
  1481. _keys_to_ignore_on_load_unexpected = [r"decoder"]
  1482. def __init__(self, config: LongT5Config):
  1483. super().__init__(config)
  1484. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1485. encoder_config = copy.deepcopy(config)
  1486. encoder_config.use_cache = False
  1487. self.encoder = LongT5Stack(encoder_config)
  1488. # Initialize weights and apply final processing
  1489. self.post_init()
  1490. def get_input_embeddings(self):
  1491. return self.shared
  1492. def set_input_embeddings(self, new_embeddings):
  1493. self.shared = new_embeddings
  1494. self.encoder.set_input_embeddings(new_embeddings)
  1495. @auto_docstring
  1496. def forward(
  1497. self,
  1498. input_ids: torch.LongTensor | None = None,
  1499. attention_mask: torch.FloatTensor | None = None,
  1500. inputs_embeds: torch.FloatTensor | None = None,
  1501. output_attentions: bool | None = None,
  1502. output_hidden_states: bool | None = None,
  1503. return_dict: bool | None = None,
  1504. **kwargs,
  1505. ) -> tuple[torch.FloatTensor] | BaseModelOutput:
  1506. r"""
  1507. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1508. Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
  1509. you should be able to pad the inputs on both the right and the left.
  1510. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1511. [`PreTrainedTokenizer.__call__`] for detail.
  1512. To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
  1513. Training](./longt5#training).
  1514. Example:
  1515. ```python
  1516. >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
  1517. >>> tokenizer = AutoTokenizer.from_pretrained("google/long-t5-local-base")
  1518. >>> model = LongT5EncoderModel.from_pretrained("google/long-t5-local-base")
  1519. >>> input_ids = tokenizer(
  1520. ... 100 * "Studies have been shown that owning a dog is good for you ", return_tensors="pt"
  1521. ... ).input_ids # Batch size 1
  1522. >>> outputs = model(input_ids=input_ids)
  1523. >>> last_hidden_states = outputs.last_hidden_state
  1524. ```"""
  1525. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1526. encoder_outputs = self.encoder(
  1527. input_ids=input_ids,
  1528. attention_mask=attention_mask,
  1529. inputs_embeds=inputs_embeds,
  1530. output_attentions=output_attentions,
  1531. output_hidden_states=output_hidden_states,
  1532. return_dict=return_dict,
  1533. )
  1534. return encoder_outputs
  1535. __all__ = ["LongT5EncoderModel", "LongT5ForConditionalGeneration", "LongT5Model", "LongT5PreTrainedModel"]