| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619 |
- # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch T5 model."""
- import copy
- import math
- import torch
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
- from ...generation import GenerationMixin
- from ...masking_utils import create_bidirectional_mask, create_causal_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPastAndCrossAttentions,
- Seq2SeqLMOutput,
- Seq2SeqModelOutput,
- Seq2SeqQuestionAnsweringModelOutput,
- Seq2SeqSequenceClassifierOutput,
- TokenClassifierOutput,
- )
- from ...modeling_utils import PreTrainedModel
- from ...utils import DUMMY_INPUTS, DUMMY_MASK, auto_docstring, logging, torch_compilable_check
- from .configuration_t5 import T5Config
- logger = logging.get_logger(__name__)
- class T5LayerNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward(self, hidden_states):
- # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
- # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
- # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
- # half-precision inputs is done in fp32
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- # convert into half-precision if necessary
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
- hidden_states = hidden_states.to(self.weight.dtype)
- return self.weight * hidden_states
- try:
- from apex.normalization import FusedRMSNorm
- T5LayerNorm = FusedRMSNorm
- logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
- except ImportError:
- # using the normal T5LayerNorm
- pass
- except Exception:
- logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
- class T5DenseActDense(nn.Module):
- def __init__(self, config: T5Config):
- super().__init__()
- self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
- self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
- self.dropout = nn.Dropout(config.dropout_rate)
- self.act = ACT2FN[config.dense_act_fn]
- def forward(self, hidden_states):
- hidden_states = self.wi(hidden_states)
- hidden_states = self.act(hidden_states)
- hidden_states = self.dropout(hidden_states)
- if (
- isinstance(self.wo.weight, torch.Tensor)
- and hidden_states.dtype != self.wo.weight.dtype
- and self.wo.weight.dtype != torch.int8
- ):
- hidden_states = hidden_states.to(self.wo.weight.dtype)
- hidden_states = self.wo(hidden_states)
- return hidden_states
- class T5DenseGatedActDense(nn.Module):
- def __init__(self, config: T5Config):
- super().__init__()
- self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
- self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
- self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
- self.dropout = nn.Dropout(config.dropout_rate)
- self.act = ACT2FN[config.dense_act_fn]
- def forward(self, hidden_states):
- hidden_gelu = self.act(self.wi_0(hidden_states))
- hidden_linear = self.wi_1(hidden_states)
- hidden_states = hidden_gelu * hidden_linear
- hidden_states = self.dropout(hidden_states)
- # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
- # See https://github.com/huggingface/transformers/issues/20287
- # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
- if (
- isinstance(self.wo.weight, torch.Tensor)
- and hidden_states.dtype != self.wo.weight.dtype
- and self.wo.weight.dtype != torch.int8
- ):
- hidden_states = hidden_states.to(self.wo.weight.dtype)
- hidden_states = self.wo(hidden_states)
- return hidden_states
- class T5LayerFF(nn.Module):
- def __init__(self, config: T5Config):
- super().__init__()
- if config.is_gated_act:
- self.DenseReluDense = T5DenseGatedActDense(config)
- else:
- self.DenseReluDense = T5DenseActDense(config)
- self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(self, hidden_states):
- forwarded_states = self.layer_norm(hidden_states)
- forwarded_states = self.DenseReluDense(forwarded_states)
- hidden_states = hidden_states + self.dropout(forwarded_states)
- return hidden_states
- class T5Attention(nn.Module):
- def __init__(
- self,
- config: T5Config,
- has_relative_attention_bias=False,
- layer_idx: int | None = None,
- ):
- super().__init__()
- self.is_decoder = config.is_decoder
- self.has_relative_attention_bias = has_relative_attention_bias
- self.relative_attention_num_buckets = config.relative_attention_num_buckets
- self.relative_attention_max_distance = config.relative_attention_max_distance
- self.d_model = config.d_model
- self.key_value_proj_dim = config.d_kv
- self.n_heads = config.num_heads
- self.dropout = config.dropout_rate
- self.inner_dim = self.n_heads * self.key_value_proj_dim
- self.layer_idx = layer_idx
- if layer_idx is None and self.is_decoder:
- logger.warning_once(
- f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
- "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
- self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
- if self.has_relative_attention_bias:
- self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
- self.gradient_checkpointing = False
- @staticmethod
- def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
- """
- Adapted from Mesh Tensorflow:
- https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
- Translate relative position to a bucket number for relative attention. The relative position is defined as
- memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
- position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
- small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
- positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
- This should allow for more graceful generalization to longer sequences than the model has been trained on
- Args:
- relative_position: an int32 Tensor
- bidirectional: a boolean - whether the attention is bidirectional
- num_buckets: an integer
- max_distance: an integer
- Returns:
- a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
- """
- relative_buckets = 0
- if bidirectional:
- num_buckets //= 2
- relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
- relative_position = torch.abs(relative_position)
- else:
- relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
- # now relative_position is in the range [0, inf)
- # half of the buckets are for exact increments in positions
- max_exact = num_buckets // 2
- is_small = relative_position < max_exact
- # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
- relative_position_if_large = max_exact + (
- torch.log(relative_position.float() / max_exact)
- / math.log(max_distance / max_exact)
- * (num_buckets - max_exact)
- ).to(torch.long)
- relative_position_if_large = torch.min(
- relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
- )
- relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
- return relative_buckets
- def compute_bias(self, query_length, key_length, device=None, past_seen_tokens=0):
- """Compute binned relative position bias"""
- if device is None:
- device = self.relative_attention_bias.weight.device
- context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + past_seen_tokens
- memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
- relative_position = memory_position - context_position # shape (query_length, key_length)
- relative_position_bucket = self._relative_position_bucket(
- relative_position, # shape (query_length, key_length)
- bidirectional=(not self.is_decoder),
- num_buckets=self.relative_attention_num_buckets,
- max_distance=self.relative_attention_max_distance,
- )
- values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
- values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
- return values
- def forward(
- self,
- hidden_states,
- mask=None,
- key_value_states=None,
- position_bias=None,
- past_key_values=None,
- output_attentions=False,
- **kwargs,
- ):
- """
- Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
- """
- # Input is (batch_size, seq_length, dim)
- # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.key_value_proj_dim)
- past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0
- # We clone here for StaticCache, as we get the value before updating it, but use it after and it's the same ref
- past_seen_tokens = past_seen_tokens.clone() if isinstance(past_seen_tokens, torch.Tensor) else past_seen_tokens
- # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
- is_cross_attention = key_value_states is not None
- query_states = self.q(hidden_states).view(hidden_shape).transpose(1, 2)
- # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
- is_updated = False
- if isinstance(past_key_values, EncoderDecoderCache):
- is_updated = past_key_values.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_states from cache
- curr_past_key_values = past_key_values.cross_attention_cache
- else:
- curr_past_key_values = past_key_values.self_attention_cache
- else:
- curr_past_key_values = past_key_values
- current_states = key_value_states if is_cross_attention else hidden_states
- if is_cross_attention and past_key_values is not None and is_updated:
- # reuse k,v, cross_attentions
- key_states = curr_past_key_values.layers[self.layer_idx].keys
- value_states = curr_past_key_values.layers[self.layer_idx].values
- else:
- kv_shape = (*current_states.shape[:-1], -1, self.key_value_proj_dim)
- key_states = self.k(current_states).view(kv_shape).transpose(1, 2)
- value_states = self.v(current_states).view(kv_shape).transpose(1, 2)
- if past_key_values is not None:
- key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
- past_key_values.is_updated[self.layer_idx] = True
- # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
- scores = torch.matmul(query_states, key_states.transpose(3, 2))
- if position_bias is None:
- key_length = key_states.shape[-2]
- if not self.has_relative_attention_bias:
- position_bias = torch.zeros(
- (1, query_states.shape[1], input_shape[1], key_length), device=scores.device, dtype=scores.dtype
- )
- if self.gradient_checkpointing and self.training:
- position_bias.requires_grad = True
- else:
- position_bias = self.compute_bias(
- input_shape[1], key_length, device=scores.device, past_seen_tokens=past_seen_tokens
- )
- if mask is not None:
- causal_mask = mask[:, :, :, : key_states.shape[-2]]
- position_bias = position_bias + causal_mask
- position_bias_masked = position_bias
- scores += position_bias_masked
- # (batch_size, n_heads, seq_length, key_length)
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
- attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(*input_shape, -1)
- attn_output = self.o(attn_output)
- outputs = (attn_output, position_bias)
- if output_attentions:
- outputs = outputs + (attn_weights,)
- return outputs
- class T5LayerSelfAttention(nn.Module):
- def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
- super().__init__()
- self.SelfAttention = T5Attention(
- config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
- )
- self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- position_bias=None,
- past_key_values=None,
- use_cache=False,
- output_attentions=False,
- **kwargs,
- ):
- normed_hidden_states = self.layer_norm(hidden_states)
- attention_output = self.SelfAttention(
- normed_hidden_states,
- mask=attention_mask,
- position_bias=position_bias,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- )
- hidden_states = hidden_states + self.dropout(attention_output[0])
- outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
- return outputs
- class T5LayerCrossAttention(nn.Module):
- def __init__(self, config, layer_idx: int | None = None):
- super().__init__()
- self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
- self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(
- self,
- hidden_states,
- key_value_states,
- attention_mask=None,
- position_bias=None,
- past_key_values=None,
- output_attentions=False,
- **kwargs,
- ):
- normed_hidden_states = self.layer_norm(hidden_states)
- attention_output = self.EncDecAttention(
- normed_hidden_states,
- mask=attention_mask,
- key_value_states=key_value_states,
- position_bias=position_bias,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- )
- layer_output = hidden_states + self.dropout(attention_output[0])
- outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
- return outputs
- class T5Block(GradientCheckpointingLayer):
- def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
- super().__init__()
- self.is_decoder = config.is_decoder
- self.layer = nn.ModuleList()
- self.layer.append(
- T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
- )
- if self.is_decoder:
- self.layer.append(T5LayerCrossAttention(config, layer_idx=layer_idx))
- self.layer.append(T5LayerFF(config))
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- position_bias=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- encoder_decoder_position_bias=None,
- past_key_values=None,
- use_cache=False,
- output_attentions=False,
- return_dict=True,
- **kwargs,
- ):
- self_attention_outputs = self.layer[0](
- hidden_states,
- attention_mask=attention_mask,
- position_bias=position_bias,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- )
- hidden_states = self_attention_outputs[0]
- attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
- # clamp inf values to enable fp16 training
- if hidden_states.dtype == torch.float16:
- clamp_value = torch.where(
- torch.isinf(hidden_states).any(),
- torch.finfo(hidden_states.dtype).max - 1000,
- torch.finfo(hidden_states.dtype).max,
- )
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- do_cross_attention = self.is_decoder and encoder_hidden_states is not None
- if do_cross_attention:
- cross_attention_outputs = self.layer[1](
- hidden_states,
- key_value_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- position_bias=encoder_decoder_position_bias,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- )
- hidden_states = cross_attention_outputs[0]
- # clamp inf values to enable fp16 training
- if hidden_states.dtype == torch.float16:
- clamp_value = torch.where(
- torch.isinf(hidden_states).any(),
- torch.finfo(hidden_states.dtype).max - 1000,
- torch.finfo(hidden_states.dtype).max,
- )
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- # Keep cross-attention outputs and relative position weights
- attention_outputs = attention_outputs + cross_attention_outputs[1:]
- # Apply Feed Forward layer
- hidden_states = self.layer[-1](hidden_states)
- # clamp inf values to enable fp16 training
- if hidden_states.dtype == torch.float16:
- clamp_value = torch.where(
- torch.isinf(hidden_states).any(),
- torch.finfo(hidden_states.dtype).max - 1000,
- torch.finfo(hidden_states.dtype).max,
- )
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- outputs = (hidden_states,)
- return (
- outputs + attention_outputs
- ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
- class T5ClassificationHead(nn.Module):
- """Head for sentence-level classification tasks."""
- def __init__(self, config: T5Config):
- super().__init__()
- self.dense = nn.Linear(config.d_model, config.d_model)
- self.dropout = nn.Dropout(p=config.classifier_dropout)
- self.out_proj = nn.Linear(config.d_model, config.num_labels)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.dense(hidden_states)
- hidden_states = torch.tanh(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.out_proj(hidden_states)
- return hidden_states
- @auto_docstring
- class T5PreTrainedModel(PreTrainedModel):
- config: T5Config
- base_model_prefix = "transformer"
- supports_gradient_checkpointing = True
- _can_compile_fullgraph = True
- _no_split_modules = ["T5Block"]
- _keep_in_fp32_modules = ["wo"]
- @property
- def dummy_inputs(self):
- input_ids = torch.tensor(DUMMY_INPUTS)
- input_mask = torch.tensor(DUMMY_MASK)
- dummy_inputs = {
- "decoder_input_ids": input_ids,
- "input_ids": input_ids,
- "decoder_attention_mask": input_mask,
- }
- return dummy_inputs
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- factor = self.config.initializer_factor # Used for testing weights initialization
- if isinstance(module, T5LayerNorm):
- init.constant_(module.weight, factor * 1.0)
- elif isinstance(
- module,
- (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering),
- ):
- init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
- if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
- init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0)
- if hasattr(module, "qa_outputs"):
- init.normal_(module.qa_outputs.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
- init.zeros_(module.qa_outputs.bias)
- elif isinstance(module, T5ForTokenClassification):
- if hasattr(module, "classifier"):
- init.normal_(module.classifier.weight, mean=0.0, std=factor * 1.0)
- init.zeros_(module.classifier.bias)
- elif isinstance(module, T5ClassificationHead):
- init.normal_(module.dense.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
- if hasattr(module.dense, "bias") and module.dense.bias is not None:
- init.zeros_(module.dense.bias)
- init.normal_(module.out_proj.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
- if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
- init.zeros_(module.out_proj.bias)
- elif isinstance(module, T5DenseActDense):
- init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
- if hasattr(module.wi, "bias") and module.wi.bias is not None:
- init.zeros_(module.wi.bias)
- init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
- if hasattr(module.wo, "bias") and module.wo.bias is not None:
- init.zeros_(module.wo.bias)
- elif isinstance(module, T5DenseGatedActDense):
- init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
- if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
- init.zeros_(module.wi_0.bias)
- init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
- if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
- init.zeros_(module.wi_1.bias)
- init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
- if hasattr(module.wo, "bias") and module.wo.bias is not None:
- init.zeros_(module.wo.bias)
- elif isinstance(module, T5Attention):
- d_model = self.config.d_model
- key_value_proj_dim = self.config.d_kv
- n_heads = self.config.num_heads
- init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
- init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5))
- init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5))
- init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
- if module.has_relative_attention_bias:
- init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5))
- def _shift_right(self, input_ids):
- decoder_start_token_id = self.config.decoder_start_token_id
- pad_token_id = self.config.pad_token_id
- if decoder_start_token_id is None:
- raise ValueError(
- "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
- "See T5 docs for more information."
- )
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
- shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
- shifted_input_ids[..., 0] = decoder_start_token_id
- if pad_token_id is None:
- raise ValueError("self.model.config.pad_token_id has to be defined.")
- # replace possible -100 values in labels by `pad_token_id`
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
- return shifted_input_ids
- class T5Stack(T5PreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
- self.is_decoder = config.is_decoder
- self.block = nn.ModuleList(
- [T5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)]
- )
- self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- # Initialize weights and apply final processing
- self.post_init()
- self.gradient_checkpointing = False
- def set_input_embeddings(self, new_embeddings):
- self.embed_tokens = new_embeddings
- def forward(
- self,
- input_ids=None,
- attention_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- inputs_embeds=None,
- past_key_values=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- **kwargs,
- ):
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if input_ids is not None and inputs_embeds is not None:
- err_msg_prefix = "decoder_" if self.is_decoder else ""
- raise ValueError(
- f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
- )
- elif input_ids is not None:
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- err_msg_prefix = "decoder_" if self.is_decoder else ""
- raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
- if inputs_embeds is None:
- if self.embed_tokens is None:
- raise ValueError("You have to initialize the model with valid token embeddings")
- inputs_embeds = self.embed_tokens(input_ids)
- batch_size, seq_length = input_shape
- if use_cache is True:
- if not self.is_decoder:
- raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
- if self.is_decoder:
- if use_cache and past_key_values is None:
- if self.config.is_encoder_decoder:
- past_key_values = EncoderDecoderCache(
- DynamicCache(config=self.config), DynamicCache(config=self.config)
- )
- else:
- past_key_values = DynamicCache(config=self.config)
- elif not self.is_decoder:
- # do not pass cache object down the line for encoder stack
- # it messes indexing later in decoder-stack because cache object is modified in-place
- past_key_values = None
- if self.config.is_decoder:
- attention_mask = create_causal_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- past_key_values=past_key_values.self_attention_cache
- if isinstance(past_key_values, EncoderDecoderCache)
- else past_key_values,
- )
- else:
- attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- )
- encoder_extended_attention_mask = None
- if self.is_decoder and encoder_hidden_states is not None:
- encoder_extended_attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=encoder_attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- )
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- all_cross_attentions = () if (output_attentions and self.is_decoder) else None
- position_bias = None
- encoder_decoder_position_bias = None
- hidden_states = self.dropout(inputs_embeds)
- for layer_module in self.block:
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_outputs = layer_module(
- hidden_states,
- attention_mask,
- position_bias,
- encoder_hidden_states,
- encoder_extended_attention_mask,
- encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- return_dict=return_dict,
- )
- hidden_states = layer_outputs[0]
- # We share the position biases between the layers - the first layer store them
- # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
- # (cross-attention position bias), (cross-attention weights)
- position_bias = layer_outputs[1]
- if self.is_decoder and encoder_hidden_states is not None:
- encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
- if output_attentions:
- all_attentions = all_attentions + (layer_outputs[2],)
- if self.is_decoder:
- all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
- hidden_states = self.final_layer_norm(hidden_states)
- hidden_states = self.dropout(hidden_states)
- # Add last layer
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(
- v
- for v in [
- hidden_states,
- past_key_values,
- all_hidden_states,
- all_attentions,
- all_cross_attentions,
- ]
- if v is not None
- )
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- hidden_states=all_hidden_states,
- attentions=all_attentions,
- cross_attentions=all_cross_attentions,
- )
- @auto_docstring
- class T5Model(T5PreTrainedModel):
- _keys_to_ignore_on_load_unexpected = [
- "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
- ]
- _tied_weights_keys = {
- "encoder.embed_tokens.weight": "shared.weight",
- "decoder.embed_tokens.weight": "shared.weight",
- }
- def __init__(self, config: T5Config):
- super().__init__(config)
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
- encoder_config = copy.deepcopy(config)
- encoder_config.is_decoder = False
- encoder_config.use_cache = False
- self.encoder = T5Stack(encoder_config)
- decoder_config = copy.deepcopy(config)
- decoder_config.is_decoder = True
- decoder_config.num_layers = config.num_decoder_layers
- self.decoder = T5Stack(decoder_config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.shared
- def set_input_embeddings(self, new_embeddings):
- self.shared = new_embeddings
- self.encoder.set_input_embeddings(new_embeddings)
- self.decoder.set_input_embeddings(new_embeddings)
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.BoolTensor | None = None,
- encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.Tensor | None = None,
- decoder_inputs_embeds: torch.Tensor | None = None,
- use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple[torch.FloatTensor] | Seq2SeqModelOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
- should be able to pad the inputs on both the right and the left.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for detail.
- [What are input IDs?](../glossary#input-ids)
- To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids)
- T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
- is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
- To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
- Training](./t5#training).
- decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, T5Model
- >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
- >>> model = T5Model.from_pretrained("google-t5/t5-small")
- >>> input_ids = tokenizer(
- ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
- ... ).input_ids # Batch size 1
- >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
- >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
- >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
- >>> decoder_input_ids = model._shift_right(decoder_input_ids)
- >>> # forward pass
- >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
- >>> last_hidden_states = outputs.last_hidden_state
- ```"""
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- # Encode if needed (training, first prediction pass)
- if encoder_outputs is None:
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
- encoder_outputs = BaseModelOutput(
- last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
- )
- hidden_states = encoder_outputs[0]
- # Decode
- decoder_outputs = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- inputs_embeds=decoder_inputs_embeds,
- past_key_values=past_key_values,
- encoder_hidden_states=hidden_states,
- encoder_attention_mask=attention_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- if not return_dict:
- return decoder_outputs + encoder_outputs
- return Seq2SeqModelOutput(
- last_hidden_state=decoder_outputs.last_hidden_state,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- T5 Model with a `language modeling` head on top.
- """
- )
- class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
- _keys_to_ignore_on_load_unexpected = [
- "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
- ]
- _tied_weights_keys = {
- "lm_head.weight": "shared.weight",
- "encoder.embed_tokens.weight": "shared.weight",
- "decoder.embed_tokens.weight": "shared.weight",
- }
- def __init__(self, config: T5Config):
- super().__init__(config)
- self.model_dim = config.d_model
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
- encoder_config = copy.deepcopy(config)
- encoder_config.is_decoder = False
- encoder_config.use_cache = False
- self.encoder = T5Stack(encoder_config)
- decoder_config = copy.deepcopy(config)
- decoder_config.is_decoder = True
- decoder_config.num_layers = config.num_decoder_layers
- self.decoder = T5Stack(decoder_config)
- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.shared
- def set_input_embeddings(self, new_embeddings):
- self.shared = new_embeddings
- self.encoder.set_input_embeddings(new_embeddings)
- self.decoder.set_input_embeddings(new_embeddings)
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.BoolTensor | None = None,
- encoder_outputs: tuple[tuple[torch.Tensor]] | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple[torch.FloatTensor] | Seq2SeqLMOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
- should be able to pad the inputs on both the right and the left.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for detail.
- [What are input IDs?](../glossary#input-ids)
- To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids)
- T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
- is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
- To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
- Training](./t5#training).
- decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
- config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
- labels in `[0, ..., config.vocab_size]`
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, T5ForConditionalGeneration
- >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
- >>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
- >>> # training
- >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
- >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
- >>> outputs = model(input_ids=input_ids, labels=labels)
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- >>> # inference
- >>> input_ids = tokenizer(
- ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
- ... ).input_ids # Batch size 1
- >>> outputs = model.generate(input_ids)
- >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
- >>> # studies have shown that owning a dog is good for you.
- ```"""
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- # Encode if needed (training, first prediction pass)
- if encoder_outputs is None:
- # Convert encoder inputs in embeddings if needed
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
- encoder_outputs = BaseModelOutput(
- last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
- )
- hidden_states = encoder_outputs[0]
- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
- # get decoder inputs from shifting lm labels to the right
- decoder_input_ids = self._shift_right(labels)
- # Decode
- decoder_outputs = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- inputs_embeds=decoder_inputs_embeds,
- past_key_values=past_key_values,
- encoder_hidden_states=hidden_states,
- encoder_attention_mask=attention_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = decoder_outputs[0]
- if self.config.scale_decoder_outputs:
- sequence_output = sequence_output * (self.model_dim**-0.5)
- lm_logits = self.lm_head(sequence_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss(ignore_index=-100)
- # move labels to correct device to enable PP
- labels = labels.to(lm_logits.device)
- loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
- if not return_dict:
- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
- return ((loss,) + output) if loss is not None else output
- return Seq2SeqLMOutput(
- loss=loss,
- logits=lm_logits,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
- def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
- return self._shift_right(labels)
- @auto_docstring
- class T5EncoderModel(T5PreTrainedModel):
- _tied_weights_keys = {"encoder.embed_tokens.weight": "shared.weight"}
- _keys_to_ignore_on_load_unexpected = [r"decoder"]
- def __init__(self, config: T5Config):
- super().__init__(config)
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
- encoder_config = config
- encoder_config.use_cache = False
- encoder_config.is_encoder_decoder = False
- self.encoder = T5Stack(encoder_config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.shared
- def set_input_embeddings(self, new_embeddings):
- self.shared = new_embeddings
- self.encoder.set_input_embeddings(new_embeddings)
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple[torch.FloatTensor] | BaseModelOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
- should be able to pad the inputs on both the right and the left.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for detail.
- To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
- Example:
- ```python
- >>> from transformers import AutoTokenizer, T5EncoderModel
- >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
- >>> model = T5EncoderModel.from_pretrained("google-t5/t5-small")
- >>> input_ids = tokenizer(
- ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
- ... ).input_ids # Batch size 1
- >>> outputs = model(input_ids=input_ids)
- >>> last_hidden_states = outputs.last_hidden_state
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- return encoder_outputs
- @auto_docstring(
- custom_intro="""
- T5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
- tasks.
- """
- )
- class T5ForSequenceClassification(T5PreTrainedModel):
- _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
- def __init__(self, config: T5Config):
- super().__init__(config)
- self.transformer = T5Model(config)
- self.classification_head = T5ClassificationHead(config)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.LongTensor | None = None,
- encoder_outputs: list[torch.FloatTensor] | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | Seq2SeqSequenceClassifierOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
- should be able to pad the inputs on both the right and the left.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for detail.
- [What are input IDs?](../glossary#input-ids)
- To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids)
- T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
- is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
- To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
- Training](./t5#training).
- decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if labels is not None:
- use_cache = False
- if input_ids is None and inputs_embeds is not None:
- raise NotImplementedError(
- f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
- )
- # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates
- # decoder_input_ids from input_ids if no decoder_input_ids are provided
- if decoder_input_ids is None and decoder_inputs_embeds is None:
- if input_ids is None:
- raise ValueError(
- "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
- "passed, `input_ids` cannot be `None`. Please pass either "
- "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
- )
- decoder_input_ids = self._shift_right(input_ids)
- outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- encoder_outputs=encoder_outputs,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
- torch_compilable_check(
- torch.unique_consecutive(eos_mask.sum(1)).numel() == 1,
- "All examples must have the same number of <eos> tokens.",
- )
- batch_size, _, hidden_size = sequence_output.shape
- sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
- logits = self.classification_head(sentence_representation)
- loss = None
- if labels is not None:
- labels = labels.to(logits.device)
- if self.config.problem_type is None:
- if self.config.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.config.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- if not return_dict:
- output = (logits,) + outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return Seq2SeqSequenceClassifierOutput(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- decoder_hidden_states=outputs.decoder_hidden_states,
- decoder_attentions=outputs.decoder_attentions,
- cross_attentions=outputs.cross_attentions,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_hidden_states=outputs.encoder_hidden_states,
- encoder_attentions=outputs.encoder_attentions,
- )
- @auto_docstring
- class T5ForTokenClassification(T5PreTrainedModel):
- def __init__(self, config: T5Config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.transformer = T5EncoderModel(config)
- self.dropout = nn.Dropout(config.classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor] | TokenClassifierOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
- should be able to pad the inputs on both the right and the left.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for detail.
- [What are input IDs?](../glossary#input-ids)
- To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = outputs[0]
- hidden_states = self.dropout(hidden_states)
- logits = self.classifier(hidden_states)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- if not return_dict:
- output = (logits, outputs[2:-1])
- return ((loss,) + output) if loss is not None else output
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class T5ForQuestionAnswering(T5PreTrainedModel):
- _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
- _tied_weights_keys = {
- "encoder.embed_tokens.weight": "shared.weight",
- "decoder.embed_tokens.weight": "shared.weight",
- }
- def __init__(self, config: T5Config):
- super().__init__(config)
- self.model_dim = config.d_model
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
- encoder_config = copy.deepcopy(config)
- encoder_config.is_decoder = False
- encoder_config.use_cache = False
- self.encoder = T5Stack(encoder_config)
- decoder_config = copy.deepcopy(config)
- decoder_config.is_decoder = True
- decoder_config.num_layers = config.num_decoder_layers
- self.decoder = T5Stack(decoder_config)
- self.num_labels = config.num_labels
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.shared
- def set_input_embeddings(self, new_embeddings):
- self.shared = new_embeddings
- self.encoder.set_input_embeddings(new_embeddings)
- self.decoder.set_input_embeddings(new_embeddings)
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- decoder_input_ids: torch.LongTensor | None = None,
- decoder_attention_mask: torch.BoolTensor | None = None,
- encoder_outputs: tuple[tuple[torch.Tensor]] | None = None,
- start_positions: torch.LongTensor | None = None,
- end_positions: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- decoder_inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple[torch.FloatTensor] | Seq2SeqQuestionAnsweringModelOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
- should be able to pad the inputs on both the right and the left.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for detail.
- [What are input IDs?](../glossary#input-ids)
- To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids)
- T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
- is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
- To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
- Training](./t5#training).
- decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- if start_positions is not None and end_positions is not None:
- use_cache = False
- # Copied from models.bart.modeling_bart.BartModel.forward
- # different to other models, T5 automatically creates decoder_input_ids from
- # input_ids if no decoder_input_ids are provided
- if decoder_input_ids is None and decoder_inputs_embeds is None:
- if input_ids is None:
- raise ValueError(
- "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
- "passed, `input_ids` cannot be `None`. Please pass either "
- "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
- )
- decoder_input_ids = self._shift_right(input_ids)
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- # Encode if needed (training, first prediction pass)
- if encoder_outputs is None:
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
- encoder_outputs = BaseModelOutput(
- last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
- )
- hidden_states = encoder_outputs[0]
- # Decode
- decoder_outputs = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- inputs_embeds=decoder_inputs_embeds,
- past_key_values=None,
- encoder_hidden_states=hidden_states,
- encoder_attention_mask=attention_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = decoder_outputs[0]
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1).to(start_logits.device)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1).to(end_logits.device)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- if not return_dict:
- output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs
- return ((total_loss,) + output) if total_loss is not None else output
- return Seq2SeqQuestionAnsweringModelOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
- __all__ = [
- "T5EncoderModel",
- "T5ForConditionalGeneration",
- "T5Model",
- "T5PreTrainedModel",
- "T5ForQuestionAnswering",
- "T5ForSequenceClassification",
- "T5ForTokenClassification",
- ]
|