| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/roberta/modular_roberta.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_roberta.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Callable
- import torch
- import torch.nn as nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ... import initialization as init
- from ...activations import ACT2FN, gelu
- from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
- from ...generation import GenerationMixin
- from ...masking_utils import create_bidirectional_mask, create_causal_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutputWithPastAndCrossAttentions,
- BaseModelOutputWithPoolingAndCrossAttentions,
- CausalLMOutputWithCrossAttentions,
- MaskedLMOutput,
- MultipleChoiceModelOutput,
- QuestionAnsweringModelOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
- )
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...pytorch_utils import apply_chunking_to_forward
- from ...utils import TransformersKwargs, auto_docstring, logging
- from ...utils.generic import can_return_tuple, merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from .configuration_roberta import RobertaConfig
- logger = logging.get_logger(__name__)
- class RobertaEmbeddings(nn.Module):
- """Construct the embeddings from word, position and token_type embeddings."""
- def __init__(self, config):
- super().__init__()
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
- self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.register_buffer(
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
- )
- self.register_buffer(
- "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
- )
- self.padding_idx = config.pad_token_id
- self.position_embeddings = nn.Embedding(
- config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
- )
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- past_key_values_length: int = 0,
- ) -> torch.Tensor:
- if position_ids is None:
- if input_ids is not None:
- # Create the position ids from the input token ids. Any padded tokens remain padded.
- position_ids = self.create_position_ids_from_input_ids(
- input_ids, self.padding_idx, past_key_values_length
- )
- else:
- position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, self.padding_idx)
- if input_ids is not None:
- input_shape = input_ids.size()
- else:
- input_shape = inputs_embeds.size()[:-1]
- batch_size, seq_length = input_shape
- # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
- # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
- # issue #5664
- if token_type_ids is None:
- if hasattr(self, "token_type_ids"):
- # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
- buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
- buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
- token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
- else:
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
- if inputs_embeds is None:
- inputs_embeds = self.word_embeddings(input_ids)
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
- embeddings = inputs_embeds + token_type_embeddings
- position_embeddings = self.position_embeddings(position_ids)
- embeddings = embeddings + position_embeddings
- embeddings = self.LayerNorm(embeddings)
- embeddings = self.dropout(embeddings)
- return embeddings
- @staticmethod
- def create_position_ids_from_inputs_embeds(inputs_embeds, padding_idx):
- """
- We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
- Args:
- inputs_embeds: torch.Tensor
- Returns: torch.Tensor
- """
- input_shape = inputs_embeds.size()[:-1]
- sequence_length = input_shape[1]
- position_ids = torch.arange(
- padding_idx + 1, sequence_length + padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
- )
- return position_ids.unsqueeze(0).expand(input_shape)
- @staticmethod
- def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
- """
- Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
- are ignored. This is modified from fairseq's `utils.make_positions`.
- Args:
- x: torch.Tensor x:
- Returns: torch.Tensor
- """
- # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
- mask = input_ids.ne(padding_idx).int()
- incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
- return incremental_indices.long() + padding_idx
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float | None = None,
- dropout: float = 0.0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- if scaling is None:
- scaling = query.size(-1) ** -0.5
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class RobertaSelfAttention(nn.Module):
- def __init__(self, config, is_causal=False, layer_idx=None):
- super().__init__()
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
- f"heads ({config.num_attention_heads})"
- )
- self.config = config
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.scaling = self.attention_head_size**-0.5
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- self.is_decoder = config.is_decoder
- self.is_causal = is_causal
- self.layer_idx = layer_idx
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- past_key_values: Cache | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.attention_head_size)
- # get all proj
- query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
- key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
- value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
- if past_key_values is not None:
- # decoder-only roberta can have a simple dynamic cache for example
- current_past_key_values = past_key_values
- if isinstance(past_key_values, EncoderDecoderCache):
- current_past_key_values = past_key_values.self_attention_cache
- # save all key/value_layer to cache to be re-used for fast auto-regressive generation
- key_layer, value_layer = current_past_key_values.update(key_layer, value_layer, self.layer_idx)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_layer,
- key_layer,
- value_layer,
- attention_mask,
- dropout=0.0 if not self.training else self.dropout.p,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- return attn_output, attn_weights
- class RobertaCrossAttention(nn.Module):
- def __init__(self, config, is_causal=False, layer_idx=None):
- super().__init__()
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
- f"heads ({config.num_attention_heads})"
- )
- self.config = config
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.scaling = self.attention_head_size**-0.5
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- self.is_causal = is_causal
- self.layer_idx = layer_idx
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.FloatTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- past_key_values: EncoderDecoderCache | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor]:
- # determine input shapes
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.attention_head_size)
- # get query proj
- query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
- is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
- if past_key_values is not None and is_updated:
- # reuse k,v, cross_attentions
- key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
- value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
- else:
- kv_shape = (*encoder_hidden_states.shape[:-1], -1, self.attention_head_size)
- key_layer = self.key(encoder_hidden_states).view(kv_shape).transpose(1, 2)
- value_layer = self.value(encoder_hidden_states).view(kv_shape).transpose(1, 2)
- if past_key_values is not None:
- # save all states to the cache
- key_layer, value_layer = past_key_values.cross_attention_cache.update(
- key_layer, value_layer, self.layer_idx
- )
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- past_key_values.is_updated[self.layer_idx] = True
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_layer,
- key_layer,
- value_layer,
- attention_mask,
- dropout=0.0 if not self.training else self.dropout.p,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- return attn_output, attn_weights
- class RobertaSelfOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class RobertaAttention(nn.Module):
- def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
- super().__init__()
- self.is_cross_attention = is_cross_attention
- attention_class = RobertaCrossAttention if is_cross_attention else RobertaSelfAttention
- self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
- self.output = RobertaSelfOutput(config)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.FloatTensor | None = None,
- past_key_values: Cache | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor]:
- attention_mask = attention_mask if not self.is_cross_attention else encoder_attention_mask
- attention_output, attn_weights = self.self(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- **kwargs,
- )
- attention_output = self.output(attention_output, hidden_states)
- return attention_output, attn_weights
- class RobertaIntermediate(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- class RobertaOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class RobertaLayer(GradientCheckpointingLayer):
- def __init__(self, config, layer_idx=None):
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- self.attention = RobertaAttention(config, is_causal=config.is_decoder, layer_idx=layer_idx)
- self.is_decoder = config.is_decoder
- self.add_cross_attention = config.add_cross_attention
- if self.add_cross_attention:
- if not self.is_decoder:
- raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
- self.crossattention = RobertaAttention(
- config,
- is_causal=False,
- layer_idx=layer_idx,
- is_cross_attention=True,
- )
- self.intermediate = RobertaIntermediate(config)
- self.output = RobertaOutput(config)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.FloatTensor | None = None,
- past_key_values: Cache | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- self_attention_output, _ = self.attention(
- hidden_states,
- attention_mask,
- past_key_values=past_key_values,
- **kwargs,
- )
- attention_output = self_attention_output
- if self.is_decoder and encoder_hidden_states is not None:
- if not hasattr(self, "crossattention"):
- raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
- " by setting `config.add_cross_attention=True`"
- )
- cross_attention_output, _ = self.crossattention(
- self_attention_output,
- None, # attention_mask
- encoder_hidden_states,
- encoder_attention_mask,
- past_key_values=past_key_values,
- **kwargs,
- )
- attention_output = cross_attention_output
- layer_output = apply_chunking_to_forward(
- self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
- )
- return layer_output
- def feed_forward_chunk(self, attention_output):
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output, attention_output)
- return layer_output
- @auto_docstring
- class RobertaPreTrainedModel(PreTrainedModel):
- config_class = RobertaConfig
- base_model_prefix = "roberta"
- supports_gradient_checkpointing = True
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _supports_attention_backend = True
- _can_record_outputs = {
- "hidden_states": RobertaLayer,
- "attentions": RobertaSelfAttention,
- "cross_attentions": RobertaCrossAttention,
- }
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- super()._init_weights(module)
- if isinstance(module, RobertaLMHead):
- init.zeros_(module.bias)
- elif isinstance(module, RobertaEmbeddings):
- init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
- init.zeros_(module.token_type_ids)
- class RobertaEncoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.layer = nn.ModuleList([RobertaLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.FloatTensor | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | BaseModelOutputWithPastAndCrossAttentions:
- for i, layer_module in enumerate(self.layer):
- hidden_states = layer_module(
- hidden_states,
- attention_mask,
- encoder_hidden_states, # as a positional argument for gradient checkpointing
- encoder_attention_mask=encoder_attention_mask,
- past_key_values=past_key_values,
- **kwargs,
- )
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
- )
- class RobertaPooler(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.activation = nn.Tanh()
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- # We "pool" the model by simply taking the hidden state corresponding
- # to the first token.
- first_token_tensor = hidden_states[:, 0]
- pooled_output = self.dense(first_token_tensor)
- pooled_output = self.activation(pooled_output)
- return pooled_output
- @auto_docstring(
- custom_intro="""
- The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
- cross-attention is added between the self-attention layers, following the architecture described in [Attention is
- all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
- Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
- To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
- to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
- `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
- """
- )
- class RobertaModel(RobertaPreTrainedModel):
- _no_split_modules = ["RobertaEmbeddings", "RobertaLayer"]
- def __init__(self, config, add_pooling_layer=True):
- r"""
- add_pooling_layer (bool, *optional*, defaults to `True`):
- Whether to add a pooling layer
- """
- super().__init__(config)
- self.config = config
- self.gradient_checkpointing = False
- self.embeddings = RobertaEmbeddings(config)
- self.encoder = RobertaEncoder(config)
- self.pooler = RobertaPooler(config) if add_pooling_layer else None
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.word_embeddings
- def set_input_embeddings(self, value):
- self.embeddings.word_embeddings = value
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- encoder_hidden_states: torch.Tensor | None = None,
- encoder_attention_mask: torch.Tensor | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if self.config.is_decoder:
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- else:
- use_cache = False
- if use_cache and past_key_values is None:
- past_key_values = (
- EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
- if encoder_hidden_states is not None or self.config.is_encoder_decoder
- else DynamicCache(config=self.config)
- )
- past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
- embedding_output = self.embeddings(
- input_ids=input_ids,
- position_ids=position_ids,
- token_type_ids=token_type_ids,
- inputs_embeds=inputs_embeds,
- past_key_values_length=past_key_values_length,
- )
- attention_mask, encoder_attention_mask = self._create_attention_masks(
- attention_mask=attention_mask,
- encoder_attention_mask=encoder_attention_mask,
- embedding_output=embedding_output,
- encoder_hidden_states=encoder_hidden_states,
- past_key_values=past_key_values,
- )
- encoder_outputs = self.encoder(
- embedding_output,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- past_key_values=past_key_values,
- use_cache=use_cache,
- position_ids=position_ids,
- **kwargs,
- )
- sequence_output = encoder_outputs.last_hidden_state
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
- return BaseModelOutputWithPoolingAndCrossAttentions(
- last_hidden_state=sequence_output,
- pooler_output=pooled_output,
- past_key_values=encoder_outputs.past_key_values,
- )
- def _create_attention_masks(
- self,
- attention_mask,
- encoder_attention_mask,
- embedding_output,
- encoder_hidden_states,
- past_key_values,
- ):
- if self.config.is_decoder:
- attention_mask = create_causal_mask(
- config=self.config,
- inputs_embeds=embedding_output,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- )
- else:
- attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=embedding_output,
- attention_mask=attention_mask,
- )
- if encoder_attention_mask is not None:
- encoder_attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=embedding_output,
- attention_mask=encoder_attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- )
- return attention_mask, encoder_attention_mask
- @auto_docstring(
- custom_intro="""
- RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.
- """
- )
- class RobertaForCausalLM(RobertaPreTrainedModel, GenerationMixin):
- _tied_weights_keys = {
- "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight",
- "lm_head.decoder.bias": "lm_head.bias",
- }
- def __init__(self, config):
- super().__init__(config)
- if not config.is_decoder:
- logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
- self.roberta = RobertaModel(config, add_pooling_layer=False)
- self.lm_head = RobertaLMHead(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.lm_head.decoder
- def set_output_embeddings(self, new_embeddings):
- self.lm_head.decoder = new_embeddings
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- past_key_values: tuple[tuple[torch.FloatTensor]] | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions:
- r"""
- token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
- >= 2. All the value in this tensor should be always < type_vocab_size.
- [What are token type IDs?](../glossary#token-type-ids)
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
- `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
- ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- Example:
- ```python
- >>> from transformers import AutoTokenizer, RobertaForCausalLM, AutoConfig
- >>> import torch
- >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base")
- >>> config = AutoConfig.from_pretrained("FacebookAI/roberta-base")
- >>> config.is_decoder = True
- >>> model = RobertaForCausalLM.from_pretrained("FacebookAI/roberta-base", config=config)
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> prediction_logits = outputs.logits
- ```"""
- if labels is not None:
- use_cache = False
- outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.roberta(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- past_key_values=past_key_values,
- use_cache=use_cache,
- return_dict=True,
- **kwargs,
- )
- hidden_states = outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
- return CausalLMOutputWithCrossAttentions(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- @auto_docstring
- class RobertaForMaskedLM(RobertaPreTrainedModel):
- _tied_weights_keys = {
- "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight",
- "lm_head.decoder.bias": "lm_head.bias",
- }
- def __init__(self, config):
- super().__init__(config)
- if config.is_decoder:
- logger.warning(
- "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
- "bi-directional self-attention."
- )
- self.roberta = RobertaModel(config, add_pooling_layer=False)
- self.lm_head = RobertaLMHead(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.lm_head.decoder
- def set_output_embeddings(self, new_embeddings):
- self.lm_head.decoder = new_embeddings
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | MaskedLMOutput:
- r"""
- token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
- >= 2. All the value in this tensor should be always < type_vocab_size.
- [What are token type IDs?](../glossary#token-type-ids)
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- """
- outputs = self.roberta(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- return_dict=True,
- **kwargs,
- )
- sequence_output = outputs[0]
- prediction_scores = self.lm_head(sequence_output)
- masked_lm_loss = None
- if labels is not None:
- # move labels to correct device
- labels = labels.to(prediction_scores.device)
- loss_fct = CrossEntropyLoss()
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
- return MaskedLMOutput(
- loss=masked_lm_loss,
- logits=prediction_scores,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- class RobertaLMHead(nn.Module):
- """Roberta Head for masked language modeling."""
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
- self.bias = nn.Parameter(torch.zeros(config.vocab_size))
- def forward(self, features, **kwargs):
- x = self.dense(features)
- x = gelu(x)
- x = self.layer_norm(x)
- # project back to size of vocabulary with bias
- x = self.decoder(x)
- return x
- @auto_docstring(
- custom_intro="""
- RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
- pooled output) e.g. for GLUE tasks.
- """
- )
- class RobertaForSequenceClassification(RobertaPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.config = config
- self.roberta = RobertaModel(config, add_pooling_layer=False)
- self.classifier = RobertaClassificationHead(config)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
- r"""
- token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
- >= 2. All the value in this tensor should be always < type_vocab_size.
- [What are token type IDs?](../glossary#token-type-ids)
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- outputs = self.roberta(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- return_dict=True,
- **kwargs,
- )
- sequence_output = outputs[0]
- logits = self.classifier(sequence_output)
- loss = None
- if labels is not None:
- # move labels to correct device
- labels = labels.to(logits.device)
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class RobertaForMultipleChoice(RobertaPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.roberta = RobertaModel(config)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, 1)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
- >= 2. All the value in this tensor should be always < type_vocab_size.
- [What are token type IDs?](../glossary#token-type-ids)
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
- num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
- `input_ids` above)
- position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- """
- num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
- flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
- flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
- flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
- flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
- flat_inputs_embeds = (
- inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
- if inputs_embeds is not None
- else None
- )
- outputs = self.roberta(
- flat_input_ids,
- position_ids=flat_position_ids,
- token_type_ids=flat_token_type_ids,
- attention_mask=flat_attention_mask,
- inputs_embeds=flat_inputs_embeds,
- return_dict=True,
- **kwargs,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output)
- logits = self.classifier(pooled_output)
- reshaped_logits = logits.view(-1, num_choices)
- loss = None
- if labels is not None:
- # move labels to correct device
- labels = labels.to(reshaped_logits.device)
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(reshaped_logits, labels)
- return MultipleChoiceModelOutput(
- loss=loss,
- logits=reshaped_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class RobertaForTokenClassification(RobertaPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.roberta = RobertaModel(config, add_pooling_layer=False)
- classifier_dropout = (
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
- )
- self.dropout = nn.Dropout(classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | TokenClassifierOutput:
- r"""
- token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
- >= 2. All the value in this tensor should be always < type_vocab_size.
- [What are token type IDs?](../glossary#token-type-ids)
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- """
- outputs = self.roberta(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- return_dict=True,
- **kwargs,
- )
- sequence_output = outputs[0]
- sequence_output = self.dropout(sequence_output)
- logits = self.classifier(sequence_output)
- loss = None
- if labels is not None:
- # move labels to correct device
- labels = labels.to(logits.device)
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- class RobertaClassificationHead(nn.Module):
- """Head for sentence-level classification tasks."""
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- classifier_dropout = (
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
- )
- self.dropout = nn.Dropout(classifier_dropout)
- self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
- def forward(self, features, **kwargs):
- x = features[:, 0, :] # take <s> token (equiv. to [CLS])
- x = self.dropout(x)
- x = self.dense(x)
- x = torch.tanh(x)
- x = self.dropout(x)
- x = self.out_proj(x)
- return x
- @auto_docstring
- class RobertaForQuestionAnswering(RobertaPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.roberta = RobertaModel(config, add_pooling_layer=False)
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- start_positions: torch.LongTensor | None = None,
- end_positions: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
- r"""
- token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
- >= 2. All the value in this tensor should be always < type_vocab_size.
- [What are token type IDs?](../glossary#token-type-ids)
- """
- outputs = self.roberta(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- return_dict=True,
- **kwargs,
- )
- sequence_output = outputs[0]
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- return QuestionAnsweringModelOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- __all__ = [
- "RobertaForCausalLM",
- "RobertaForMaskedLM",
- "RobertaForMultipleChoice",
- "RobertaForQuestionAnswering",
- "RobertaForSequenceClassification",
- "RobertaForTokenClassification",
- "RobertaModel",
- "RobertaPreTrainedModel",
- ]
|