| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601 |
- # Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- PyTorch XLM model.
- """
- import math
- from collections.abc import Callable
- from dataclasses import dataclass
- import numpy as np
- import torch
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ... import initialization as init
- from ...activations import gelu, get_activation
- from ...cache_utils import DynamicCache, EncoderDecoderCache
- from ...generation import GenerationMixin
- from ...modeling_outputs import (
- BaseModelOutput,
- MaskedLMOutput,
- MultipleChoiceModelOutput,
- QuestionAnsweringModelOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
- )
- from ...modeling_utils import PreTrainedModel
- from ...pytorch_utils import apply_chunking_to_forward
- from ...utils import ModelOutput, auto_docstring, logging
- from .configuration_xlm import XLMConfig
- logger = logging.get_logger(__name__)
- def create_sinusoidal_embeddings(n_pos, dim, out):
- position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
- out.requires_grad = False
- out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
- out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
- out.detach_()
- return out
- def get_masks(slen, lengths, causal, padding_mask=None):
- """
- Generate hidden states mask, and optionally an attention mask.
- """
- alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
- if padding_mask is not None:
- mask = padding_mask
- else:
- assert lengths.max().item() <= slen
- mask = alen < lengths[:, None]
- # attention mask is the same as mask, or triangular inferior attention (causal)
- bs = lengths.size(0)
- if causal:
- attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
- else:
- attn_mask = mask
- # sanity check
- assert mask.size() == (bs, slen)
- assert causal is False or attn_mask.size() == (bs, slen, slen)
- return mask, attn_mask
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for outputs of question answering models using a [`~modeling_utils.XLMSQuADHead`].
- """
- )
- class XLMSquadHeadOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
- Classification loss as the sum of start token, end token (and is_impossible if provided) classification
- losses.
- start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Log probabilities for the top config.start_n_top start token possibilities (beam-search).
- start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Indices for the top config.start_n_top start token possibilities (beam-search).
- end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
- (beam-search).
- end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
- cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Log probabilities for the `is_impossible` label of the answers.
- """
- loss: torch.FloatTensor | None = None
- start_top_log_probs: torch.FloatTensor | None = None
- start_top_index: torch.LongTensor | None = None
- end_top_log_probs: torch.FloatTensor | None = None
- end_top_index: torch.LongTensor | None = None
- cls_logits: torch.FloatTensor | None = None
- class XLMPoolerStartLogits(nn.Module):
- """
- Compute SQuAD start logits from sequence hidden states.
- Args:
- config ([`XLMConfig`]):
- The config used by the model, will be used to grab the `hidden_size` of the model.
- """
- def __init__(self, config: XLMConfig):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, 1)
- def forward(self, hidden_states: torch.FloatTensor, p_mask: torch.FloatTensor | None = None) -> torch.FloatTensor:
- """
- Args:
- hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
- The final hidden states of the model.
- p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
- Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
- should be masked.
- Returns:
- `torch.FloatTensor`: The start logits for SQuAD.
- """
- x = self.dense(hidden_states).squeeze(-1)
- if p_mask is not None:
- if p_mask.dtype == torch.float16:
- x = x * (1 - p_mask) - 65500 * p_mask
- else:
- x = x * (1 - p_mask) - 1e30 * p_mask
- return x
- class XLMPoolerEndLogits(nn.Module):
- """
- Compute SQuAD end logits from sequence hidden states.
- Args:
- config ([`XLMConfig`]):
- The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
- to use.
- """
- def __init__(self, config: XLMConfig):
- super().__init__()
- self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
- self.activation = nn.Tanh()
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dense_1 = nn.Linear(config.hidden_size, 1)
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- start_states: torch.FloatTensor | None = None,
- start_positions: torch.LongTensor | None = None,
- p_mask: torch.FloatTensor | None = None,
- ) -> torch.FloatTensor:
- """
- Args:
- hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
- The final hidden states of the model.
- start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
- The hidden states of the first tokens for the labeled span.
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- The position of the first token for the labeled span.
- p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
- Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
- should be masked.
- <Tip>
- One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
- `start_states`.
- </Tip>
- Returns:
- `torch.FloatTensor`: The end logits for SQuAD.
- """
- assert start_states is not None or start_positions is not None, (
- "One of start_states, start_positions should be not None"
- )
- if start_positions is not None:
- slen, hsz = hidden_states.shape[-2:]
- start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
- start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
- start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
- x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
- x = self.activation(x)
- x = self.LayerNorm(x)
- x = self.dense_1(x).squeeze(-1)
- if p_mask is not None:
- if p_mask.dtype == torch.float16:
- x = x * (1 - p_mask) - 65500 * p_mask
- else:
- x = x * (1 - p_mask) - 1e30 * p_mask
- return x
- class XLMPoolerAnswerClass(nn.Module):
- """
- Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
- Args:
- config ([`XLMConfig`]):
- The config used by the model, will be used to grab the `hidden_size` of the model.
- """
- def __init__(self, config: XLMConfig):
- super().__init__()
- self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
- self.activation = nn.Tanh()
- self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- start_states: torch.FloatTensor | None = None,
- start_positions: torch.LongTensor | None = None,
- cls_index: torch.LongTensor | None = None,
- ) -> torch.FloatTensor:
- """
- Args:
- hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
- The final hidden states of the model.
- start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
- The hidden states of the first tokens for the labeled span.
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- The position of the first token for the labeled span.
- cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
- <Tip>
- One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
- `start_states`.
- </Tip>
- Returns:
- `torch.FloatTensor`: The SQuAD 2.0 answer class.
- """
- # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
- hsz = hidden_states.shape[-1]
- assert start_states is not None or start_positions is not None, (
- "One of start_states, start_positions should be not None"
- )
- if start_positions is not None:
- start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
- start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
- if cls_index is not None:
- cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
- cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
- else:
- cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
- x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
- x = self.activation(x)
- x = self.dense_1(x).squeeze(-1)
- return x
- class XLMSQuADHead(nn.Module):
- r"""
- A SQuAD head inspired by XLNet.
- Args:
- config ([`XLMConfig`]):
- The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
- to use.
- """
- def __init__(self, config: XLMConfig):
- super().__init__()
- self.start_n_top = config.start_n_top
- self.end_n_top = config.end_n_top
- self.start_logits = XLMPoolerStartLogits(config)
- self.end_logits = XLMPoolerEndLogits(config)
- self.answer_class = XLMPoolerAnswerClass(config)
- @auto_docstring
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- start_positions: torch.LongTensor | None = None,
- end_positions: torch.LongTensor | None = None,
- cls_index: torch.LongTensor | None = None,
- is_impossible: torch.LongTensor | None = None,
- p_mask: torch.FloatTensor | None = None,
- return_dict: bool = False,
- ) -> XLMSquadHeadOutput | tuple[torch.FloatTensor]:
- r"""
- hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
- Final hidden states of the model on the sequence tokens.
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Positions of the first token for the labeled span.
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Positions of the last token for the labeled span.
- cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
- is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Whether the question has a possible answer in the paragraph or not.
- p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
- Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
- should be masked.
- """
- start_logits = self.start_logits(hidden_states, p_mask=p_mask)
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, let's remove the dimension added by batch splitting
- for x in (start_positions, end_positions, cls_index, is_impossible):
- if x is not None and x.dim() > 1:
- x.squeeze_(-1)
- # during training, compute the end logits based on the ground truth of the start position
- end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
- loss_fct = CrossEntropyLoss()
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- if cls_index is not None and is_impossible is not None:
- # Predict answerability from the representation of CLS and START
- cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
- loss_fct_cls = nn.BCEWithLogitsLoss()
- cls_loss = loss_fct_cls(cls_logits, is_impossible)
- # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
- total_loss += cls_loss * 0.5
- return XLMSquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
- else:
- # during inference, compute the end logits based on beam search
- bsz, slen, hsz = hidden_states.size()
- start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
- start_top_log_probs, start_top_index = torch.topk(
- start_log_probs, self.start_n_top, dim=-1
- ) # shape (bsz, start_n_top)
- start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
- start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
- start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
- hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
- start_states
- ) # shape (bsz, slen, start_n_top, hsz)
- p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
- end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
- end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
- end_top_log_probs, end_top_index = torch.topk(
- end_log_probs, self.end_n_top, dim=1
- ) # shape (bsz, end_n_top, start_n_top)
- end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
- end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
- start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
- cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
- if not return_dict:
- return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
- else:
- return XLMSquadHeadOutput(
- start_top_log_probs=start_top_log_probs,
- start_top_index=start_top_index,
- end_top_log_probs=end_top_log_probs,
- end_top_index=end_top_index,
- cls_logits=cls_logits,
- )
- class XLMSequenceSummary(nn.Module):
- r"""
- Compute a single vector summary of a sequence hidden states.
- Args:
- config ([`XLMConfig`]):
- The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
- config class of your model for the default values it uses):
- - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
- - `"last"` -- Take the last token hidden state (like XLNet)
- - `"first"` -- Take the first token hidden state (like Bert)
- - `"mean"` -- Take the mean of all tokens hidden states
- - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
- - `"attn"` -- Not implemented now, use multi-head attention
- - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
- - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
- (otherwise to `config.hidden_size`).
- - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
- another string or `None` will add no activation.
- - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
- - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
- """
- def __init__(self, config: XLMConfig):
- super().__init__()
- self.summary_type = getattr(config, "summary_type", "last")
- if self.summary_type == "attn":
- # We should use a standard multi-head attention module with absolute positional embedding for that.
- # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
- # We can probably just use the multi-head attention module of PyTorch >=1.1.0
- raise NotImplementedError
- self.summary = nn.Identity()
- if hasattr(config, "summary_use_proj") and config.summary_use_proj:
- if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
- num_classes = config.num_labels
- else:
- num_classes = config.hidden_size
- self.summary = nn.Linear(config.hidden_size, num_classes)
- activation_string = getattr(config, "summary_activation", None)
- self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
- self.first_dropout = nn.Identity()
- if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
- self.first_dropout = nn.Dropout(config.summary_first_dropout)
- self.last_dropout = nn.Identity()
- if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
- self.last_dropout = nn.Dropout(config.summary_last_dropout)
- def forward(
- self, hidden_states: torch.FloatTensor, cls_index: torch.LongTensor | None = None
- ) -> torch.FloatTensor:
- """
- Compute a single vector summary of a sequence hidden states.
- Args:
- hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
- The hidden states of the last layer.
- cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
- Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
- Returns:
- `torch.FloatTensor`: The summary of the sequence hidden states.
- """
- if self.summary_type == "last":
- output = hidden_states[:, -1]
- elif self.summary_type == "first":
- output = hidden_states[:, 0]
- elif self.summary_type == "mean":
- output = hidden_states.mean(dim=1)
- elif self.summary_type == "cls_index":
- if cls_index is None:
- cls_index = torch.full_like(
- hidden_states[..., :1, :],
- hidden_states.shape[-2] - 1,
- dtype=torch.long,
- )
- else:
- cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
- cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
- # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
- output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
- elif self.summary_type == "attn":
- raise NotImplementedError
- output = self.first_dropout(output)
- output = self.summary(output)
- output = self.activation(output)
- output = self.last_dropout(output)
- return output
- class MultiHeadAttention(nn.Module):
- def __init__(self, n_heads, dim, config, layer_idx: int = 0):
- super().__init__()
- self.layer_id = layer_idx
- self.dim = dim
- self.n_heads = n_heads
- self.head_dim = dim // n_heads
- self.dropout = config.attention_dropout
- assert self.dim % self.n_heads == 0
- self.q_lin = nn.Linear(dim, dim)
- self.k_lin = nn.Linear(dim, dim)
- self.v_lin = nn.Linear(dim, dim)
- self.out_lin = nn.Linear(dim, dim)
- def forward(
- self,
- input,
- mask,
- kv=None,
- cache=None,
- output_attentions=False,
- **kwargs,
- ):
- """
- Self-attention (if kv is None) or attention over source sentence (provided by kv).
- """
- # Input is (bs, qlen, dim)
- # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
- bs, qlen, dim = input.size()
- is_cross_attention = kv is not None
- mask_reshape = (bs, 1, qlen, -1) if mask.dim() == 3 else (bs, 1, 1, -1)
- q = self.q_lin(input).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
- if cache is not None:
- if isinstance(cache, EncoderDecoderCache):
- is_updated = cache.is_updated.get(self.layer_id)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_states from cache
- curr_past_key_values = cache.cross_attention_cache
- else:
- curr_past_key_values = cache.self_attention_cache
- else:
- curr_past_key_values = cache
- current_states = kv if is_cross_attention else input
- if is_cross_attention and cache is not None and is_updated:
- # reuse k,v, cross_attentions
- k = curr_past_key_values.key_cache[self.layer_id]
- v = curr_past_key_values.value_cache[self.layer_id]
- else:
- k = self.k_lin(current_states)
- v = self.v_lin(current_states)
- k = k.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
- v = v.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
- if cache is not None:
- # save all key/value_states to cache to be re-used for fast auto-regressive generation
- k, v = curr_past_key_values.update(k, v, self.layer_id)
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention:
- cache.is_updated[self.layer_id] = True
- q = q / math.sqrt(self.head_dim) # (bs, n_heads, qlen, head_dim)
- scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
- mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
- scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
- weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
- weights = nn.functional.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
- context = torch.matmul(weights, v) # (bs, n_heads, qlen, head_dim)
- context = context.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.head_dim)
- outputs = (self.out_lin(context),)
- if output_attentions:
- outputs = outputs + (weights,)
- return outputs
- class TransformerFFN(nn.Module):
- def __init__(self, in_dim, dim_hidden, out_dim, config):
- super().__init__()
- self.dropout = config.dropout
- self.lin1 = nn.Linear(in_dim, dim_hidden)
- self.lin2 = nn.Linear(dim_hidden, out_dim)
- self.act = gelu if config.gelu_activation else nn.functional.relu
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- def forward(self, input):
- return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
- def ff_chunk(self, input):
- x = self.lin1(input)
- x = self.act(x)
- x = self.lin2(x)
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
- return x
- @auto_docstring
- class XLMPreTrainedModel(PreTrainedModel):
- config: XLMConfig
- base_model_prefix = "transformer"
- @property
- def dummy_inputs(self):
- inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
- attns_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
- if self.config.use_lang_emb and self.config.n_langs > 1:
- langs_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
- else:
- langs_list = None
- return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights."""
- if isinstance(module, nn.Embedding):
- if self.config is not None and self.config.embed_init_std is not None:
- init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
- # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
- if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
- init.zeros_(module.weight[module.padding_idx])
- if isinstance(module, nn.Linear):
- if self.config is not None and self.config.init_std is not None:
- init.normal_(module.weight, mean=0, std=self.config.init_std)
- if module.bias is not None:
- init.constant_(module.bias, 0.0)
- if isinstance(module, nn.LayerNorm):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- if isinstance(module, XLMModel):
- if self.config.sinusoidal_embeddings:
- init.copy_(
- module.position_embeddings.weight,
- create_sinusoidal_embeddings(
- self.config.max_position_embeddings,
- self.config.emb_dim,
- out=torch.empty_like(module.position_embeddings.weight),
- ),
- )
- init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for outputs of question answering models using a `XLMSQuADHead`.
- """
- )
- class XLMForQuestionAnsweringOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
- Classification loss as the sum of start token, end token (and is_impossible if provided) classification
- losses.
- start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Log probabilities for the top config.start_n_top start token possibilities (beam-search).
- start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Indices for the top config.start_n_top start token possibilities (beam-search).
- end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
- (beam-search).
- end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
- cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Log probabilities for the `is_impossible` label of the answers.
- """
- loss: torch.FloatTensor | None = None
- start_top_log_probs: torch.FloatTensor | None = None
- start_top_index: torch.LongTensor | None = None
- end_top_log_probs: torch.FloatTensor | None = None
- end_top_index: torch.LongTensor | None = None
- cls_logits: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- @auto_docstring
- class XLMModel(XLMPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- # encoder / decoder, output layer
- self.is_encoder = config.is_encoder
- self.is_decoder = not config.is_encoder
- if self.is_decoder:
- raise NotImplementedError("Currently XLM can only be used as an encoder")
- # self.with_output = with_output
- self.causal = config.causal
- # dictionary / languages
- self.n_langs = config.n_langs
- self.use_lang_emb = config.use_lang_emb
- self.n_words = config.n_words
- self.eos_index = config.eos_index
- self.pad_index = config.pad_index
- # self.dico = dico
- # self.id2lang = config.id2lang
- # self.lang2id = config.lang2id
- # assert len(self.dico) == self.n_words
- # assert len(self.id2lang) == len(self.lang2id) == self.n_langs
- # model parameters
- self.dim = config.emb_dim # 512 by default
- self.hidden_dim = self.dim * 4 # 2048 by default
- self.n_heads = config.n_heads # 8 by default
- self.n_layers = config.n_layers
- self.dropout = config.dropout
- self.attention_dropout = config.attention_dropout
- assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads"
- # embeddings
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
- if config.n_langs > 1 and config.use_lang_emb:
- self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
- self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
- self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
- # transformer layers
- self.attentions = nn.ModuleList()
- self.layer_norm1 = nn.ModuleList()
- self.ffns = nn.ModuleList()
- self.layer_norm2 = nn.ModuleList()
- # if self.is_decoder:
- # self.layer_norm15 = nn.ModuleList()
- # self.encoder_attn = nn.ModuleList()
- for i in range(self.n_layers):
- self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config, layer_idx=i))
- self.layer_norm1.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
- # if self.is_decoder:
- # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
- # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
- self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
- self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
- # Initialize weights and apply final processing
- self.register_buffer(
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
- )
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings
- def set_input_embeddings(self, new_embeddings):
- self.embeddings = new_embeddings
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- langs: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- lengths: torch.Tensor | None = None,
- cache: dict[str, torch.Tensor] | None = None,
- inputs_embeds: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs, # Dummy kwargs for now
- ) -> tuple | BaseModelOutput:
- r"""
- langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
- languages ids which can be obtained from the language names by using two conversion mappings provided in
- the configuration of the model (only provided for multilingual models). More precisely, the *language name
- to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
- *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
- See usage examples detailed in the [multilingual documentation](../multilingual).
- lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Length of each sentence that can be used to avoid performing attention on padding token indices. You can
- also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
- `[0, ..., input_ids.size(-1)]`.
- cache (`dict[str, torch.FloatTensor]`, *optional*):
- Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
- decoding.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if input_ids is not None:
- bs, slen = input_ids.size()
- else:
- bs, slen = inputs_embeds.size()[:-1]
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if cache is None:
- cache = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
- if lengths is None:
- if input_ids is not None:
- lengths = (input_ids != self.pad_index).sum(dim=1).long()
- else:
- lengths = torch.full((bs,), slen, device=device, dtype=torch.long)
- # check inputs
- assert lengths.size(0) == bs
- assert lengths.max().item() <= slen
- # generate masks
- mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
- # position_ids
- if position_ids is None:
- position_ids = self.position_ids[:, :slen]
- else:
- assert position_ids.size() == (bs, slen) # (slen, bs)
- # langs
- if langs is not None:
- assert langs.size() == (bs, slen) # (slen, bs)
- # do not recompute cached elements
- if cache is not None and input_ids is not None:
- _slen = slen - cache.get_seq_length()
- input_ids = input_ids[:, -_slen:]
- position_ids = position_ids[:, -_slen:]
- if langs is not None:
- langs = langs[:, -_slen:]
- mask = mask[:, -_slen:]
- attn_mask = attn_mask[:, -_slen:]
- # embeddings
- if inputs_embeds is None:
- inputs_embeds = self.embeddings(input_ids)
- tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds)
- if langs is not None and self.use_lang_emb and self.n_langs > 1:
- tensor = tensor + self.lang_embeddings(langs)
- if token_type_ids is not None:
- tensor = tensor + self.embeddings(token_type_ids)
- tensor = self.layer_norm_emb(tensor)
- tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training)
- tensor *= mask.unsqueeze(-1).to(tensor.dtype)
- # transformer layers
- hidden_states = () if output_hidden_states else None
- attentions = () if output_attentions else None
- for i in range(self.n_layers):
- if output_hidden_states:
- hidden_states = hidden_states + (tensor,)
- # self attention
- attn_outputs = self.attentions[i](
- tensor,
- attn_mask,
- cache=cache,
- output_attentions=output_attentions,
- )
- attn = attn_outputs[0]
- if output_attentions:
- attentions = attentions + (attn_outputs[1],)
- attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
- tensor = tensor + attn
- tensor = self.layer_norm1[i](tensor)
- # FFN
- tensor = tensor + self.ffns[i](tensor)
- tensor = self.layer_norm2[i](tensor)
- tensor *= mask.unsqueeze(-1).to(tensor.dtype)
- # Add last hidden state
- if output_hidden_states:
- hidden_states = hidden_states + (tensor,)
- if not return_dict:
- return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
- return BaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
- class XLMPredLayer(nn.Module):
- """
- Prediction layer (cross_entropy or adaptive_softmax).
- """
- def __init__(self, config):
- super().__init__()
- self.asm = config.asm
- self.n_words = config.n_words
- self.pad_index = config.pad_index
- dim = config.emb_dim
- if config.asm is False:
- self.proj = nn.Linear(dim, config.n_words, bias=True)
- else:
- self.proj = nn.AdaptiveLogSoftmaxWithLoss(
- in_features=dim,
- n_classes=config.n_words,
- cutoffs=config.asm_cutoffs,
- div_value=config.asm_div_value,
- head_bias=True, # default is False
- )
- def forward(self, x, y=None):
- """Compute the loss, and optionally the scores."""
- outputs = ()
- if self.asm is False:
- scores = self.proj(x)
- outputs = (scores,) + outputs
- if y is not None:
- loss = nn.functional.cross_entropy(scores.view(-1, self.n_words), y.view(-1), reduction="mean")
- outputs = (loss,) + outputs
- else:
- scores = self.proj.log_prob(x)
- outputs = (scores,) + outputs
- if y is not None:
- _, loss = self.proj(x, y)
- outputs = (loss,) + outputs
- return outputs
- @auto_docstring(
- custom_intro="""
- The XLM Model transformer with a language modeling head on top (linear layer with weights tied to the input
- embeddings).
- """
- )
- class XLMWithLMHeadModel(XLMPreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"pred_layer.proj.weight": "transformer.embeddings.weight"}
- def __init__(self, config):
- super().__init__(config)
- self.transformer = XLMModel(config)
- self.pred_layer = XLMPredLayer(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.pred_layer.proj
- def set_output_embeddings(self, new_embeddings):
- self.pred_layer.proj = new_embeddings
- def prepare_inputs_for_generation(self, input_ids, is_first_iteration=False, **kwargs):
- # Overwritten -- this model uses config options to prepare inputs
- mask_token_id = self.config.mask_token_id
- lang_id = self.config.lang_id
- effective_batch_size = input_ids.shape[0]
- mask_token = torch.full((effective_batch_size, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
- input_ids = torch.cat([input_ids, mask_token], dim=1)
- if lang_id is not None:
- langs = torch.full_like(input_ids, lang_id)
- else:
- langs = None
- model_inputs = {"input_ids": input_ids, "langs": langs}
- # They are calculated on the fly on XLMModel.forward()
- kwargs.pop("token_type_ids", None)
- kwargs.pop("attention_mask", None)
- kwargs.pop("position_ids", None)
- # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
- for key, value in kwargs.items():
- if key not in model_inputs:
- model_inputs[key] = value
- return model_inputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- langs: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- lengths: torch.Tensor | None = None,
- cache: dict[str, torch.Tensor] | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs,
- ) -> tuple | MaskedLMOutput:
- r"""
- langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
- languages ids which can be obtained from the language names by using two conversion mappings provided in
- the configuration of the model (only provided for multilingual models). More precisely, the *language name
- to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
- *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
- See usage examples detailed in the [multilingual documentation](../multilingual).
- lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Length of each sentence that can be used to avoid performing attention on padding token indices. You can
- also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
- `[0, ..., input_ids.size(-1)]`.
- cache (`dict[str, torch.FloatTensor]`, *optional*):
- Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
- decoding.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- transformer_outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- langs=langs,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- lengths=lengths,
- cache=cache,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- **kwargs,
- )
- hidden_states = transformer_outputs[0]
- # Only compute necessary logits
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- outputs = self.pred_layer(
- hidden_states[:, slice_indices, :],
- labels,
- ) # (loss, logits) or (logits,) depending on if labels are provided.
- if not return_dict:
- return outputs + transformer_outputs[1:]
- return MaskedLMOutput(
- loss=outputs[0] if labels is not None else None,
- logits=outputs[0] if labels is None else outputs[1],
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- XLM Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g.
- for GLUE tasks.
- """
- )
- class XLMForSequenceClassification(XLMPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.config = config
- self.transformer = XLMModel(config)
- self.sequence_summary = XLMSequenceSummary(config)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- langs: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- lengths: torch.Tensor | None = None,
- cache: dict[str, torch.Tensor] | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | SequenceClassifierOutput:
- r"""
- langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
- languages ids which can be obtained from the language names by using two conversion mappings provided in
- the configuration of the model (only provided for multilingual models). More precisely, the *language name
- to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
- *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
- See usage examples detailed in the [multilingual documentation](../multilingual).
- lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Length of each sentence that can be used to avoid performing attention on padding token indices. You can
- also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
- `[0, ..., input_ids.size(-1)]`.
- cache (`dict[str, torch.FloatTensor]`, *optional*):
- Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
- decoding.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- transformer_outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- langs=langs,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- lengths=lengths,
- cache=cache,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- output = transformer_outputs[0]
- logits = self.sequence_summary(output)
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- if not return_dict:
- output = (logits,) + transformer_outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
- layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
- """
- )
- class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.transformer = XLMModel(config)
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- langs: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- lengths: torch.Tensor | None = None,
- cache: dict[str, torch.Tensor] | None = None,
- inputs_embeds: torch.Tensor | None = None,
- start_positions: torch.Tensor | None = None,
- end_positions: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | QuestionAnsweringModelOutput:
- r"""
- langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
- languages ids which can be obtained from the language names by using two conversion mappings provided in
- the configuration of the model (only provided for multilingual models). More precisely, the *language name
- to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
- *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
- See usage examples detailed in the [multilingual documentation](../multilingual).
- lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Length of each sentence that can be used to avoid performing attention on padding token indices. You can
- also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
- `[0, ..., input_ids.size(-1)]`.
- cache (`dict[str, torch.FloatTensor]`, *optional*):
- Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
- decoding.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- transformer_outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- langs=langs,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- lengths=lengths,
- cache=cache,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = transformer_outputs[0]
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- if not return_dict:
- output = (start_logits, end_logits) + transformer_outputs[1:]
- return ((total_loss,) + output) if total_loss is not None else output
- return QuestionAnsweringModelOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- @auto_docstring
- class XLMForQuestionAnswering(XLMPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.transformer = XLMModel(config)
- self.qa_outputs = XLMSQuADHead(config)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- langs: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- lengths: torch.Tensor | None = None,
- cache: dict[str, torch.Tensor] | None = None,
- inputs_embeds: torch.Tensor | None = None,
- start_positions: torch.Tensor | None = None,
- end_positions: torch.Tensor | None = None,
- is_impossible: torch.Tensor | None = None,
- cls_index: torch.Tensor | None = None,
- p_mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | XLMForQuestionAnsweringOutput:
- r"""
- langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
- languages ids which can be obtained from the language names by using two conversion mappings provided in
- the configuration of the model (only provided for multilingual models). More precisely, the *language name
- to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
- *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
- See usage examples detailed in the [multilingual documentation](../multilingual).
- lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Length of each sentence that can be used to avoid performing attention on padding token indices. You can
- also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
- `[0, ..., input_ids.size(-1)]`.
- cache (`dict[str, torch.FloatTensor]`, *optional*):
- Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
- decoding.
- is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels whether a question has an answer or no answer (SQuAD 2.0)
- cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the classification token to use as input for computing plausibility of the
- answer.
- p_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...). 1.0 means token should be
- masked. 0.0 mean token is not masked.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, XLMForQuestionAnswering
- >>> import torch
- >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-mlm-en-2048")
- >>> model = XLMForQuestionAnswering.from_pretrained("FacebookAI/xlm-mlm-en-2048")
- >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(
- ... 0
- ... ) # Batch size 1
- >>> start_positions = torch.tensor([1])
- >>> end_positions = torch.tensor([3])
- >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
- >>> loss = outputs.loss
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- transformer_outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- langs=langs,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- lengths=lengths,
- cache=cache,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- output = transformer_outputs[0]
- outputs = self.qa_outputs(
- output,
- start_positions=start_positions,
- end_positions=end_positions,
- cls_index=cls_index,
- is_impossible=is_impossible,
- p_mask=p_mask,
- return_dict=return_dict,
- )
- if not return_dict:
- return outputs + transformer_outputs[1:]
- return XLMForQuestionAnsweringOutput(
- loss=outputs.loss,
- start_top_log_probs=outputs.start_top_log_probs,
- start_top_index=outputs.start_top_index,
- end_top_log_probs=outputs.end_top_log_probs,
- end_top_index=outputs.end_top_index,
- cls_logits=outputs.cls_logits,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- @auto_docstring
- class XLMForTokenClassification(XLMPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.transformer = XLMModel(config)
- self.dropout = nn.Dropout(config.dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- langs: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- lengths: torch.Tensor | None = None,
- cache: dict[str, torch.Tensor] | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | TokenClassifierOutput:
- r"""
- langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
- languages ids which can be obtained from the language names by using two conversion mappings provided in
- the configuration of the model (only provided for multilingual models). More precisely, the *language name
- to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
- *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
- See usage examples detailed in the [multilingual documentation](../multilingual).
- lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Length of each sentence that can be used to avoid performing attention on padding token indices. You can
- also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
- `[0, ..., input_ids.size(-1)]`.
- cache (`dict[str, torch.FloatTensor]`, *optional*):
- Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
- decoding.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- langs=langs,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- lengths=lengths,
- cache=cache,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- sequence_output = self.dropout(sequence_output)
- logits = self.classifier(sequence_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- if not return_dict:
- output = (logits,) + outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class XLMForMultipleChoice(XLMPreTrainedModel):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.transformer = XLMModel(config)
- self.sequence_summary = XLMSequenceSummary(config)
- self.logits_proj = nn.Linear(config.num_labels, 1)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- langs: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- lengths: torch.Tensor | None = None,
- cache: dict[str, torch.Tensor] | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | MultipleChoiceModelOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- langs (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
- languages ids which can be obtained from the language names by using two conversion mappings provided in
- the configuration of the model (only provided for multilingual models). More precisely, the *language name
- to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
- *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
- See usage examples detailed in the [multilingual documentation](../multilingual).
- token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
- 1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Length of each sentence that can be used to avoid performing attention on padding token indices. You can
- also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
- `[0, ..., input_ids.size(-1)]`.
- cache (`dict[str, torch.FloatTensor]`, *optional*):
- Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
- decoding.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
- num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
- `input_ids` above)
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
- input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
- attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
- token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
- position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
- langs = langs.view(-1, langs.size(-1)) if langs is not None else None
- inputs_embeds = (
- inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
- if inputs_embeds is not None
- else None
- )
- if lengths is not None:
- logger.warning(
- "The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the "
- "attention mask instead."
- )
- lengths = None
- transformer_outputs = self.transformer(
- input_ids=input_ids,
- attention_mask=attention_mask,
- langs=langs,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- lengths=lengths,
- cache=cache,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- output = transformer_outputs[0]
- logits = self.sequence_summary(output)
- logits = self.logits_proj(logits)
- reshaped_logits = logits.view(-1, num_choices)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(reshaped_logits, labels)
- if not return_dict:
- output = (reshaped_logits,) + transformer_outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return MultipleChoiceModelOutput(
- loss=loss,
- logits=reshaped_logits,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- __all__ = [
- "XLMForMultipleChoice",
- "XLMForQuestionAnswering",
- "XLMForQuestionAnsweringSimple",
- "XLMForSequenceClassification",
- "XLMForTokenClassification",
- "XLMModel",
- "XLMPreTrainedModel",
- "XLMWithLMHeadModel",
- ]
|