| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821 |
- # Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch LongT5 model."""
- import copy
- import math
- from typing import Any
- import torch
- from torch import nn
- from torch.nn import CrossEntropyLoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
- from ...generation import GenerationMixin
- from ...masking_utils import create_causal_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPastAndCrossAttentions,
- Seq2SeqLMOutput,
- Seq2SeqModelOutput,
- )
- from ...modeling_utils import PreTrainedModel
- from ...utils import (
- DUMMY_INPUTS,
- DUMMY_MASK,
- auto_docstring,
- is_torchdynamo_compiling,
- logging,
- )
- from .configuration_longt5 import LongT5Config
- logger = logging.get_logger(__name__)
- # TODO: Update before the merge
- def _pad_to_multiple(x: torch.Tensor, block_len: int, dim: int, pad_value: int = 0) -> torch.Tensor:
- """Pad a tensor so that a sequence length will be a multiple of `block_len`"""
- pad_len = -x.shape[dim] % block_len
- # Handle cases when an empty input sequence is given
- if not all(x.shape):
- new_shape = list(x.shape)
- new_shape[dim] += pad_len
- return torch.zeros(new_shape, dtype=x.dtype)
- pad = [(0, 0)] * x.ndim
- pad[dim] = (0, pad_len)
- pad = sum(pad[::-1], ())
- x = nn.functional.pad(x, pad=pad, mode="constant", value=pad_value)
- return x
- def _split_into_blocks(x: torch.Tensor, block_len: int, dim: int) -> torch.Tensor:
- """Split an input tensor into blocks of a given `block_len` along the given `dim`. If the dimension length
- is not a multiple of `block_len`, it will be padded first with selected `pad_value`.
- """
- # pad tensor to multiple of block_len
- if x.shape[dim] % block_len != 0:
- x = _pad_to_multiple(x, block_len, dim, pad_value=0)
- num_blocks = x.shape[dim] // block_len
- output_shape = x.shape[:dim] + (num_blocks, block_len) + x.shape[(dim + 1) :]
- # If 0 is in output_shape, we cannot apply reshape because of incompatibility with ONNX conversion
- if 0 in output_shape:
- return torch.empty(output_shape, dtype=x.dtype, device=x.device)
- return x.reshape(output_shape)
- def _concatenate_3_blocks(x: torch.Tensor, block_dim: int, sequence_dim: int, pad_value: int = 0) -> torch.Tensor:
- """Concatenate three consecutive blocks for each input block for local attentiont.
- For more information, see: https://huggingface.co/papers/2112.07916.
- """
- num_blocks = x.shape[block_dim]
- pad = [(0, 0)] * x.ndim
- pad[block_dim] = (1, 1)
- pad = sum(pad[::-1], ())
- # [batch_size, num_blocks, block_len] -> [batch_size, num_blocks + 2, block_len]
- x = nn.functional.pad(x, pad=pad, mode="constant", value=pad_value)
- blocks_list: list[torch.Tensor] = []
- for i in range(3):
- # We use indexing approach here:
- # https://numpy.org/doc/stable/user/basics.indexing.html#dealing-with-variable-numbers-of-indices-within-programs
- indices = [slice(0, None)] * x.ndim
- indices[block_dim] = slice(i, i + num_blocks)
- indices = tuple(indices)
- blocks_list.append(x[indices])
- # [batch_size, num_blocks, 3 * block_len, ...]
- return torch.cat(blocks_list, dim=sequence_dim)
- def _make_3block_relative_position_ids(block_len: int) -> torch.Tensor:
- """Makes 3-blocked relative position ids for local attention."""
- position_ids = torch.arange(3 * block_len, dtype=torch.int32)
- center_position_ids = position_ids[block_len:-block_len]
- # [block_len, 3 * block_len]
- relative_position_ids = position_ids.unsqueeze(0) - center_position_ids.unsqueeze(1)
- return relative_position_ids
- def _mask_local_attention_mask(local_attention_mask: torch.Tensor, block_len: int) -> torch.Tensor:
- """Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius."""
- relative_position_ids = _make_3block_relative_position_ids(block_len)
- locality_mask = torch.abs(relative_position_ids) < block_len
- locality_mask = locality_mask[None, None, :, :]
- locality_mask = locality_mask.to(local_attention_mask.device)
- return torch.logical_and(local_attention_mask, locality_mask)
- def _get_local_attention_mask(attention_mask: torch.Tensor, block_len: int, device: torch.device) -> torch.Tensor:
- """Prepare attention mask to be applied for a local attention."""
- # [batch_size, num_blocks, block_len]
- _blocked_attention_mask = _split_into_blocks(attention_mask, block_len, dim=1)
- # [batch_size, num_block, 3 * block_len]
- _3blocked_attention_mask = _concatenate_3_blocks(_blocked_attention_mask, block_dim=1, sequence_dim=2)
- _blocked_attention_mask = _blocked_attention_mask.unsqueeze(-1)
- _3blocked_attention_mask = _3blocked_attention_mask.unsqueeze(-2)
- # [batch_size, num_block, block_len, 3 * block_len]
- local_attention_mask = torch.logical_and(_blocked_attention_mask, _3blocked_attention_mask)
- local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len)
- # [batch_size, 1, num_block, block_len, 3 * block_len]
- return local_attention_mask.unsqueeze(1).to(device)
- def _make_global_fixed_block_ids(
- attention_mask: torch.Tensor, global_block_size: int
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """Obtain the "fixed block" global id corresponding to each input token.
- This implementation is a simplified version of the original Flaxformr implementation adopted from:
- https://github.com/google/flaxformer/blob/main/flaxformer/architectures/longt5/long_attention.py.
- In our scenario, as we use this strategy only for a decoder, orphan tokens, i.e. those tokens which do not make for
- the whole fixed block, are assigned to the preceding block.
- Padding tokens from the original sequence are represented by -1.
- """
- batch_size, seq_len = attention_mask.shape[:2]
- def handle_orphan_tokens(block_ids: torch.Tensor) -> torch.Tensor:
- block_ends = (torch.arange(seq_len) % global_block_size) == global_block_size - 1
- block_ends = block_ends.to(block_ids.device)
- true_block_ends = torch.logical_and(block_ends, block_ids >= 0)
- full_blocks = true_block_ends.sum(-1).unsqueeze(-1).type(block_ids.dtype) - 1
- block_ids = torch.where(block_ids < full_blocks, block_ids, full_blocks)
- return block_ids
- fixed_block_mask = torch.ones_like(attention_mask, device=attention_mask.device) / global_block_size
- fixed_block_mask = torch.cumsum(fixed_block_mask, axis=1) - fixed_block_mask
- mask = torch.where(attention_mask != 0.0, 1.0, -1000.0).type(attention_mask.dtype)
- global_block_ids = torch.floor(mask + fixed_block_mask - 1.0).type(attention_mask.dtype)
- _global_block_ids_lower_bound = torch.tensor(-1, dtype=global_block_ids.dtype, device=global_block_ids.device)
- global_block_ids = torch.where(
- global_block_ids > _global_block_ids_lower_bound, global_block_ids, _global_block_ids_lower_bound
- )
- # set padding tokens to -1
- global_block_ids = (global_block_ids * attention_mask) + (attention_mask - 1)
- # [batch_size, seq_len]
- global_block_ids = handle_orphan_tokens(global_block_ids)
- num_globals = seq_len // global_block_size
- # [batch_size, seq_len // global_block_size]
- if num_globals > 0:
- _sequence_block_ids_max = torch.max(global_block_ids, dim=-1).values.repeat(num_globals, 1).transpose(0, 1)
- else:
- _sequence_block_ids_max = torch.zeros(
- batch_size, 0, dtype=global_block_ids.dtype, device=global_block_ids.device
- )
- global_segment_ids = torch.cumsum(torch.ones(batch_size, num_globals), dim=-1) - 1
- global_segment_ids = global_segment_ids.to(attention_mask.device)
- global_segment_ids = torch.where(global_segment_ids <= _sequence_block_ids_max, 1, 0)
- return global_block_ids.type(torch.int), global_segment_ids.type(torch.int)
- def _make_side_relative_position_ids(attention_mask: torch.Tensor, global_block_size: int) -> torch.Tensor:
- """Create the relative position tensor for local -> global attention."""
- block_ids, global_segment_ids = _make_global_fixed_block_ids(attention_mask, global_block_size)
- global_seq_len = global_segment_ids.shape[-1]
- global_positions = torch.arange(global_seq_len, device=block_ids.device)
- side_relative_position = global_positions - block_ids[..., None]
- return side_relative_position.type(torch.int64)
- def _create_global_aggregates(
- hidden_states: torch.Tensor, block_ids: torch.Tensor, global_seq_len: int
- ) -> torch.Tensor:
- """Compute individual block aggregates by summing over individual blocks."""
- # (batch..., seq_len, global_seq_len))
- block_ids = block_ids.where(
- block_ids >= 0, torch.tensor(global_seq_len, dtype=block_ids.dtype, device=block_ids.device)
- )
- one_hot_block_ids = nn.functional.one_hot(block_ids.type(torch.int64), global_seq_len + 1)[:, :, :-1]
- return torch.einsum("...nd,...ng->...gd", hidden_states, one_hot_block_ids.type(hidden_states.dtype))
- # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->LongT5
- class LongT5LayerNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- Construct a layernorm module in the LongT5 style. No bias and no subtraction of mean.
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward(self, hidden_states):
- # LongT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
- # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
- # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
- # half-precision inputs is done in fp32
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- # convert into half-precision if necessary
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
- hidden_states = hidden_states.to(self.weight.dtype)
- return self.weight * hidden_states
- try:
- from apex.normalization import FusedRMSNorm
- LongT5LayerNorm = FusedRMSNorm
- logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of LongT5LayerNorm")
- except ImportError:
- # using the normal LongT5LayerNorm
- pass
- except Exception:
- logger.warning("discovered apex but it failed to load, falling back to LongT5LayerNorm")
- # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->LongT5
- class LongT5DenseActDense(nn.Module):
- def __init__(self, config: LongT5Config):
- super().__init__()
- self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
- self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
- self.dropout = nn.Dropout(config.dropout_rate)
- self.act = ACT2FN[config.dense_act_fn]
- def forward(self, hidden_states):
- hidden_states = self.wi(hidden_states)
- hidden_states = self.act(hidden_states)
- hidden_states = self.dropout(hidden_states)
- if (
- isinstance(self.wo.weight, torch.Tensor)
- and hidden_states.dtype != self.wo.weight.dtype
- and self.wo.weight.dtype != torch.int8
- ):
- hidden_states = hidden_states.to(self.wo.weight.dtype)
- hidden_states = self.wo(hidden_states)
- return hidden_states
- class LongT5DenseGatedActDense(nn.Module):
- def __init__(self, config: LongT5Config):
- super().__init__()
- self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
- self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
- self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
- self.dropout = nn.Dropout(config.dropout_rate)
- self.act = ACT2FN[config.dense_act_fn]
- def forward(self, hidden_states):
- hidden_gelu = self.act(self.wi_0(hidden_states))
- hidden_linear = self.wi_1(hidden_states)
- hidden_states = hidden_gelu * hidden_linear
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.wo(hidden_states)
- return hidden_states
- # Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->LongT5
- class LongT5LayerFF(nn.Module):
- def __init__(self, config: LongT5Config):
- super().__init__()
- if config.is_gated_act:
- self.DenseReluDense = LongT5DenseGatedActDense(config)
- else:
- self.DenseReluDense = LongT5DenseActDense(config)
- self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(self, hidden_states):
- forwarded_states = self.layer_norm(hidden_states)
- forwarded_states = self.DenseReluDense(forwarded_states)
- hidden_states = hidden_states + self.dropout(forwarded_states)
- return hidden_states
- # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->LongT5
- class LongT5Attention(nn.Module):
- def __init__(
- self,
- config: LongT5Config,
- has_relative_attention_bias=False,
- layer_idx: int | None = None,
- ):
- super().__init__()
- self.is_decoder = config.is_decoder
- self.has_relative_attention_bias = has_relative_attention_bias
- self.relative_attention_num_buckets = config.relative_attention_num_buckets
- self.relative_attention_max_distance = config.relative_attention_max_distance
- self.d_model = config.d_model
- self.key_value_proj_dim = config.d_kv
- self.n_heads = config.num_heads
- self.dropout = config.dropout_rate
- self.inner_dim = self.n_heads * self.key_value_proj_dim
- self.layer_idx = layer_idx
- if layer_idx is None and self.is_decoder:
- logger.warning_once(
- f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
- "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
- self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
- if self.has_relative_attention_bias:
- self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
- self.gradient_checkpointing = False
- @staticmethod
- def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
- """
- Adapted from Mesh Tensorflow:
- https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
- Translate relative position to a bucket number for relative attention. The relative position is defined as
- memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
- position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
- small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
- positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
- This should allow for more graceful generalization to longer sequences than the model has been trained on
- Args:
- relative_position: an int32 Tensor
- bidirectional: a boolean - whether the attention is bidirectional
- num_buckets: an integer
- max_distance: an integer
- Returns:
- a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
- """
- relative_buckets = 0
- if bidirectional:
- num_buckets //= 2
- relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
- relative_position = torch.abs(relative_position)
- else:
- relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
- # now relative_position is in the range [0, inf)
- # half of the buckets are for exact increments in positions
- max_exact = num_buckets // 2
- is_small = relative_position < max_exact
- # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
- relative_position_if_large = max_exact + (
- torch.log(relative_position.float() / max_exact)
- / math.log(max_distance / max_exact)
- * (num_buckets - max_exact)
- ).to(torch.long)
- relative_position_if_large = torch.min(
- relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
- )
- relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
- return relative_buckets
- def compute_bias(self, query_length, key_length, device=None, past_seen_tokens=0):
- """Compute binned relative position bias"""
- if device is None:
- device = self.relative_attention_bias.weight.device
- context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + past_seen_tokens
- memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
- relative_position = memory_position - context_position # shape (query_length, key_length)
- relative_position_bucket = self._relative_position_bucket(
- relative_position, # shape (query_length, key_length)
- bidirectional=(not self.is_decoder),
- num_buckets=self.relative_attention_num_buckets,
- max_distance=self.relative_attention_max_distance,
- )
- values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
- values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
- return values
- def forward(
- self,
- hidden_states,
- mask=None,
- key_value_states=None,
- position_bias=None,
- past_key_values=None,
- output_attentions=False,
- **kwargs,
- ):
- """
- Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
- """
- # Input is (batch_size, seq_length, dim)
- # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.key_value_proj_dim)
- past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0
- # We clone here for StaticCache, as we get the value before updating it, but use it after and it's the same ref
- past_seen_tokens = past_seen_tokens.clone() if isinstance(past_seen_tokens, torch.Tensor) else past_seen_tokens
- # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
- is_cross_attention = key_value_states is not None
- query_states = self.q(hidden_states).view(hidden_shape).transpose(1, 2)
- # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
- is_updated = False
- if isinstance(past_key_values, EncoderDecoderCache):
- is_updated = past_key_values.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_states from cache
- curr_past_key_values = past_key_values.cross_attention_cache
- else:
- curr_past_key_values = past_key_values.self_attention_cache
- else:
- curr_past_key_values = past_key_values
- current_states = key_value_states if is_cross_attention else hidden_states
- if is_cross_attention and past_key_values is not None and is_updated:
- # reuse k,v, cross_attentions
- key_states = curr_past_key_values.layers[self.layer_idx].keys
- value_states = curr_past_key_values.layers[self.layer_idx].values
- else:
- kv_shape = (*current_states.shape[:-1], -1, self.key_value_proj_dim)
- key_states = self.k(current_states).view(kv_shape).transpose(1, 2)
- value_states = self.v(current_states).view(kv_shape).transpose(1, 2)
- if past_key_values is not None:
- key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
- past_key_values.is_updated[self.layer_idx] = True
- # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
- scores = torch.matmul(query_states, key_states.transpose(3, 2))
- if position_bias is None:
- key_length = key_states.shape[-2]
- if not self.has_relative_attention_bias:
- position_bias = torch.zeros(
- (1, query_states.shape[1], input_shape[1], key_length), device=scores.device, dtype=scores.dtype
- )
- if self.gradient_checkpointing and self.training:
- position_bias.requires_grad = True
- else:
- position_bias = self.compute_bias(
- input_shape[1], key_length, device=scores.device, past_seen_tokens=past_seen_tokens
- )
- if mask is not None:
- causal_mask = mask[:, :, :, : key_states.shape[-2]]
- position_bias = position_bias + causal_mask
- position_bias_masked = position_bias
- scores += position_bias_masked
- # (batch_size, n_heads, seq_length, key_length)
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
- attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(*input_shape, -1)
- attn_output = self.o(attn_output)
- outputs = (attn_output, position_bias)
- if output_attentions:
- outputs = outputs + (attn_weights,)
- return outputs
- class LongT5LocalAttention(nn.Module):
- def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = False) -> None:
- super().__init__()
- self.is_decoder = config.is_decoder
- self.has_relative_attention_bias = has_relative_attention_bias
- self.relative_attention_num_buckets = config.relative_attention_num_buckets
- self.relative_attention_max_distance = config.relative_attention_max_distance
- self.d_model = config.d_model
- self.key_value_proj_dim = config.d_kv
- self.n_heads = config.num_heads
- self.local_radius = config.local_radius
- self.block_len = self.local_radius + 1
- self.dropout = config.dropout_rate
- self.inner_dim = self.n_heads * self.key_value_proj_dim
- self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
- if self.has_relative_attention_bias:
- self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
- self.gradient_checkpointing = False
- @staticmethod
- # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
- def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
- """
- Adapted from Mesh Tensorflow:
- https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
- Translate relative position to a bucket number for relative attention. The relative position is defined as
- memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
- position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
- small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
- positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
- This should allow for more graceful generalization to longer sequences than the model has been trained on
- Args:
- relative_position: an int32 Tensor
- bidirectional: a boolean - whether the attention is bidirectional
- num_buckets: an integer
- max_distance: an integer
- Returns:
- a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
- """
- relative_buckets = 0
- if bidirectional:
- num_buckets //= 2
- relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
- relative_position = torch.abs(relative_position)
- else:
- relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
- # now relative_position is in the range [0, inf)
- # half of the buckets are for exact increments in positions
- max_exact = num_buckets // 2
- is_small = relative_position < max_exact
- # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
- relative_position_if_large = max_exact + (
- torch.log(relative_position.float() / max_exact)
- / math.log(max_distance / max_exact)
- * (num_buckets - max_exact)
- ).to(torch.long)
- relative_position_if_large = torch.min(
- relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
- )
- relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
- return relative_buckets
- def compute_bias(self, block_length: int):
- """Compute binned relative position bias"""
- target_device = (
- self.relative_attention_bias.weight.device
- if self.relative_attention_bias.weight.device.type != "meta"
- else None
- )
- memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device)
- context_position = memory_position[block_length:-block_length]
- # (block_length, 3 * block_length)
- relative_position = memory_position[None, :] - context_position[:, None]
- relative_position_bucket = self._relative_position_bucket(
- relative_position, # (block_length, 3 * block_length)
- bidirectional=(not self.is_decoder),
- num_buckets=self.relative_attention_num_buckets,
- max_distance=self.relative_attention_max_distance,
- )
- # (block_length, 3 * block_length, num_heads)
- values = self.relative_attention_bias(relative_position_bucket)
- # (1, 1, num_heads, block_length, 3 * block_length)
- values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0)
- return values
- def forward(
- self,
- hidden_states,
- mask=None,
- position_bias=None,
- output_attentions=False,
- ):
- batch_size, seq_length = hidden_states.shape[:2]
- def shape(states):
- """projection"""
- return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
- def unshape(states):
- """reshape"""
- return states.contiguous().view(batch_size, -1, self.inner_dim)
- # get query/key/value states -> (batch_size, seq_length, n_heads, dim_per_head)
- query_states = shape(self.q(hidden_states))
- key_states = shape(self.k(hidden_states))
- value_states = shape(self.v(hidden_states))
- # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
- query_states = _split_into_blocks(query_states, self.block_len, dim=1)
- key_states = _split_into_blocks(key_states, self.block_len, dim=1)
- value_states = _split_into_blocks(value_states, self.block_len, dim=1)
- # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
- key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
- value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
- # Compute scores
- scores = torch.einsum(
- "...qhd,...khd->...hqk", query_states, key_states
- ) # (batch_size, num_block, n_heads, block_len, 3 * block_len)
- if position_bias is None:
- # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
- if not self.has_relative_attention_bias:
- position_bias = torch.zeros(
- (1, 1, self.n_heads, self.block_len, 3 * self.block_len), device=scores.device, dtype=scores.dtype
- )
- if self.gradient_checkpointing and self.training:
- position_bias.requires_grad = True
- else:
- position_bias = self.compute_bias(self.block_len)
- if mask is not None:
- # Replace masked positions with -1e10 (according to the original implementation)
- mask = torch.where(mask > 0, 0.0, -1e10)
- # We need to adjust position bias shape to be sum with mask
- position_bias = position_bias + mask.transpose(1, 2)
- scores += position_bias
- # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
- # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
- attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
- attn_weights = attn_weights.type(value_states.dtype)
- attn_output = unshape(torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states))
- attn_output = attn_output[:, :seq_length, :]
- attn_output = self.o(attn_output)
- outputs = (
- attn_output,
- position_bias,
- )
- if output_attentions:
- outputs = outputs + (attn_weights,)
- return outputs
- class LongT5TransientGlobalAttention(nn.Module):
- def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = False) -> None:
- super().__init__()
- self.is_decoder = config.is_decoder
- self.has_relative_attention_bias = has_relative_attention_bias
- self.relative_attention_num_buckets = config.relative_attention_num_buckets
- self.relative_attention_max_distance = config.relative_attention_max_distance
- self.d_model = config.d_model
- self.key_value_proj_dim = config.d_kv
- self.n_heads = config.num_heads
- self.local_radius = config.local_radius
- self.block_len = self.local_radius + 1
- self.global_block_size = config.global_block_size
- self.dropout = config.dropout_rate
- self.inner_dim = self.n_heads * self.key_value_proj_dim
- self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
- if self.has_relative_attention_bias:
- self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
- # Relativen attention bias & Layer norm for global attention
- if self.has_relative_attention_bias:
- self.global_relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
- self.global_input_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- @staticmethod
- # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
- def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
- """
- Adapted from Mesh Tensorflow:
- https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
- Translate relative position to a bucket number for relative attention. The relative position is defined as
- memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
- position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
- small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
- positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
- This should allow for more graceful generalization to longer sequences than the model has been trained on
- Args:
- relative_position: an int32 Tensor
- bidirectional: a boolean - whether the attention is bidirectional
- num_buckets: an integer
- max_distance: an integer
- Returns:
- a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
- """
- relative_buckets = 0
- if bidirectional:
- num_buckets //= 2
- relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
- relative_position = torch.abs(relative_position)
- else:
- relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
- # now relative_position is in the range [0, inf)
- # half of the buckets are for exact increments in positions
- max_exact = num_buckets // 2
- is_small = relative_position < max_exact
- # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
- relative_position_if_large = max_exact + (
- torch.log(relative_position.float() / max_exact)
- / math.log(max_distance / max_exact)
- * (num_buckets - max_exact)
- ).to(torch.long)
- relative_position_if_large = torch.min(
- relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
- )
- relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
- return relative_buckets
- def compute_bias(self, block_length: int):
- """Compute binned relative position bias"""
- target_device = (
- self.relative_attention_bias.weight.device
- if self.relative_attention_bias.weight.device.type != "meta"
- else None
- )
- memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device)
- context_position = memory_position[block_length:-block_length]
- # (block_length, 3 * block_length)
- relative_position = memory_position[None, :] - context_position[:, None]
- relative_position_bucket = self._relative_position_bucket(
- relative_position, # (block_length, 3 * block_length)
- bidirectional=(not self.is_decoder),
- num_buckets=self.relative_attention_num_buckets,
- max_distance=self.relative_attention_max_distance,
- )
- # (block_length, 3 * block_length, num_heads)
- values = self.relative_attention_bias(relative_position_bucket)
- # (1, 1, num_heads, block_length, 3 * block_length)
- values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0)
- return values
- def compute_side_bias(self, mask: torch.Tensor, global_segment_ids: torch.Tensor) -> torch.Tensor:
- # (batch_size, 1, seq_len, global_seq_len)
- side_attention_mask = torch.eq(mask[..., None], global_segment_ids[:, None, :])[:, None, ...]
- attention_side_bias = torch.where(side_attention_mask > 0, 0.0, -1e10)
- # (batch_size, seq_len, global_seq_len)
- side_relative_position = _make_side_relative_position_ids(mask, self.global_block_size)
- side_relative_position_bucket = self._relative_position_bucket(
- side_relative_position,
- bidirectional=(not self.is_decoder),
- num_buckets=self.relative_attention_num_buckets,
- max_distance=self.relative_attention_max_distance,
- )
- # (batch_size, seq_len, global_seq_len, num_heads)
- side_bias = self.global_relative_attention_bias(side_relative_position_bucket)
- # (batch_size, num_heads, seq_len, global_seq_len)
- side_bias = side_bias.permute([0, 3, 1, 2])
- # (batch_size, num_heads, seq_len, global_seq_len)
- attention_side_bias = attention_side_bias + side_bias
- return attention_side_bias
- def forward(
- self,
- hidden_states,
- mask=None,
- position_bias=None,
- output_attentions=False,
- ):
- batch_size, seq_length = hidden_states.shape[:2]
- def shape(states):
- """projection"""
- return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
- def unshape(states):
- """reshape"""
- return states.contiguous().view(batch_size, -1, self.inner_dim)
- # Prepare components for transient-global attention
- # Obtain block_ids and global_segment_ids
- # global_seq_len := seq_len // self.global_block_size
- # shapes: (batch_size, seq_len) & (batch_size, global_seq_len)
- block_ids, global_segment_ids = _make_global_fixed_block_ids(
- mask if mask is not None else torch.ones(hidden_states.shape[:-1]),
- self.global_block_size,
- )
- # Create global inputs
- _global_seq_len = global_segment_ids.shape[-1]
- global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len)
- global_inputs = self.global_input_layer_norm(global_inputs)
- # get query states -> (batch_size, seq_length, n_heads, dim_per_head)
- query_states = shape(self.q(hidden_states))
- key_states = shape(self.k(hidden_states))
- value_states = shape(self.v(hidden_states))
- # Get global/side key/value states shape: (batch_size, global_seq_len, n_heads, dim_per_head)
- side_key_states = shape(self.k(global_inputs))
- side_value_states = shape(self.v(global_inputs))
- # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
- query_states = _split_into_blocks(query_states, self.block_len, dim=1)
- key_states = _split_into_blocks(key_states, self.block_len, dim=1)
- value_states = _split_into_blocks(value_states, self.block_len, dim=1)
- # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
- key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
- value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
- # Tile side inputs across local key/value blocks
- # New shape: (batch_size, num_blocks, global_seq_len, n_heads, dim_per_head)
- reps = [1] * (side_key_states.ndim + 1)
- reps[1] = key_states.shape[1]
- side_key_states = side_key_states.unsqueeze(1).repeat(reps)
- side_value_states = side_value_states.unsqueeze(1).repeat(reps)
- # Concatenate "local" and "side"/"global" key/value states to allow each token to attend global aggregated ones
- # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head)
- key_states = torch.cat([key_states, side_key_states], dim=2)
- value_states = torch.cat([value_states, side_value_states], dim=2)
- # Compute scores -> (batch_size, num_block, n_heads, block_len, 3 * block_len + global_seq_len)
- scores = torch.einsum("...qhd,...khd->...hqk", query_states, key_states)
- if mask is not None:
- # We need to adjust position bias shape to be sum with mask
- local_attention_mask = _get_local_attention_mask(mask, self.block_len, hidden_states.device)
- # Replace masked positions with -10_000 (according to the original implementation)
- local_attention_mask = torch.where(local_attention_mask > 0, 0.0, -1e10)
- else:
- local_attention_mask = None
- if position_bias is None:
- # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
- if not self.has_relative_attention_bias:
- position_bias = torch.zeros(
- (1, 1, self.n_heads, self.block_len, 3 * self.block_len),
- device=scores.device,
- dtype=scores.dtype,
- )
- if self.gradient_checkpointing and self.training:
- position_bias.requires_grad = True
- else:
- position_bias = self.compute_bias(self.block_len)
- if local_attention_mask is not None:
- # (batch_size, 1, n_heads, block_len, 3 * block_len)
- position_bias = position_bias + local_attention_mask.transpose(1, 2)
- position_bias = position_bias.type(scores.dtype)
- # Calculate global/side bias - shape: # (batch_size, num_heads, seq_len, global_seq_len)
- if mask is None:
- mask = torch.ones(batch_size, seq_length)
- # (batch_size, num_heads, seq_len, global_seq_len)
- side_position_bias = self.compute_side_bias(mask, global_segment_ids)
- # (batch_size, num_blocks, num_heads, block_len, global_seq_len)
- side_position_bias = _split_into_blocks(side_position_bias, self.block_len, dim=-2).transpose(1, 2)
- side_position_bias = side_position_bias.type(scores.dtype).to(scores.device)
- # (batch_size, num_blocks, num_heads, block_len, 3 * block_len + global_seq_len)
- position_bias = torch.cat([position_bias, side_position_bias], dim=-1)
- scores += position_bias
- # (batch_size, num_blocks, n_heads, block_len, 3 * block_len + global_seq_len)
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
- attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
- attn_weights = attn_weights.type(value_states.dtype)
- attn_output = unshape(torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states))
- attn_output = attn_output[:, :seq_length, :]
- attn_output = self.o(attn_output)
- outputs = (attn_output, position_bias)
- if output_attentions:
- outputs = outputs + (attn_weights,)
- return outputs
- # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5
- class LongT5LayerSelfAttention(nn.Module):
- def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
- super().__init__()
- self.SelfAttention = LongT5Attention(
- config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
- )
- self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- position_bias=None,
- past_key_values=None,
- use_cache=False,
- output_attentions=False,
- **kwargs,
- ):
- normed_hidden_states = self.layer_norm(hidden_states)
- attention_output = self.SelfAttention(
- normed_hidden_states,
- mask=attention_mask,
- position_bias=position_bias,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- )
- hidden_states = hidden_states + self.dropout(attention_output[0])
- outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
- return outputs
- class LongT5LayerLocalSelfAttention(nn.Module):
- """Local self attention used in encoder"""
- def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
- super().__init__()
- self.LocalSelfAttention = LongT5LocalAttention(config, has_relative_attention_bias=has_relative_attention_bias)
- self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- position_bias=None,
- output_attentions=False,
- **kwargs: Any, # to accept past_key_values and use_cache kwargs
- ):
- normed_hidden_states = self.layer_norm(hidden_states)
- attention_output = self.LocalSelfAttention(
- normed_hidden_states,
- mask=attention_mask,
- position_bias=position_bias,
- output_attentions=output_attentions,
- )
- hidden_states = hidden_states + self.dropout(attention_output[0])
- outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
- return outputs
- class LongT5LayerTransientGlobalSelfAttention(nn.Module):
- """Transient-Global self attention used in encoder"""
- def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
- super().__init__()
- self.TransientGlobalSelfAttention = LongT5TransientGlobalAttention(
- config, has_relative_attention_bias=has_relative_attention_bias
- )
- self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- position_bias=None,
- output_attentions=False,
- **kwargs: Any, # to accept past_key_values and use_cache kwargs
- ):
- normed_hidden_states = self.layer_norm(hidden_states)
- attention_output = self.TransientGlobalSelfAttention(
- normed_hidden_states,
- mask=attention_mask,
- position_bias=position_bias,
- output_attentions=output_attentions,
- )
- hidden_states = hidden_states + self.dropout(attention_output[0])
- outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
- return outputs
- # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5
- class LongT5LayerCrossAttention(nn.Module):
- def __init__(self, config, layer_idx: int | None = None):
- super().__init__()
- self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
- self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(
- self,
- hidden_states,
- key_value_states,
- attention_mask=None,
- position_bias=None,
- past_key_values=None,
- output_attentions=False,
- **kwargs,
- ):
- normed_hidden_states = self.layer_norm(hidden_states)
- attention_output = self.EncDecAttention(
- normed_hidden_states,
- mask=attention_mask,
- key_value_states=key_value_states,
- position_bias=position_bias,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- )
- layer_output = hidden_states + self.dropout(attention_output[0])
- outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
- return outputs
- class LongT5Block(GradientCheckpointingLayer):
- def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
- super().__init__()
- self.is_decoder = config.is_decoder
- if config.is_decoder:
- attention_layer = LongT5LayerSelfAttention
- elif config.encoder_attention_type == "local":
- attention_layer = LongT5LayerLocalSelfAttention
- elif config.encoder_attention_type == "transient-global":
- attention_layer = LongT5LayerTransientGlobalSelfAttention
- else:
- raise ValueError(
- "For encoder attention mechanism, either `local` or `transient-global` attention type is expected, "
- f"but got {config.encoder_attention_type}."
- )
- self.layer = nn.ModuleList()
- self.layer.append(
- attention_layer(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
- )
- if self.is_decoder:
- self.layer.append(LongT5LayerCrossAttention(config, layer_idx=layer_idx))
- self.layer.append(LongT5LayerFF(config))
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- position_bias=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- encoder_decoder_position_bias=None,
- past_key_values=None,
- use_cache=False,
- output_attentions=False,
- return_dict=True,
- **kwargs,
- ):
- self_attention_outputs = self.layer[0](
- hidden_states,
- attention_mask=attention_mask,
- position_bias=position_bias,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- )
- hidden_states = self_attention_outputs[0]
- attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
- # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
- if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- do_cross_attention = self.is_decoder and encoder_hidden_states is not None
- if do_cross_attention:
- cross_attention_outputs = self.layer[1](
- hidden_states,
- key_value_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- position_bias=encoder_decoder_position_bias,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- )
- hidden_states = cross_attention_outputs[0]
- # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
- if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- # Keep cross-attention outputs and relative position weights
- attention_outputs = attention_outputs + cross_attention_outputs[1:]
- # Apply Feed Forward layer
- hidden_states = self.layer[-1](hidden_states)
- # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
- if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- return (
- (hidden_states,) + attention_outputs
- ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
- @auto_docstring
- class LongT5PreTrainedModel(PreTrainedModel):
- config: LongT5Config
- base_model_prefix = "transformer"
- supports_gradient_checkpointing = True
- _no_split_modules = ["LongT5Block"]
- _can_compile_fullgraph = False # TODO: @raushan more involved due to local/global attn
- @property
- # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs
- def dummy_inputs(self):
- input_ids = torch.tensor(DUMMY_INPUTS)
- input_mask = torch.tensor(DUMMY_MASK)
- dummy_inputs = {
- "decoder_input_ids": input_ids,
- "input_ids": input_ids,
- "decoder_attention_mask": input_mask,
- }
- return dummy_inputs
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- factor = self.config.initializer_factor # Used for testing weights initialization
- if isinstance(module, LongT5LayerNorm):
- init.constant_(module.weight, factor * 1.0)
- elif isinstance(module, (LongT5Model, LongT5ForConditionalGeneration, LongT5EncoderModel)):
- init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
- if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
- init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0)
- elif isinstance(module, LongT5DenseActDense):
- init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
- if hasattr(module.wi, "bias") and module.wi.bias is not None:
- init.zeros_(module.wi.bias)
- init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
- if hasattr(module.wo, "bias") and module.wo.bias is not None:
- init.zeros_(module.wo.bias)
- elif isinstance(module, LongT5DenseGatedActDense):
- init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
- if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
- init.zeros_(module.wi_0.bias)
- init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
- if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
- init.zeros_(module.wi_1.bias)
- init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
- if hasattr(module.wo, "bias") and module.wo.bias is not None:
- init.zeros_(module.wo.bias)
- elif isinstance(module, (LongT5Attention, LongT5LocalAttention, LongT5TransientGlobalAttention)):
- d_model = self.config.d_model
- key_value_proj_dim = self.config.d_kv
- n_heads = self.config.num_heads
- init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
- init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5))
- init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5))
- init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
- if module.has_relative_attention_bias:
- init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5))
- if isinstance(module, LongT5TransientGlobalAttention):
- init.normal_(
- module.global_relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5)
- )
- # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5
- def _shift_right(self, input_ids):
- decoder_start_token_id = self.config.decoder_start_token_id
- pad_token_id = self.config.pad_token_id
- if decoder_start_token_id is None:
- raise ValueError(
- "self.model.config.decoder_start_token_id has to be defined. In LongT5 it is usually set to the pad_token_id. "
- "See LongT5 docs for more information."
- )
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
- shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
- shifted_input_ids[..., 0] = decoder_start_token_id
- if pad_token_id is None:
- raise ValueError("self.model.config.pad_token_id has to be defined.")
- # replace possible -100 values in labels by `pad_token_id`
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
- return shifted_input_ids
- class LongT5Stack(LongT5PreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
- self.is_decoder = config.is_decoder
- self.local_radius = config.local_radius
- self.block_len = self.local_radius + 1
- self.block = nn.ModuleList(
- [
- LongT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i)
- for i in range(config.num_layers)
- ]
- )
- self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- # Copied from transformers.models.t5.modeling_t5.T5Stack.set_input_embeddings
- def set_input_embeddings(self, new_embeddings):
- self.embed_tokens = new_embeddings
- def forward(
- self,
- input_ids=None,
- attention_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- inputs_embeds=None,
- past_key_values=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- **kwargs,
- ):
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if input_ids is not None and inputs_embeds is not None:
- err_msg_prefix = "decoder_" if self.is_decoder else ""
- raise ValueError(
- f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
- )
- elif input_ids is not None:
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- err_msg_prefix = "decoder_" if self.is_decoder else ""
- raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
- if inputs_embeds is None:
- assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
- inputs_embeds = self.embed_tokens(input_ids)
- batch_size, seq_length = input_shape
- if self.is_decoder:
- if use_cache and past_key_values is None:
- if self.config.is_encoder_decoder:
- past_key_values = EncoderDecoderCache(
- DynamicCache(config=self.config), DynamicCache(config=self.config)
- )
- else:
- past_key_values = DynamicCache(config=self.config)
- elif not self.is_decoder:
- # do not pass cache object down the line for encoder stack
- # it messes indexing later in decoder-stack because cache object is modified in-place
- past_key_values = None
- past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
- if attention_mask is None and not is_torchdynamo_compiling():
- # required mask seq length can be calculated via length of past
- mask_seq_length = past_key_values_length + seq_length
- attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
- if self.is_decoder:
- causal_mask = create_causal_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- )
- # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used
- elif self.config.encoder_attention_type == "local":
- causal_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device)
- else: # we need to use both local attention mask and standard extended mask for transient-global attention
- causal_mask = attention_mask
- # If a 2D or 3D attention mask is provided for the cross-attention
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
- if self.is_decoder and encoder_hidden_states is not None:
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
- if encoder_attention_mask is None:
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
- else:
- encoder_extended_attention_mask = None
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- all_cross_attentions = () if (output_attentions and self.is_decoder) else None
- position_bias = None
- encoder_decoder_position_bias = None
- hidden_states = self.dropout(inputs_embeds)
- for i, layer_module in enumerate(self.block):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_outputs = layer_module(
- hidden_states,
- causal_mask,
- position_bias,
- encoder_hidden_states,
- encoder_extended_attention_mask,
- encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- return_dict=return_dict,
- )
- # layer_outputs is a tuple with:
- # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
- hidden_states = layer_outputs[0]
- # We share the position biases between the layers - the first layer store them
- # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
- # (cross-attention position bias), (cross-attention weights)
- position_bias = layer_outputs[1]
- if self.is_decoder and encoder_hidden_states is not None:
- encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
- if output_attentions:
- all_attentions = all_attentions + (layer_outputs[2],)
- if self.is_decoder:
- all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
- hidden_states = self.final_layer_norm(hidden_states)
- hidden_states = self.dropout(hidden_states)
- # Add last layer
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(
- v
- for v in [
- hidden_states,
- past_key_values,
- all_hidden_states,
- all_attentions,
- all_cross_attentions,
- ]
- if v is not None
- )
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- hidden_states=all_hidden_states,
- attentions=all_attentions,
- cross_attentions=all_cross_attentions,
- )
- @auto_docstring
- class LongT5Model(LongT5PreTrainedModel):
- _keys_to_ignore_on_load_unexpected = [
- r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
- ]
- _tied_weights_keys = {
- "encoder.embed_tokens.weight": "shared.weight",
- "decoder.embed_tokens.weight": "shared.weight",
- }
- def __init__(self, config: LongT5Config):
- super().__init__(config)
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
- encoder_config = copy.deepcopy(config)
- encoder_config.is_decoder = False
- encoder_config.use_cache = False
- self.encoder = LongT5Stack(encoder_config)
- decoder_config = copy.deepcopy(config)
- decoder_config.is_decoder = True
- decoder_config.num_layers = config.num_decoder_layers
- self.decoder = LongT5Stack(decoder_config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.shared
- def set_input_embeddings(self, new_embeddings):
- self.shared = new_embeddings
- self.encoder.set_input_embeddings(new_embeddings)
- self.decoder.set_input_embeddings(new_embeddings)
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.BoolTensor | None = None,
- encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.Tensor | None = None,
- decoder_inputs_embeds: torch.Tensor | None = None,
- use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple[torch.FloatTensor] | Seq2SeqModelOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
- you should be able to pad the inputs on both the right and the left.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for detail.
- [What are input IDs?](../glossary#input-ids)
- To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
- Training](./longt5#training).
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids)
- LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
- `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
- `past_key_values`).
- To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5
- Training](./longt5#training).
- decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, LongT5Model
- >>> tokenizer = AutoTokenizer.from_pretrained("google/long-t5-local-base")
- >>> model = LongT5Model.from_pretrained("google/long-t5-local-base")
- >>> # Let's try a very long encoder input.
- >>> input_ids = tokenizer(
- ... 100 * "Studies have been shown that owning a dog is good for you", return_tensors="pt"
- ... ).input_ids # Batch size 1
- >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
- >>> # forward pass
- >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
- >>> last_hidden_states = outputs.last_hidden_state
- ```"""
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- # Encode if needed (training, first prediction pass)
- if encoder_outputs is None:
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
- encoder_outputs = BaseModelOutput(
- last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
- )
- hidden_states = encoder_outputs[0]
- # Decode
- decoder_outputs = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- inputs_embeds=decoder_inputs_embeds,
- past_key_values=past_key_values,
- encoder_hidden_states=hidden_states,
- encoder_attention_mask=attention_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- if not return_dict:
- return decoder_outputs + encoder_outputs
- return Seq2SeqModelOutput(
- last_hidden_state=decoder_outputs.last_hidden_state,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- LONGT5 Model with a `language modeling` head on top.
- """
- )
- class LongT5ForConditionalGeneration(LongT5PreTrainedModel, GenerationMixin):
- _keys_to_ignore_on_load_unexpected = [
- r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
- ]
- _tied_weights_keys = {
- "encoder.embed_tokens.weight": "shared.weight",
- "decoder.embed_tokens.weight": "shared.weight",
- "lm_head.weight": "shared.weight",
- }
- def __init__(self, config: LongT5Config):
- super().__init__(config)
- self.model_dim = config.d_model
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
- encoder_config = copy.deepcopy(config)
- encoder_config.is_decoder = False
- encoder_config.use_cache = False
- self.encoder = LongT5Stack(encoder_config)
- decoder_config = copy.deepcopy(config)
- decoder_config.is_decoder = True
- decoder_config.num_layers = config.num_decoder_layers
- self.decoder = LongT5Stack(decoder_config)
- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.shared
- def set_input_embeddings(self, new_embeddings):
- self.shared = new_embeddings
- self.encoder.set_input_embeddings(new_embeddings)
- self.decoder.set_input_embeddings(new_embeddings)
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.BoolTensor | None = None,
- encoder_outputs: tuple[tuple[torch.Tensor]] | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple[torch.FloatTensor] | Seq2SeqLMOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
- you should be able to pad the inputs on both the right and the left.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for detail.
- [What are input IDs?](../glossary#input-ids)
- To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
- Training](./longt5#training).
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids)
- LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
- `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
- `past_key_values`).
- To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5
- Training](./longt5#training).
- decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
- config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
- labels in `[0, ..., config.vocab_size]`
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
- >>> tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
- >>> model = LongT5ForConditionalGeneration.from_pretrained(
- ... "Stancld/longt5-tglobal-large-16384-pubmed-3k_steps"
- ... )
- >>> # Let's try a very long input.
- >>> inputs = tokenizer(100 * "studies have shown that owning a dog is good for you ", return_tensors="pt")
- >>> input_ids = inputs.input_ids
- >>> outputs = model.generate(input_ids)
- >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
- abstractthe aim of this article is to provide an overview of the literature on the role of dog
- ```"""
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- # Encode if needed (training, first prediction pass)
- if encoder_outputs is None:
- # Convert encoder inputs in embeddings if needed
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
- encoder_outputs = BaseModelOutput(
- last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
- )
- hidden_states = encoder_outputs[0]
- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
- # get decoder inputs from shifting lm labels to the right
- decoder_input_ids = self._shift_right(labels)
- # Decode
- decoder_outputs = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- inputs_embeds=decoder_inputs_embeds,
- past_key_values=past_key_values,
- encoder_hidden_states=hidden_states,
- encoder_attention_mask=attention_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = decoder_outputs[0]
- if self.config.tie_word_embeddings:
- sequence_output = sequence_output * (self.model_dim**-0.5)
- lm_logits = self.lm_head(sequence_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss(ignore_index=-100)
- labels = labels.to(lm_logits.device)
- loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
- if not return_dict:
- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
- return ((loss,) + output) if loss is not None else output
- return Seq2SeqLMOutput(
- loss=loss,
- logits=lm_logits,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
- def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
- return self._shift_right(labels)
- @auto_docstring
- class LongT5EncoderModel(LongT5PreTrainedModel):
- _tied_weights_keys = {
- "encoder.embed_tokens.weight": "shared.weight",
- }
- _keys_to_ignore_on_load_unexpected = [r"decoder"]
- def __init__(self, config: LongT5Config):
- super().__init__(config)
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
- encoder_config = copy.deepcopy(config)
- encoder_config.use_cache = False
- self.encoder = LongT5Stack(encoder_config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.shared
- def set_input_embeddings(self, new_embeddings):
- self.shared = new_embeddings
- self.encoder.set_input_embeddings(new_embeddings)
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple[torch.FloatTensor] | BaseModelOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
- you should be able to pad the inputs on both the right and the left.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for detail.
- To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
- Training](./longt5#training).
- Example:
- ```python
- >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
- >>> tokenizer = AutoTokenizer.from_pretrained("google/long-t5-local-base")
- >>> model = LongT5EncoderModel.from_pretrained("google/long-t5-local-base")
- >>> input_ids = tokenizer(
- ... 100 * "Studies have been shown that owning a dog is good for you ", return_tensors="pt"
- ... ).input_ids # Batch size 1
- >>> outputs = model(input_ids=input_ids)
- >>> last_hidden_states = outputs.last_hidden_state
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- return encoder_outputs
- __all__ = ["LongT5EncoderModel", "LongT5ForConditionalGeneration", "LongT5Model", "LongT5PreTrainedModel"]
|