| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171 |
- # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- PyTorch XLNet model.
- """
- from collections.abc import Callable
- from dataclasses import dataclass
- import torch
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ... import initialization as init
- from ...activations import ACT2FN, get_activation
- from ...generation import GenerationMixin
- from ...modeling_utils import PreTrainedModel
- from ...pytorch_utils import apply_chunking_to_forward
- from ...utils import ModelOutput, auto_docstring, logging
- from .configuration_xlnet import XLNetConfig
- logger = logging.get_logger(__name__)
- class XLNetRelativeAttention(nn.Module):
- def __init__(self, config):
- super().__init__()
- if config.d_model % config.n_head != 0:
- raise ValueError(
- f"The hidden size ({config.d_model}) is not a multiple of the number of attention "
- f"heads ({config.n_head}"
- )
- self.n_head = config.n_head
- self.d_head = config.d_head
- self.d_model = config.d_model
- self.scale = 1 / (config.d_head**0.5)
- self.q = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
- self.k = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
- self.v = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
- self.o = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
- self.r = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
- self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
- self.r_s_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
- self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
- self.seg_embed = nn.Parameter(torch.FloatTensor(2, self.n_head, self.d_head))
- self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.dropout)
- @staticmethod
- def rel_shift(x, klen=-1):
- """perform relative shift to form the relative attention score."""
- x_size = x.shape
- x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
- x = x[1:, ...]
- x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
- # x = x[:, 0:klen, :, :]
- x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))
- return x
- @staticmethod
- def rel_shift_bnij(x, klen=-1):
- x_size = x.shape
- x = x.reshape(x_size[0], x_size[1], x_size[3], x_size[2])
- x = x[:, :, 1:, :]
- x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3] - 1)
- # Note: the tensor-slice form was faster in my testing than torch.index_select
- # However, tracing doesn't like the nature of the slice, and if klen changes
- # during the run then it'll fail, whereas index_select will be fine.
- x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
- # x = x[:, :, :, :klen]
- return x
- def rel_attn_core(
- self,
- q_head,
- k_head_h,
- v_head_h,
- k_head_r,
- seg_mat=None,
- attn_mask=None,
- output_attentions=False,
- ):
- """Core relative positional attention operations."""
- # content based attention score
- ac = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_w_bias, k_head_h)
- # position based attention score
- bd = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_r_bias, k_head_r)
- bd = self.rel_shift_bnij(bd, klen=ac.shape[3])
- # segment based attention score
- if seg_mat is None:
- ef = 0
- else:
- ef = torch.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
- ef = torch.einsum("ijbs,ibns->bnij", seg_mat, ef)
- # merge attention scores and perform masking
- attn_score = (ac + bd + ef) * self.scale
- if attn_mask is not None:
- # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
- if attn_mask.dtype == torch.float16:
- attn_score = attn_score - 65500 * torch.einsum("ijbn->bnij", attn_mask)
- else:
- attn_score = attn_score - 1e30 * torch.einsum("ijbn->bnij", attn_mask)
- # attention probability
- attn_prob = nn.functional.softmax(attn_score, dim=3)
- attn_prob = self.dropout(attn_prob)
- # attention output
- attn_vec = torch.einsum("bnij,jbnd->ibnd", attn_prob, v_head_h)
- if output_attentions:
- return attn_vec, torch.einsum("bnij->ijbn", attn_prob)
- return attn_vec
- def post_attention(self, h, attn_vec, residual=True):
- """Post-attention processing."""
- # post-attention projection (back to `d_model`)
- attn_out = torch.einsum("ibnd,hnd->ibh", attn_vec, self.o)
- attn_out = self.dropout(attn_out)
- if residual:
- attn_out = attn_out + h
- output = self.layer_norm(attn_out)
- return output
- def forward(
- self,
- h,
- g,
- attn_mask_h,
- attn_mask_g,
- r,
- seg_mat,
- mems=None,
- target_mapping=None,
- output_attentions=False,
- ):
- if g is not None:
- # Two-stream attention with relative positional encoding.
- # content based attention score
- if mems is not None and mems.dim() > 1:
- cat = torch.cat([mems, h], dim=0)
- else:
- cat = h
- # content-based key head
- k_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.k)
- # content-based value head
- v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v)
- # position-based key head
- k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
- # h-stream
- # content-stream query head
- q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
- # core attention ops
- attn_vec_h = self.rel_attn_core(
- q_head_h,
- k_head_h,
- v_head_h,
- k_head_r,
- seg_mat=seg_mat,
- attn_mask=attn_mask_h,
- output_attentions=output_attentions,
- )
- if output_attentions:
- attn_vec_h, attn_prob_h = attn_vec_h
- # post processing
- output_h = self.post_attention(h, attn_vec_h)
- # g-stream
- # query-stream query head
- q_head_g = torch.einsum("ibh,hnd->ibnd", g, self.q)
- # core attention ops
- if target_mapping is not None:
- q_head_g = torch.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
- attn_vec_g = self.rel_attn_core(
- q_head_g,
- k_head_h,
- v_head_h,
- k_head_r,
- seg_mat=seg_mat,
- attn_mask=attn_mask_g,
- output_attentions=output_attentions,
- )
- if output_attentions:
- attn_vec_g, attn_prob_g = attn_vec_g
- attn_vec_g = torch.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
- else:
- attn_vec_g = self.rel_attn_core(
- q_head_g,
- k_head_h,
- v_head_h,
- k_head_r,
- seg_mat=seg_mat,
- attn_mask=attn_mask_g,
- output_attentions=output_attentions,
- )
- if output_attentions:
- attn_vec_g, attn_prob_g = attn_vec_g
- # post processing
- output_g = self.post_attention(g, attn_vec_g)
- if output_attentions:
- attn_prob = attn_prob_h, attn_prob_g
- else:
- # Multi-head attention with relative positional encoding
- if mems is not None and mems.dim() > 1:
- cat = torch.cat([mems, h], dim=0)
- else:
- cat = h
- # content heads
- q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
- k_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.k)
- v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v)
- # positional heads
- # type casting for fp16 support
- k_head_r = torch.einsum("ibh,hnd->ibnd", r.type(self.r.dtype), self.r)
- # core attention ops
- attn_vec = self.rel_attn_core(
- q_head_h,
- k_head_h,
- v_head_h,
- k_head_r,
- seg_mat=seg_mat,
- attn_mask=attn_mask_h,
- output_attentions=output_attentions,
- )
- if output_attentions:
- attn_vec, attn_prob = attn_vec
- # post processing
- output_h = self.post_attention(h, attn_vec)
- output_g = None
- outputs = (output_h, output_g)
- if output_attentions:
- outputs = outputs + (attn_prob,)
- return outputs
- class XLNetFeedForward(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
- self.layer_1 = nn.Linear(config.d_model, config.d_inner)
- self.layer_2 = nn.Linear(config.d_inner, config.d_model)
- self.dropout = nn.Dropout(config.dropout)
- if isinstance(config.ff_activation, str):
- self.activation_function = ACT2FN[config.ff_activation]
- else:
- self.activation_function = config.ff_activation
- def forward(self, inp):
- output = inp
- output = self.layer_1(output)
- output = self.activation_function(output)
- output = self.dropout(output)
- output = self.layer_2(output)
- output = self.dropout(output)
- output = self.layer_norm(output + inp)
- return output
- class XLNetLayer(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.rel_attn = XLNetRelativeAttention(config)
- self.ff = XLNetFeedForward(config)
- self.dropout = nn.Dropout(config.dropout)
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- def forward(
- self,
- output_h,
- output_g,
- attn_mask_h,
- attn_mask_g,
- r,
- seg_mat,
- mems=None,
- target_mapping=None,
- output_attentions=False,
- ):
- outputs = self.rel_attn(
- output_h,
- output_g,
- attn_mask_h,
- attn_mask_g,
- r,
- seg_mat,
- mems=mems,
- target_mapping=target_mapping,
- output_attentions=output_attentions,
- )
- output_h, output_g = outputs[:2]
- if output_g is not None:
- output_g = apply_chunking_to_forward(
- self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_g
- )
- output_h = apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_h)
- outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there
- return outputs
- def ff_chunk(self, output_x):
- output_x = self.ff(output_x)
- return output_x
- # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerStartLogits with XLM->XLNet
- class XLNetPoolerStartLogits(nn.Module):
- """
- Compute SQuAD start logits from sequence hidden states.
- Args:
- config ([`XLNetConfig`]):
- The config used by the model, will be used to grab the `hidden_size` of the model.
- """
- def __init__(self, config: XLNetConfig):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, 1)
- def forward(self, hidden_states: torch.FloatTensor, p_mask: torch.FloatTensor | None = None) -> torch.FloatTensor:
- """
- Args:
- hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
- The final hidden states of the model.
- p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
- Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
- should be masked.
- Returns:
- `torch.FloatTensor`: The start logits for SQuAD.
- """
- x = self.dense(hidden_states).squeeze(-1)
- if p_mask is not None:
- if p_mask.dtype == torch.float16:
- x = x * (1 - p_mask) - 65500 * p_mask
- else:
- x = x * (1 - p_mask) - 1e30 * p_mask
- return x
- # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerEndLogits with XLM->XLNet
- class XLNetPoolerEndLogits(nn.Module):
- """
- Compute SQuAD end logits from sequence hidden states.
- Args:
- config ([`XLNetConfig`]):
- The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
- to use.
- """
- def __init__(self, config: XLNetConfig):
- super().__init__()
- self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
- self.activation = nn.Tanh()
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dense_1 = nn.Linear(config.hidden_size, 1)
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- start_states: torch.FloatTensor | None = None,
- start_positions: torch.LongTensor | None = None,
- p_mask: torch.FloatTensor | None = None,
- ) -> torch.FloatTensor:
- """
- Args:
- hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
- The final hidden states of the model.
- start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
- The hidden states of the first tokens for the labeled span.
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- The position of the first token for the labeled span.
- p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
- Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
- should be masked.
- <Tip>
- One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
- `start_states`.
- </Tip>
- Returns:
- `torch.FloatTensor`: The end logits for SQuAD.
- """
- assert start_states is not None or start_positions is not None, (
- "One of start_states, start_positions should be not None"
- )
- if start_positions is not None:
- slen, hsz = hidden_states.shape[-2:]
- start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
- start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
- start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
- x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
- x = self.activation(x)
- x = self.LayerNorm(x)
- x = self.dense_1(x).squeeze(-1)
- if p_mask is not None:
- if p_mask.dtype == torch.float16:
- x = x * (1 - p_mask) - 65500 * p_mask
- else:
- x = x * (1 - p_mask) - 1e30 * p_mask
- return x
- # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerAnswerClass with XLM->XLNet
- class XLNetPoolerAnswerClass(nn.Module):
- """
- Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
- Args:
- config ([`XLNetConfig`]):
- The config used by the model, will be used to grab the `hidden_size` of the model.
- """
- def __init__(self, config: XLNetConfig):
- super().__init__()
- self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
- self.activation = nn.Tanh()
- self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- start_states: torch.FloatTensor | None = None,
- start_positions: torch.LongTensor | None = None,
- cls_index: torch.LongTensor | None = None,
- ) -> torch.FloatTensor:
- """
- Args:
- hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
- The final hidden states of the model.
- start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
- The hidden states of the first tokens for the labeled span.
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- The position of the first token for the labeled span.
- cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
- <Tip>
- One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
- `start_states`.
- </Tip>
- Returns:
- `torch.FloatTensor`: The SQuAD 2.0 answer class.
- """
- # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
- hsz = hidden_states.shape[-1]
- assert start_states is not None or start_positions is not None, (
- "One of start_states, start_positions should be not None"
- )
- if start_positions is not None:
- start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
- start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
- if cls_index is not None:
- cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
- cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
- else:
- cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
- x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
- x = self.activation(x)
- x = self.dense_1(x).squeeze(-1)
- return x
- # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->XLNet
- class XLNetSequenceSummary(nn.Module):
- r"""
- Compute a single vector summary of a sequence hidden states.
- Args:
- config ([`XLNetConfig`]):
- The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
- config class of your model for the default values it uses):
- - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
- - `"last"` -- Take the last token hidden state (like XLNet)
- - `"first"` -- Take the first token hidden state (like Bert)
- - `"mean"` -- Take the mean of all tokens hidden states
- - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
- - `"attn"` -- Not implemented now, use multi-head attention
- - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
- - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
- (otherwise to `config.hidden_size`).
- - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
- another string or `None` will add no activation.
- - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
- - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
- """
- def __init__(self, config: XLNetConfig):
- super().__init__()
- self.summary_type = getattr(config, "summary_type", "last")
- if self.summary_type == "attn":
- # We should use a standard multi-head attention module with absolute positional embedding for that.
- # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
- # We can probably just use the multi-head attention module of PyTorch >=1.1.0
- raise NotImplementedError
- self.summary = nn.Identity()
- if hasattr(config, "summary_use_proj") and config.summary_use_proj:
- if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
- num_classes = config.num_labels
- else:
- num_classes = config.hidden_size
- self.summary = nn.Linear(config.hidden_size, num_classes)
- activation_string = getattr(config, "summary_activation", None)
- self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
- self.first_dropout = nn.Identity()
- if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
- self.first_dropout = nn.Dropout(config.summary_first_dropout)
- self.last_dropout = nn.Identity()
- if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
- self.last_dropout = nn.Dropout(config.summary_last_dropout)
- def forward(
- self, hidden_states: torch.FloatTensor, cls_index: torch.LongTensor | None = None
- ) -> torch.FloatTensor:
- """
- Compute a single vector summary of a sequence hidden states.
- Args:
- hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
- The hidden states of the last layer.
- cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
- Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
- Returns:
- `torch.FloatTensor`: The summary of the sequence hidden states.
- """
- if self.summary_type == "last":
- output = hidden_states[:, -1]
- elif self.summary_type == "first":
- output = hidden_states[:, 0]
- elif self.summary_type == "mean":
- output = hidden_states.mean(dim=1)
- elif self.summary_type == "cls_index":
- if cls_index is None:
- cls_index = torch.full_like(
- hidden_states[..., :1, :],
- hidden_states.shape[-2] - 1,
- dtype=torch.long,
- )
- else:
- cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
- cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
- # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
- output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
- elif self.summary_type == "attn":
- raise NotImplementedError
- output = self.first_dropout(output)
- output = self.summary(output)
- output = self.activation(output)
- output = self.last_dropout(output)
- return output
- @auto_docstring
- class XLNetPreTrainedModel(PreTrainedModel):
- config: XLNetConfig
- base_model_prefix = "transformer"
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights."""
- super()._init_weights(module)
- if isinstance(module, XLNetRelativeAttention):
- for param in [
- module.q,
- module.k,
- module.v,
- module.o,
- module.r,
- module.r_r_bias,
- module.r_s_bias,
- module.r_w_bias,
- module.seg_embed,
- ]:
- init.normal_(param, mean=0.0, std=self.config.initializer_range)
- elif isinstance(module, XLNetModel):
- init.normal_(module.mask_emb, mean=0.0, std=self.config.initializer_range)
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`XLNetModel`].
- """
- )
- class XLNetModelOutput(ModelOutput):
- r"""
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_predict, hidden_size)`):
- Sequence of hidden-states at the last layer of the model.
- `num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict`
- corresponds to `sequence_length`.
- mems (`list[torch.FloatTensor]` of length `config.n_layers`):
- Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The
- token ids which have their past given to this model should not be passed as `input_ids` as they have
- already been computed.
- """
- last_hidden_state: torch.FloatTensor
- mems: list[torch.FloatTensor] | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`XLNetLMHeadModel`].
- """
- )
- class XLNetLMHeadModelOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- logits (`torch.FloatTensor` of shape `(batch_size, num_predict, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- `num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict`
- corresponds to `sequence_length`.
- mems (`list[torch.FloatTensor]` of length `config.n_layers`):
- Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The
- token ids which have their past given to this model should not be passed as `input_ids` as they have
- already been computed.
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor | None = None
- mems: list[torch.FloatTensor] | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`XLNetForSequenceClassification`].
- """
- )
- class XLNetForSequenceClassificationOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `label` is provided):
- Classification (or regression if config.num_labels==1) loss.
- logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
- Classification (or regression if config.num_labels==1) scores (before SoftMax).
- mems (`list[torch.FloatTensor]` of length `config.n_layers`):
- Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The
- token ids which have their past given to this model should not be passed as `input_ids` as they have
- already been computed.
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor | None = None
- mems: list[torch.FloatTensor] | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`XLNetForTokenClassificationOutput`].
- """
- )
- class XLNetForTokenClassificationOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Classification loss.
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
- Classification scores (before SoftMax).
- mems (`list[torch.FloatTensor]` of length `config.n_layers`):
- Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The
- token ids which have their past given to this model should not be passed as `input_ids` as they have
- already been computed.
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor | None = None
- mems: list[torch.FloatTensor] | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`XLNetForMultipleChoice`].
- """
- )
- class XLNetForMultipleChoiceOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
- Classification loss.
- logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
- *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
- Classification scores (before SoftMax).
- mems (`list[torch.FloatTensor]` of length `config.n_layers`):
- Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The
- token ids which have their past given to this model should not be passed as `input_ids` as they have
- already been computed.
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor | None = None
- mems: list[torch.FloatTensor] | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`XLNetForQuestionAnsweringSimple`].
- """
- )
- class XLNetForQuestionAnsweringSimpleOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
- start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length,)`):
- Span-start scores (before SoftMax).
- end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length,)`):
- Span-end scores (before SoftMax).
- mems (`list[torch.FloatTensor]` of length `config.n_layers`):
- Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The
- token ids which have their past given to this model should not be passed as `input_ids` as they have
- already been computed.
- """
- loss: torch.FloatTensor | None = None
- start_logits: torch.FloatTensor | None = None
- end_logits: torch.FloatTensor | None = None
- mems: list[torch.FloatTensor] | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`XLNetForQuestionAnswering`].
- """
- )
- class XLNetForQuestionAnsweringOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
- Classification loss as the sum of start token, end token (and is_impossible if provided) classification
- losses.
- start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Log probabilities for the top config.start_n_top start token possibilities (beam-search).
- start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Indices for the top config.start_n_top start token possibilities (beam-search).
- end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
- (beam-search).
- end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
- cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
- Log probabilities for the `is_impossible` label of the answers.
- mems (`list[torch.FloatTensor]` of length `config.n_layers`):
- Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The
- token ids which have their past given to this model should not be passed as `input_ids` as they have
- already been computed.
- """
- loss: torch.FloatTensor | None = None
- start_top_log_probs: torch.FloatTensor | None = None
- start_top_index: torch.LongTensor | None = None
- end_top_log_probs: torch.FloatTensor | None = None
- end_top_index: torch.LongTensor | None = None
- cls_logits: torch.FloatTensor | None = None
- mems: list[torch.FloatTensor] | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- @auto_docstring
- class XLNetModel(XLNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.mem_len = config.mem_len
- self.reuse_len = config.reuse_len
- self.d_model = config.d_model
- self.same_length = config.same_length
- self.attn_type = config.attn_type
- self.bi_data = config.bi_data
- self.clamp_len = config.clamp_len
- self.n_layer = config.n_layer
- self.word_embedding = nn.Embedding(config.vocab_size, config.d_model)
- self.mask_emb = nn.Parameter(torch.FloatTensor(1, 1, config.d_model))
- self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)])
- self.dropout = nn.Dropout(config.dropout)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.word_embedding
- def set_input_embeddings(self, new_embeddings):
- self.word_embedding = new_embeddings
- def create_mask(self, qlen, mlen):
- """
- Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.
- Args:
- qlen: Sequence length
- mlen: Mask length
- ::
- same_length=False: same_length=True: <mlen > < qlen > <mlen > < qlen >
- ^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1]
- [0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1]
- qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1]
- [0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1]
- v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
- """
- mask = torch.ones((qlen, qlen + mlen), device=self.device)
- if self.same_length:
- mask_lo = mask[:, :qlen].tril(-1)
- mask.triu_(mlen + 1)
- mask[:, :qlen] += mask_lo
- else:
- mask.triu_(mlen + 1)
- return mask
- def cache_mem(self, curr_out, prev_mem):
- # cache hidden states into memory.
- if self.reuse_len is not None and self.reuse_len > 0:
- curr_out = curr_out[: self.reuse_len]
- if self.mem_len is None or self.mem_len == 0:
- # If `use_mems` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time
- # and returns all of the past and current hidden states.
- cutoff = 0
- else:
- # If `use_mems` is active and `mem_len` is defined, the model returns the last `mem_len` hidden
- # states. This is the preferred setting for training and long-form generation.
- cutoff = -self.mem_len
- if prev_mem is None:
- # if `use_mems` is active and `mem_len` is defined, the model
- new_mem = curr_out[cutoff:]
- else:
- new_mem = torch.cat([prev_mem, curr_out], dim=0)[cutoff:]
- return new_mem.detach()
- @staticmethod
- def positional_embedding(pos_seq, inv_freq, bsz=None):
- sinusoid_inp = torch.einsum("i,d->id", pos_seq, inv_freq)
- pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
- pos_emb = pos_emb[:, None, :]
- if bsz is not None:
- pos_emb = pos_emb.expand(-1, bsz, -1)
- return pos_emb
- def relative_positional_encoding(self, qlen, klen, bsz=None, device=None):
- # create relative positional encoding.
- freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.int64, device=device).float()
- inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))
- if self.attn_type == "bi":
- # beg, end = klen - 1, -qlen
- beg, end = klen, -qlen
- elif self.attn_type == "uni":
- # beg, end = klen - 1, -1
- beg, end = klen, -1
- else:
- raise ValueError(f"Unknown `attn_type` {self.attn_type}.")
- if self.bi_data:
- fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64, device=device).float()
- bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.int64, device=device).float()
- if self.clamp_len > 0:
- fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
- bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
- if bsz is not None:
- fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
- bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
- else:
- fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
- bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
- pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
- else:
- fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64, device=device).float()
- if self.clamp_len > 0:
- fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
- pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
- return pos_emb
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- mems: torch.Tensor | None = None,
- perm_mask: torch.Tensor | None = None,
- target_mapping: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- input_mask: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- use_mems: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs, # delete after depreciation warning is removed
- ) -> tuple | XLNetModelOutput:
- r"""
- mems (`list[torch.FloatTensor]` of length `config.n_layers`):
- Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential
- decoding. The token ids which have their past given to this model should not be passed as `input_ids` as
- they have already been computed.
- `use_mems` has to be set to `True` to make use of `mems`.
- perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*):
- Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`:
- - if `perm_mask[k, i, j] = 0`, i attend to j in batch k;
- - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k.
- If not set, each token attends to all the others (full bidirectional attention). Only used during
- pretraining (to define factorization order) or for sequential decoding (generation).
- target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*):
- Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is
- on the j-th token. Only used during pretraining for partial prediction or for sequential decoding
- (generation).
- input_mask (`torch.FloatTensor` of shape `batch_size, sequence_length`, *optional*):
- Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for
- real tokens and 1 for padding which is kept for compatibility with the original code base.
- Mask values selected in `[0, 1]`:
- - 1 for tokens that are **masked**,
- - 0 for tokens that are **not masked**.
- You can only uses one of `input_mask` and `attention_mask`.
- use_mems (`bool`, *optional*):
- Whether to use memory states to speed up sequential decoding. If set to `True`, the model will use the hidden
- states from previous forward passes to compute attention, which can significantly improve performance for
- sequential decoding tasks.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if self.training:
- use_mems = use_mems if use_mems is not None else self.config.use_mems_train
- else:
- use_mems = use_mems if use_mems is not None else self.config.use_mems_eval
- # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
- # but we want a unified interface in the library with the batch size on the first dimension
- # so we move here the first dimension (batch) to the end
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- input_ids = input_ids.transpose(0, 1).contiguous()
- qlen, bsz = input_ids.shape[0], input_ids.shape[1]
- elif inputs_embeds is not None:
- inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
- qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
- input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
- attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
- perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
- target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
- mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
- klen = mlen + qlen
- dtype_float = self.dtype
- device = self.device
- # Attention mask
- # causal attention mask
- if self.attn_type == "uni":
- attn_mask = self.create_mask(qlen, mlen)
- attn_mask = attn_mask[:, :, None, None]
- elif self.attn_type == "bi":
- attn_mask = None
- else:
- raise ValueError(f"Unsupported attention type: {self.attn_type}")
- # data mask: input mask & perm mask
- assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) "
- "or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
- if input_mask is None and attention_mask is not None:
- input_mask = 1.0 - attention_mask
- if input_mask is not None and perm_mask is not None:
- data_mask = input_mask[None] + perm_mask
- elif input_mask is not None and perm_mask is None:
- data_mask = input_mask[None]
- elif input_mask is None and perm_mask is not None:
- data_mask = perm_mask
- else:
- data_mask = None
- if data_mask is not None:
- # all mems can be attended to
- if mlen > 0:
- mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
- data_mask = torch.cat([mems_mask, data_mask], dim=1)
- if attn_mask is None:
- attn_mask = data_mask[:, :, :, None]
- else:
- attn_mask += data_mask[:, :, :, None]
- if attn_mask is not None:
- attn_mask = (attn_mask > 0).to(dtype_float)
- if attn_mask is not None:
- non_tgt_mask = -torch.eye(qlen).to(attn_mask)
- if mlen > 0:
- non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
- non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask)
- else:
- non_tgt_mask = None
- # Word embeddings and prepare h & g hidden states
- if inputs_embeds is not None:
- word_emb_k = inputs_embeds
- else:
- word_emb_k = self.word_embedding(input_ids)
- output_h = self.dropout(word_emb_k)
- if target_mapping is not None:
- word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
- # else: # We removed the inp_q input which was same as target mapping
- # inp_q_ext = inp_q[:, :, None]
- # word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
- output_g = self.dropout(word_emb_q)
- else:
- output_g = None
- # Segment embedding
- if token_type_ids is not None:
- # Convert `token_type_ids` to one-hot `seg_mat`
- if mlen > 0:
- mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
- cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
- else:
- cat_ids = token_type_ids
- # `1` indicates not in the same segment [qlen x klen x bsz]
- seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
- seg_mat = nn.functional.one_hot(seg_mat, num_classes=2).to(dtype_float)
- else:
- seg_mat = None
- # Positional encoding
- pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, device=output_h.device)
- pos_emb = self.dropout(pos_emb)
- new_mems = ()
- if mems is None:
- mems = [None] * len(self.layer)
- attentions = [] if output_attentions else None
- hidden_states = [] if output_hidden_states else None
- for i, layer_module in enumerate(self.layer):
- if use_mems:
- # cache new mems
- new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
- if output_hidden_states:
- hidden_states.append((output_h, output_g) if output_g is not None else output_h)
- outputs = layer_module(
- output_h,
- output_g,
- attn_mask_h=non_tgt_mask,
- attn_mask_g=attn_mask,
- r=pos_emb,
- seg_mat=seg_mat,
- mems=mems[i],
- target_mapping=target_mapping,
- output_attentions=output_attentions,
- )
- output_h, output_g = outputs[:2]
- if output_attentions:
- attentions.append(outputs[2])
- # Add last hidden state
- if output_hidden_states:
- hidden_states.append((output_h, output_g) if output_g is not None else output_h)
- output = self.dropout(output_g if output_g is not None else output_h)
- # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
- output = output.permute(1, 0, 2).contiguous()
- if not use_mems:
- new_mems = None
- if output_hidden_states:
- if output_g is not None:
- hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
- else:
- hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
- if output_attentions:
- if target_mapping is not None:
- # when target_mapping is provided, there are 2-tuple of attentions
- attentions = tuple(
- tuple(att_stream.permute(2, 3, 0, 1).contiguous() for att_stream in t) for t in attentions
- )
- else:
- attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
- if not return_dict:
- return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)
- return XLNetModelOutput(
- last_hidden_state=output, mems=new_mems, hidden_states=hidden_states, attentions=attentions
- )
- @auto_docstring(
- custom_intro="""
- XLNet Model with a language modeling head on top (linear layer with weights tied to the input embeddings).
- """
- )
- class XLNetLMHeadModel(XLNetPreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"lm_loss.weight": "transformer.word_embedding.weight"}
- def __init__(self, config):
- super().__init__(config)
- self.attn_type = config.attn_type
- self.same_length = config.same_length
- self.transformer = XLNetModel(config)
- self.lm_loss = nn.Linear(config.d_model, config.vocab_size, bias=True)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.lm_loss
- def set_output_embeddings(self, new_embeddings):
- self.lm_loss = new_embeddings
- def prepare_inputs_for_generation(
- self, input_ids, past_key_values=None, use_mems=None, is_first_iteration=False, **kwargs
- ):
- # Overwritten -- this model has unique input preparation
- # Add dummy token at the end (no attention on this one)
- effective_batch_size = input_ids.shape[0]
- dummy_token = torch.zeros((effective_batch_size, 1), dtype=torch.long, device=input_ids.device)
- # At every pass, the attention values for the new token and the two last generated tokens
- # are computed, the rest is reloaded from the `past` cache. A purely auto-regressive model would have
- # offset = 1; offset = 2 seems to have slightly better computation.
- offset = 2
- if past_key_values:
- input_ids = torch.cat([input_ids[:, -offset:], dummy_token], dim=1)
- else:
- input_ids = torch.cat([input_ids, dummy_token], dim=1)
- # Build permutation mask so that previous tokens don't see last token
- sequence_length = input_ids.shape[1]
- perm_mask = torch.zeros(
- (effective_batch_size, sequence_length, sequence_length), dtype=torch.float, device=input_ids.device
- )
- perm_mask[:, :, -1] = 1.0
- # We'll only predict the last token
- target_mapping = torch.zeros(
- (effective_batch_size, 1, sequence_length), dtype=torch.float, device=input_ids.device
- )
- target_mapping[:, 0, -1] = 1.0
- model_inputs = {
- "input_ids": input_ids,
- "perm_mask": perm_mask,
- "target_mapping": target_mapping,
- "use_mems": use_mems,
- }
- # if past is defined in model kwargs then use it for faster decoding
- if past_key_values:
- model_inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past_key_values)
- # Attention mask is computed on the fly on XLNetModel.forward()
- kwargs.pop("attention_mask", None)
- # TODO: Ignoring use_cache should not happen, fixme.
- kwargs.pop("use_cache", None)
- # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
- for key, value in kwargs.items():
- if key not in model_inputs:
- model_inputs[key] = value
- return model_inputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- mems: torch.Tensor | None = None,
- perm_mask: torch.Tensor | None = None,
- target_mapping: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- input_mask: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- use_mems: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs, # delete when `use_cache` is removed in XLNetModel
- ) -> tuple | XLNetLMHeadModelOutput:
- r"""
- mems (`list[torch.FloatTensor]` of length `config.n_layers`):
- Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential
- decoding. The token ids which have their past given to this model should not be passed as `input_ids` as
- they have already been computed.
- `use_mems` has to be set to `True` to make use of `mems`.
- perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*):
- Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`:
- - if `perm_mask[k, i, j] = 0`, i attend to j in batch k;
- - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k.
- If not set, each token attends to all the others (full bidirectional attention). Only used during
- pretraining (to define factorization order) or for sequential decoding (generation).
- target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*):
- Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is
- on the j-th token. Only used during pretraining for partial prediction or for sequential decoding
- (generation).
- input_mask (`torch.FloatTensor` of shape `batch_size, sequence_length`, *optional*):
- Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for
- real tokens and 1 for padding which is kept for compatibility with the original code base.
- Mask values selected in `[0, 1]`:
- - 1 for tokens that are **masked**,
- - 0 for tokens that are **not masked**.
- You can only uses one of `input_mask` and `attention_mask`.
- labels (`torch.LongTensor` of shape `(batch_size, num_predict)`, *optional*):
- Labels for masked language modeling. `num_predict` corresponds to `target_mapping.shape[1]`. If
- `target_mapping` is `None`, then `num_predict` corresponds to `sequence_length`.
- The labels should correspond to the masked input words that should be predicted and depends on
- `target_mapping`. Note in order to perform standard auto-regressive language modeling a *<mask>* token has
- to be added to the `input_ids` (see the `prepare_inputs_for_generation` function and examples below)
- Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored, the loss
- is only computed for labels in `[0, ..., config.vocab_size]`
- use_mems (`bool`, *optional*):
- Whether to use memory states to speed up sequential decoding. If set to `True`, the model will use the hidden
- states from previous forward passes to compute attention, which can significantly improve performance for
- sequential decoding tasks.
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, XLNetLMHeadModel
- >>> import torch
- >>> tokenizer = AutoTokenizer.from_pretrained("xlnet/xlnet-large-cased")
- >>> model = XLNetLMHeadModel.from_pretrained("xlnet/xlnet-large-cased")
- >>> # We show how to setup inputs to predict a next token using a bi-directional context.
- >>> input_ids = torch.tensor(
- ... tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=False)
- ... ).unsqueeze(
- ... 0
- ... ) # We will predict the masked token
- >>> perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
- >>> perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
- >>> target_mapping = torch.zeros(
- ... (1, 1, input_ids.shape[1]), dtype=torch.float
- ... ) # Shape [1, 1, seq_length] => let's predict one token
- >>> target_mapping[
- ... 0, 0, -1
- ... ] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
- >>> outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
- >>> next_token_logits = outputs[
- ... 0
- ... ] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
- >>> # The same way can the XLNetLMHeadModel be used to be trained by standard auto-regressive language modeling.
- >>> input_ids = torch.tensor(
- ... tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=False)
- ... ).unsqueeze(
- ... 0
- ... ) # We will predict the masked token
- >>> labels = torch.tensor(tokenizer.encode("cute", add_special_tokens=False)).unsqueeze(0)
- >>> assert labels.shape[0] == 1, "only one word will be predicted"
- >>> perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
- >>> perm_mask[
- ... :, :, -1
- ... ] = 1.0 # Previous tokens don't see last token as is done in standard auto-regressive lm training
- >>> target_mapping = torch.zeros(
- ... (1, 1, input_ids.shape[1]), dtype=torch.float
- ... ) # Shape [1, 1, seq_length] => let's predict one token
- >>> target_mapping[
- ... 0, 0, -1
- ... ] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
- >>> outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping, labels=labels)
- >>> loss = outputs.loss
- >>> next_token_logits = (
- ... outputs.logits
- ... ) # Logits have shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- transformer_outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- mems=mems,
- perm_mask=perm_mask,
- target_mapping=target_mapping,
- token_type_ids=token_type_ids,
- input_mask=input_mask,
- inputs_embeds=inputs_embeds,
- use_mems=use_mems,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- **kwargs,
- )
- hidden_states = transformer_outputs[0]
- # Only compute necessary logits
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_loss(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
- if not return_dict:
- output = (logits,) + transformer_outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return XLNetLMHeadModelOutput(
- loss=loss,
- logits=logits,
- mems=transformer_outputs.mems,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- @staticmethod
- def _reorder_cache(mems: list[torch.Tensor], beam_idx: torch.Tensor) -> list[torch.Tensor]:
- """
- This function is used to re-order the `mems` cache if [`~PreTrainedModel.beam_search`] or
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `mems` with the correct beam_idx at every
- generation step.
- """
- return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]
- @auto_docstring(
- custom_intro="""
- XLNet Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g.
- for GLUE tasks.
- """
- )
- class XLNetForSequenceClassification(XLNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.config = config
- self.transformer = XLNetModel(config)
- self.sequence_summary = XLNetSequenceSummary(config)
- self.logits_proj = nn.Linear(config.d_model, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- mems: torch.Tensor | None = None,
- perm_mask: torch.Tensor | None = None,
- target_mapping: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- input_mask: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- use_mems: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs, # delete when `use_cache` is removed in XLNetModel
- ) -> tuple | XLNetForSequenceClassificationOutput:
- r"""
- mems (`list[torch.FloatTensor]` of length `config.n_layers`):
- Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential
- decoding. The token ids which have their past given to this model should not be passed as `input_ids` as
- they have already been computed.
- `use_mems` has to be set to `True` to make use of `mems`.
- perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*):
- Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`:
- - if `perm_mask[k, i, j] = 0`, i attend to j in batch k;
- - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k.
- If not set, each token attends to all the others (full bidirectional attention). Only used during
- pretraining (to define factorization order) or for sequential decoding (generation).
- target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*):
- Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is
- on the j-th token. Only used during pretraining for partial prediction or for sequential decoding
- (generation).
- input_mask (`torch.FloatTensor` of shape `batch_size, sequence_length`, *optional*):
- Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for
- real tokens and 1 for padding which is kept for compatibility with the original code base.
- Mask values selected in `[0, 1]`:
- - 1 for tokens that are **masked**,
- - 0 for tokens that are **not masked**.
- You can only uses one of `input_mask` and `attention_mask`.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- use_mems (`bool`, *optional*):
- Whether to use memory states to speed up sequential decoding. If set to `True`, the model will use the hidden
- states from previous forward passes to compute attention, which can significantly improve performance for
- sequential decoding tasks.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- transformer_outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- mems=mems,
- perm_mask=perm_mask,
- target_mapping=target_mapping,
- token_type_ids=token_type_ids,
- input_mask=input_mask,
- inputs_embeds=inputs_embeds,
- use_mems=use_mems,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- **kwargs,
- )
- output = transformer_outputs[0]
- output = self.sequence_summary(output)
- logits = self.logits_proj(output)
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- if not return_dict:
- output = (logits,) + transformer_outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return XLNetForSequenceClassificationOutput(
- loss=loss,
- logits=logits,
- mems=transformer_outputs.mems,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- @auto_docstring
- class XLNetForTokenClassification(XLNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.transformer = XLNetModel(config)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- mems: torch.Tensor | None = None,
- perm_mask: torch.Tensor | None = None,
- target_mapping: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- input_mask: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- use_mems: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs, # delete when `use_cache` is removed in XLNetModel
- ) -> tuple | XLNetForTokenClassificationOutput:
- r"""
- mems (`list[torch.FloatTensor]` of length `config.n_layers`):
- Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential
- decoding. The token ids which have their past given to this model should not be passed as `input_ids` as
- they have already been computed.
- `use_mems` has to be set to `True` to make use of `mems`.
- perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*):
- Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`:
- - if `perm_mask[k, i, j] = 0`, i attend to j in batch k;
- - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k.
- If not set, each token attends to all the others (full bidirectional attention). Only used during
- pretraining (to define factorization order) or for sequential decoding (generation).
- target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*):
- Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is
- on the j-th token. Only used during pretraining for partial prediction or for sequential decoding
- (generation).
- input_mask (`torch.FloatTensor` of shape `batch_size, sequence_length`, *optional*):
- Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for
- real tokens and 1 for padding which is kept for compatibility with the original code base.
- Mask values selected in `[0, 1]`:
- - 1 for tokens that are **masked**,
- - 0 for tokens that are **not masked**.
- You can only uses one of `input_mask` and `attention_mask`.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
- where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
- use_mems (`bool`, *optional*):
- Whether to use memory states to speed up sequential decoding. If set to `True`, the model will use the hidden
- states from previous forward passes to compute attention, which can significantly improve performance for
- sequential decoding tasks.emory states to speed up sequential decoding. If set to `True`, the model will use the hidden
- states from previous forward passes to compute attention, which can significantly improve performance for
- sequential decoding tasks.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- mems=mems,
- perm_mask=perm_mask,
- target_mapping=target_mapping,
- token_type_ids=token_type_ids,
- input_mask=input_mask,
- inputs_embeds=inputs_embeds,
- use_mems=use_mems,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- logits = self.classifier(sequence_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- if not return_dict:
- output = (logits,) + outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return XLNetForTokenClassificationOutput(
- loss=loss,
- logits=logits,
- mems=outputs.mems,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class XLNetForMultipleChoice(XLNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.transformer = XLNetModel(config)
- self.sequence_summary = XLNetSequenceSummary(config)
- self.logits_proj = nn.Linear(config.d_model, 1)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- input_mask: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- mems: torch.Tensor | None = None,
- perm_mask: torch.Tensor | None = None,
- target_mapping: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- use_mems: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs, # delete when `use_cache` is removed in XLNetModel
- ) -> tuple | XLNetForMultipleChoiceOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
- 1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- [What are token type IDs?](../glossary#token-type-ids)
- input_mask (`torch.FloatTensor` of shape `batch_size, num_choices, sequence_length`, *optional*):
- Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for
- real tokens and 1 for padding which is kept for compatibility with the original code base.
- Mask values selected in `[0, 1]`:
- - 1 for tokens that are **masked**,
- - 0 for tokens that are **not masked**.
- You can only uses one of `input_mask` and `attention_mask`.
- mems (`list[torch.FloatTensor]` of length `config.n_layers`):
- Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential
- decoding. The token ids which have their past given to this model should not be passed as `input_ids` as
- they have already been computed.
- `use_mems` has to be set to `True` to make use of `mems`.
- perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*):
- Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`:
- - if `perm_mask[k, i, j] = 0`, i attend to j in batch k;
- - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k.
- If not set, each token attends to all the others (full bidirectional attention). Only used during
- pretraining (to define factorization order) or for sequential decoding (generation).
- target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*):
- Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is
- on the j-th token. Only used during pretraining for partial prediction or for sequential decoding
- (generation).
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
- use_mems (`bool`, *optional*):
- Whether to use memory states to speed up sequential decoding. If set to `True`, the model will use the hidden
- states from previous forward passes to compute attention, which can significantly improve performance for
- sequential decoding tasks.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
- flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
- flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
- flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
- flat_input_mask = input_mask.view(-1, input_mask.size(-1)) if input_mask is not None else None
- flat_inputs_embeds = (
- inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
- if inputs_embeds is not None
- else None
- )
- transformer_outputs = self.transformer(
- flat_input_ids,
- token_type_ids=flat_token_type_ids,
- input_mask=flat_input_mask,
- attention_mask=flat_attention_mask,
- mems=mems,
- perm_mask=perm_mask,
- target_mapping=target_mapping,
- inputs_embeds=flat_inputs_embeds,
- use_mems=use_mems,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- **kwargs,
- )
- output = transformer_outputs[0]
- output = self.sequence_summary(output)
- logits = self.logits_proj(output)
- reshaped_logits = logits.view(-1, num_choices)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(reshaped_logits, labels.view(-1))
- if not return_dict:
- output = (reshaped_logits,) + transformer_outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return XLNetForMultipleChoiceOutput(
- loss=loss,
- logits=reshaped_logits,
- mems=transformer_outputs.mems,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
- layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
- """
- )
- class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.transformer = XLNetModel(config)
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- mems: torch.Tensor | None = None,
- perm_mask: torch.Tensor | None = None,
- target_mapping: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- input_mask: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- start_positions: torch.Tensor | None = None,
- end_positions: torch.Tensor | None = None,
- use_mems: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs, # delete when `use_cache` is removed in XLNetModel
- ) -> tuple | XLNetForQuestionAnsweringSimpleOutput:
- r"""
- mems (`list[torch.FloatTensor]` of length `config.n_layers`):
- Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential
- decoding. The token ids which have their past given to this model should not be passed as `input_ids` as
- they have already been computed.
- `use_mems` has to be set to `True` to make use of `mems`.
- perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*):
- Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`:
- - if `perm_mask[k, i, j] = 0`, i attend to j in batch k;
- - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k.
- If not set, each token attends to all the others (full bidirectional attention). Only used during
- pretraining (to define factorization order) or for sequential decoding (generation).
- target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*):
- Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is
- on the j-th token. Only used during pretraining for partial prediction or for sequential decoding
- (generation).
- input_mask (`torch.FloatTensor` of shape `batch_size, sequence_length`, *optional*):
- Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for
- real tokens and 1 for padding which is kept for compatibility with the original code base.
- Mask values selected in `[0, 1]`:
- - 1 for tokens that are **masked**,
- - 0 for tokens that are **not masked**.
- You can only uses one of `input_mask` and `attention_mask`.
- use_mems (`bool`, *optional*):
- Whether to use memory states to speed up sequential decoding. If set to `True`, the model will use the hidden
- states from previous forward passes to compute attention, which can significantly improve performance for
- sequential decoding tasks.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- mems=mems,
- perm_mask=perm_mask,
- target_mapping=target_mapping,
- token_type_ids=token_type_ids,
- input_mask=input_mask,
- inputs_embeds=inputs_embeds,
- use_mems=use_mems,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- **kwargs,
- )
- sequence_output = outputs[0]
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- if not return_dict:
- output = (start_logits, end_logits) + outputs[1:]
- return ((total_loss,) + output) if total_loss is not None else output
- return XLNetForQuestionAnsweringSimpleOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- mems=outputs.mems,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class XLNetForQuestionAnswering(XLNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.start_n_top = config.start_n_top
- self.end_n_top = config.end_n_top
- self.transformer = XLNetModel(config)
- self.start_logits = XLNetPoolerStartLogits(config)
- self.end_logits = XLNetPoolerEndLogits(config)
- self.answer_class = XLNetPoolerAnswerClass(config)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- mems: torch.Tensor | None = None,
- perm_mask: torch.Tensor | None = None,
- target_mapping: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- input_mask: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- start_positions: torch.Tensor | None = None,
- end_positions: torch.Tensor | None = None,
- is_impossible: torch.Tensor | None = None,
- cls_index: torch.Tensor | None = None,
- p_mask: torch.Tensor | None = None,
- use_mems: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs, # delete when `use_cache` is removed in XLNetModel
- ) -> tuple | XLNetForQuestionAnsweringOutput:
- r"""
- mems (`list[torch.FloatTensor]` of length `config.n_layers`):
- Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential
- decoding. The token ids which have their past given to this model should not be passed as `input_ids` as
- they have already been computed.
- `use_mems` has to be set to `True` to make use of `mems`.
- perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*):
- Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`:
- - if `perm_mask[k, i, j] = 0`, i attend to j in batch k;
- - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k.
- If not set, each token attends to all the others (full bidirectional attention). Only used during
- pretraining (to define factorization order) or for sequential decoding (generation).
- target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*):
- Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is
- on the j-th token. Only used during pretraining for partial prediction or for sequential decoding
- (generation).
- input_mask (`torch.FloatTensor` of shape `batch_size, sequence_length`, *optional*):
- Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for
- real tokens and 1 for padding which is kept for compatibility with the original code base.
- Mask values selected in `[0, 1]`:
- - 1 for tokens that are **masked**,
- - 0 for tokens that are **not masked**.
- You can only uses one of `input_mask` and `attention_mask`.
- is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels whether a question has an answer or no answer (SQuAD 2.0)
- cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the classification token to use as input for computing plausibility of the
- answer.
- p_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...). 1.0 means token should be
- masked. 0.0 mean token is not masked.
- use_mems (`bool`, *optional*):
- Whether to use memory states to speed up sequential decoding. If set to `True`, the model will use the hidden
- states from previous forward passes to compute attention, which can significantly improve performance for
- sequential decoding tasks.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, XLNetForQuestionAnswering
- >>> import torch
- >>> tokenizer = AutoTokenizer.from_pretrained("xlnet/xlnet-base-cased")
- >>> model = XLNetForQuestionAnswering.from_pretrained("xlnet/xlnet-base-cased")
- >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(
- ... 0
- ... ) # Batch size 1
- >>> start_positions = torch.tensor([1])
- >>> end_positions = torch.tensor([3])
- >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
- >>> loss = outputs.loss
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- transformer_outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- mems=mems,
- perm_mask=perm_mask,
- target_mapping=target_mapping,
- token_type_ids=token_type_ids,
- input_mask=input_mask,
- inputs_embeds=inputs_embeds,
- use_mems=use_mems,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- **kwargs,
- )
- hidden_states = transformer_outputs[0]
- start_logits = self.start_logits(hidden_states, p_mask=p_mask)
- outputs = transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, let's remove the dimension added by batch splitting
- for x in (start_positions, end_positions, cls_index, is_impossible):
- if x is not None and x.dim() > 1:
- x.squeeze_(-1)
- # during training, compute the end logits based on the ground truth of the start position
- end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
- loss_fct = CrossEntropyLoss()
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- if cls_index is not None and is_impossible is not None:
- # Predict answerability from the representation of CLS and START
- cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
- loss_fct_cls = nn.BCEWithLogitsLoss()
- cls_loss = loss_fct_cls(cls_logits, is_impossible)
- # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
- total_loss += cls_loss * 0.5
- if not return_dict:
- return (total_loss,) + transformer_outputs[1:]
- else:
- return XLNetForQuestionAnsweringOutput(
- loss=total_loss,
- mems=transformer_outputs.mems,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- else:
- # during inference, compute the end logits based on beam search
- bsz, slen, hsz = hidden_states.size()
- start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
- start_top_log_probs, start_top_index = torch.topk(
- start_log_probs, self.start_n_top, dim=-1
- ) # shape (bsz, start_n_top)
- start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
- start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
- start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
- hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
- start_states
- ) # shape (bsz, slen, start_n_top, hsz)
- p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
- end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
- end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
- end_top_log_probs, end_top_index = torch.topk(
- end_log_probs, self.end_n_top, dim=1
- ) # shape (bsz, end_n_top, start_n_top)
- end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
- end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
- start_states = torch.einsum(
- "blh,bl->bh", hidden_states, start_log_probs
- ) # get the representation of START as weighted sum of hidden states
- cls_logits = self.answer_class(
- hidden_states, start_states=start_states, cls_index=cls_index
- ) # Shape (batch size,): one single `cls_logits` for each sample
- if not return_dict:
- outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
- return outputs + transformer_outputs[1:]
- else:
- return XLNetForQuestionAnsweringOutput(
- start_top_log_probs=start_top_log_probs,
- start_top_index=start_top_index,
- end_top_log_probs=end_top_log_probs,
- end_top_index=end_top_index,
- cls_logits=cls_logits,
- mems=transformer_outputs.mems,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- __all__ = [
- "XLNetForMultipleChoice",
- "XLNetForQuestionAnswering",
- "XLNetForQuestionAnsweringSimple",
- "XLNetForSequenceClassification",
- "XLNetForTokenClassification",
- "XLNetLMHeadModel",
- "XLNetModel",
- "XLNetPreTrainedModel",
- ]
|