modeling_xlm.py 71 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601
  1. # Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. PyTorch XLM model.
  16. """
  17. import math
  18. from collections.abc import Callable
  19. from dataclasses import dataclass
  20. import numpy as np
  21. import torch
  22. from torch import nn
  23. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  24. from ... import initialization as init
  25. from ...activations import gelu, get_activation
  26. from ...cache_utils import DynamicCache, EncoderDecoderCache
  27. from ...generation import GenerationMixin
  28. from ...modeling_outputs import (
  29. BaseModelOutput,
  30. MaskedLMOutput,
  31. MultipleChoiceModelOutput,
  32. QuestionAnsweringModelOutput,
  33. SequenceClassifierOutput,
  34. TokenClassifierOutput,
  35. )
  36. from ...modeling_utils import PreTrainedModel
  37. from ...pytorch_utils import apply_chunking_to_forward
  38. from ...utils import ModelOutput, auto_docstring, logging
  39. from .configuration_xlm import XLMConfig
  40. logger = logging.get_logger(__name__)
  41. def create_sinusoidal_embeddings(n_pos, dim, out):
  42. position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
  43. out.requires_grad = False
  44. out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
  45. out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
  46. out.detach_()
  47. return out
  48. def get_masks(slen, lengths, causal, padding_mask=None):
  49. """
  50. Generate hidden states mask, and optionally an attention mask.
  51. """
  52. alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
  53. if padding_mask is not None:
  54. mask = padding_mask
  55. else:
  56. assert lengths.max().item() <= slen
  57. mask = alen < lengths[:, None]
  58. # attention mask is the same as mask, or triangular inferior attention (causal)
  59. bs = lengths.size(0)
  60. if causal:
  61. attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
  62. else:
  63. attn_mask = mask
  64. # sanity check
  65. assert mask.size() == (bs, slen)
  66. assert causal is False or attn_mask.size() == (bs, slen, slen)
  67. return mask, attn_mask
  68. @dataclass
  69. @auto_docstring(
  70. custom_intro="""
  71. Base class for outputs of question answering models using a [`~modeling_utils.XLMSQuADHead`].
  72. """
  73. )
  74. class XLMSquadHeadOutput(ModelOutput):
  75. r"""
  76. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
  77. Classification loss as the sum of start token, end token (and is_impossible if provided) classification
  78. losses.
  79. start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  80. Log probabilities for the top config.start_n_top start token possibilities (beam-search).
  81. start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  82. Indices for the top config.start_n_top start token possibilities (beam-search).
  83. end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  84. Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
  85. (beam-search).
  86. end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  87. Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
  88. cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  89. Log probabilities for the `is_impossible` label of the answers.
  90. """
  91. loss: torch.FloatTensor | None = None
  92. start_top_log_probs: torch.FloatTensor | None = None
  93. start_top_index: torch.LongTensor | None = None
  94. end_top_log_probs: torch.FloatTensor | None = None
  95. end_top_index: torch.LongTensor | None = None
  96. cls_logits: torch.FloatTensor | None = None
  97. class XLMPoolerStartLogits(nn.Module):
  98. """
  99. Compute SQuAD start logits from sequence hidden states.
  100. Args:
  101. config ([`XLMConfig`]):
  102. The config used by the model, will be used to grab the `hidden_size` of the model.
  103. """
  104. def __init__(self, config: XLMConfig):
  105. super().__init__()
  106. self.dense = nn.Linear(config.hidden_size, 1)
  107. def forward(self, hidden_states: torch.FloatTensor, p_mask: torch.FloatTensor | None = None) -> torch.FloatTensor:
  108. """
  109. Args:
  110. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  111. The final hidden states of the model.
  112. p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
  113. Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
  114. should be masked.
  115. Returns:
  116. `torch.FloatTensor`: The start logits for SQuAD.
  117. """
  118. x = self.dense(hidden_states).squeeze(-1)
  119. if p_mask is not None:
  120. if p_mask.dtype == torch.float16:
  121. x = x * (1 - p_mask) - 65500 * p_mask
  122. else:
  123. x = x * (1 - p_mask) - 1e30 * p_mask
  124. return x
  125. class XLMPoolerEndLogits(nn.Module):
  126. """
  127. Compute SQuAD end logits from sequence hidden states.
  128. Args:
  129. config ([`XLMConfig`]):
  130. The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
  131. to use.
  132. """
  133. def __init__(self, config: XLMConfig):
  134. super().__init__()
  135. self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
  136. self.activation = nn.Tanh()
  137. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  138. self.dense_1 = nn.Linear(config.hidden_size, 1)
  139. def forward(
  140. self,
  141. hidden_states: torch.FloatTensor,
  142. start_states: torch.FloatTensor | None = None,
  143. start_positions: torch.LongTensor | None = None,
  144. p_mask: torch.FloatTensor | None = None,
  145. ) -> torch.FloatTensor:
  146. """
  147. Args:
  148. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  149. The final hidden states of the model.
  150. start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
  151. The hidden states of the first tokens for the labeled span.
  152. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  153. The position of the first token for the labeled span.
  154. p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
  155. Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
  156. should be masked.
  157. <Tip>
  158. One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
  159. `start_states`.
  160. </Tip>
  161. Returns:
  162. `torch.FloatTensor`: The end logits for SQuAD.
  163. """
  164. assert start_states is not None or start_positions is not None, (
  165. "One of start_states, start_positions should be not None"
  166. )
  167. if start_positions is not None:
  168. slen, hsz = hidden_states.shape[-2:]
  169. start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  170. start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
  171. start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
  172. x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
  173. x = self.activation(x)
  174. x = self.LayerNorm(x)
  175. x = self.dense_1(x).squeeze(-1)
  176. if p_mask is not None:
  177. if p_mask.dtype == torch.float16:
  178. x = x * (1 - p_mask) - 65500 * p_mask
  179. else:
  180. x = x * (1 - p_mask) - 1e30 * p_mask
  181. return x
  182. class XLMPoolerAnswerClass(nn.Module):
  183. """
  184. Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
  185. Args:
  186. config ([`XLMConfig`]):
  187. The config used by the model, will be used to grab the `hidden_size` of the model.
  188. """
  189. def __init__(self, config: XLMConfig):
  190. super().__init__()
  191. self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
  192. self.activation = nn.Tanh()
  193. self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
  194. def forward(
  195. self,
  196. hidden_states: torch.FloatTensor,
  197. start_states: torch.FloatTensor | None = None,
  198. start_positions: torch.LongTensor | None = None,
  199. cls_index: torch.LongTensor | None = None,
  200. ) -> torch.FloatTensor:
  201. """
  202. Args:
  203. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  204. The final hidden states of the model.
  205. start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
  206. The hidden states of the first tokens for the labeled span.
  207. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  208. The position of the first token for the labeled span.
  209. cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  210. Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
  211. <Tip>
  212. One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
  213. `start_states`.
  214. </Tip>
  215. Returns:
  216. `torch.FloatTensor`: The SQuAD 2.0 answer class.
  217. """
  218. # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
  219. hsz = hidden_states.shape[-1]
  220. assert start_states is not None or start_positions is not None, (
  221. "One of start_states, start_positions should be not None"
  222. )
  223. if start_positions is not None:
  224. start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  225. start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
  226. if cls_index is not None:
  227. cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  228. cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
  229. else:
  230. cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
  231. x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
  232. x = self.activation(x)
  233. x = self.dense_1(x).squeeze(-1)
  234. return x
  235. class XLMSQuADHead(nn.Module):
  236. r"""
  237. A SQuAD head inspired by XLNet.
  238. Args:
  239. config ([`XLMConfig`]):
  240. The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
  241. to use.
  242. """
  243. def __init__(self, config: XLMConfig):
  244. super().__init__()
  245. self.start_n_top = config.start_n_top
  246. self.end_n_top = config.end_n_top
  247. self.start_logits = XLMPoolerStartLogits(config)
  248. self.end_logits = XLMPoolerEndLogits(config)
  249. self.answer_class = XLMPoolerAnswerClass(config)
  250. @auto_docstring
  251. def forward(
  252. self,
  253. hidden_states: torch.FloatTensor,
  254. start_positions: torch.LongTensor | None = None,
  255. end_positions: torch.LongTensor | None = None,
  256. cls_index: torch.LongTensor | None = None,
  257. is_impossible: torch.LongTensor | None = None,
  258. p_mask: torch.FloatTensor | None = None,
  259. return_dict: bool = False,
  260. ) -> XLMSquadHeadOutput | tuple[torch.FloatTensor]:
  261. r"""
  262. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  263. Final hidden states of the model on the sequence tokens.
  264. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  265. Positions of the first token for the labeled span.
  266. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  267. Positions of the last token for the labeled span.
  268. cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  269. Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
  270. is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  271. Whether the question has a possible answer in the paragraph or not.
  272. p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
  273. Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
  274. should be masked.
  275. """
  276. start_logits = self.start_logits(hidden_states, p_mask=p_mask)
  277. if start_positions is not None and end_positions is not None:
  278. # If we are on multi-GPU, let's remove the dimension added by batch splitting
  279. for x in (start_positions, end_positions, cls_index, is_impossible):
  280. if x is not None and x.dim() > 1:
  281. x.squeeze_(-1)
  282. # during training, compute the end logits based on the ground truth of the start position
  283. end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
  284. loss_fct = CrossEntropyLoss()
  285. start_loss = loss_fct(start_logits, start_positions)
  286. end_loss = loss_fct(end_logits, end_positions)
  287. total_loss = (start_loss + end_loss) / 2
  288. if cls_index is not None and is_impossible is not None:
  289. # Predict answerability from the representation of CLS and START
  290. cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
  291. loss_fct_cls = nn.BCEWithLogitsLoss()
  292. cls_loss = loss_fct_cls(cls_logits, is_impossible)
  293. # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
  294. total_loss += cls_loss * 0.5
  295. return XLMSquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
  296. else:
  297. # during inference, compute the end logits based on beam search
  298. bsz, slen, hsz = hidden_states.size()
  299. start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
  300. start_top_log_probs, start_top_index = torch.topk(
  301. start_log_probs, self.start_n_top, dim=-1
  302. ) # shape (bsz, start_n_top)
  303. start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
  304. start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
  305. start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
  306. hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
  307. start_states
  308. ) # shape (bsz, slen, start_n_top, hsz)
  309. p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
  310. end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
  311. end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
  312. end_top_log_probs, end_top_index = torch.topk(
  313. end_log_probs, self.end_n_top, dim=1
  314. ) # shape (bsz, end_n_top, start_n_top)
  315. end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
  316. end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
  317. start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
  318. cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
  319. if not return_dict:
  320. return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
  321. else:
  322. return XLMSquadHeadOutput(
  323. start_top_log_probs=start_top_log_probs,
  324. start_top_index=start_top_index,
  325. end_top_log_probs=end_top_log_probs,
  326. end_top_index=end_top_index,
  327. cls_logits=cls_logits,
  328. )
  329. class XLMSequenceSummary(nn.Module):
  330. r"""
  331. Compute a single vector summary of a sequence hidden states.
  332. Args:
  333. config ([`XLMConfig`]):
  334. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  335. config class of your model for the default values it uses):
  336. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  337. - `"last"` -- Take the last token hidden state (like XLNet)
  338. - `"first"` -- Take the first token hidden state (like Bert)
  339. - `"mean"` -- Take the mean of all tokens hidden states
  340. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  341. - `"attn"` -- Not implemented now, use multi-head attention
  342. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  343. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  344. (otherwise to `config.hidden_size`).
  345. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  346. another string or `None` will add no activation.
  347. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  348. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  349. """
  350. def __init__(self, config: XLMConfig):
  351. super().__init__()
  352. self.summary_type = getattr(config, "summary_type", "last")
  353. if self.summary_type == "attn":
  354. # We should use a standard multi-head attention module with absolute positional embedding for that.
  355. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  356. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  357. raise NotImplementedError
  358. self.summary = nn.Identity()
  359. if hasattr(config, "summary_use_proj") and config.summary_use_proj:
  360. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  361. num_classes = config.num_labels
  362. else:
  363. num_classes = config.hidden_size
  364. self.summary = nn.Linear(config.hidden_size, num_classes)
  365. activation_string = getattr(config, "summary_activation", None)
  366. self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
  367. self.first_dropout = nn.Identity()
  368. if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
  369. self.first_dropout = nn.Dropout(config.summary_first_dropout)
  370. self.last_dropout = nn.Identity()
  371. if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
  372. self.last_dropout = nn.Dropout(config.summary_last_dropout)
  373. def forward(
  374. self, hidden_states: torch.FloatTensor, cls_index: torch.LongTensor | None = None
  375. ) -> torch.FloatTensor:
  376. """
  377. Compute a single vector summary of a sequence hidden states.
  378. Args:
  379. hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
  380. The hidden states of the last layer.
  381. cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
  382. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
  383. Returns:
  384. `torch.FloatTensor`: The summary of the sequence hidden states.
  385. """
  386. if self.summary_type == "last":
  387. output = hidden_states[:, -1]
  388. elif self.summary_type == "first":
  389. output = hidden_states[:, 0]
  390. elif self.summary_type == "mean":
  391. output = hidden_states.mean(dim=1)
  392. elif self.summary_type == "cls_index":
  393. if cls_index is None:
  394. cls_index = torch.full_like(
  395. hidden_states[..., :1, :],
  396. hidden_states.shape[-2] - 1,
  397. dtype=torch.long,
  398. )
  399. else:
  400. cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
  401. cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
  402. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  403. output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
  404. elif self.summary_type == "attn":
  405. raise NotImplementedError
  406. output = self.first_dropout(output)
  407. output = self.summary(output)
  408. output = self.activation(output)
  409. output = self.last_dropout(output)
  410. return output
  411. class MultiHeadAttention(nn.Module):
  412. def __init__(self, n_heads, dim, config, layer_idx: int = 0):
  413. super().__init__()
  414. self.layer_id = layer_idx
  415. self.dim = dim
  416. self.n_heads = n_heads
  417. self.head_dim = dim // n_heads
  418. self.dropout = config.attention_dropout
  419. assert self.dim % self.n_heads == 0
  420. self.q_lin = nn.Linear(dim, dim)
  421. self.k_lin = nn.Linear(dim, dim)
  422. self.v_lin = nn.Linear(dim, dim)
  423. self.out_lin = nn.Linear(dim, dim)
  424. def forward(
  425. self,
  426. input,
  427. mask,
  428. kv=None,
  429. cache=None,
  430. output_attentions=False,
  431. **kwargs,
  432. ):
  433. """
  434. Self-attention (if kv is None) or attention over source sentence (provided by kv).
  435. """
  436. # Input is (bs, qlen, dim)
  437. # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
  438. bs, qlen, dim = input.size()
  439. is_cross_attention = kv is not None
  440. mask_reshape = (bs, 1, qlen, -1) if mask.dim() == 3 else (bs, 1, 1, -1)
  441. q = self.q_lin(input).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
  442. if cache is not None:
  443. if isinstance(cache, EncoderDecoderCache):
  444. is_updated = cache.is_updated.get(self.layer_id)
  445. if is_cross_attention:
  446. # after the first generated id, we can subsequently re-use all key/value_states from cache
  447. curr_past_key_values = cache.cross_attention_cache
  448. else:
  449. curr_past_key_values = cache.self_attention_cache
  450. else:
  451. curr_past_key_values = cache
  452. current_states = kv if is_cross_attention else input
  453. if is_cross_attention and cache is not None and is_updated:
  454. # reuse k,v, cross_attentions
  455. k = curr_past_key_values.key_cache[self.layer_id]
  456. v = curr_past_key_values.value_cache[self.layer_id]
  457. else:
  458. k = self.k_lin(current_states)
  459. v = self.v_lin(current_states)
  460. k = k.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
  461. v = v.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
  462. if cache is not None:
  463. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  464. k, v = curr_past_key_values.update(k, v, self.layer_id)
  465. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  466. if is_cross_attention:
  467. cache.is_updated[self.layer_id] = True
  468. q = q / math.sqrt(self.head_dim) # (bs, n_heads, qlen, head_dim)
  469. scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
  470. mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
  471. scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
  472. weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
  473. weights = nn.functional.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
  474. context = torch.matmul(weights, v) # (bs, n_heads, qlen, head_dim)
  475. context = context.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.head_dim)
  476. outputs = (self.out_lin(context),)
  477. if output_attentions:
  478. outputs = outputs + (weights,)
  479. return outputs
  480. class TransformerFFN(nn.Module):
  481. def __init__(self, in_dim, dim_hidden, out_dim, config):
  482. super().__init__()
  483. self.dropout = config.dropout
  484. self.lin1 = nn.Linear(in_dim, dim_hidden)
  485. self.lin2 = nn.Linear(dim_hidden, out_dim)
  486. self.act = gelu if config.gelu_activation else nn.functional.relu
  487. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  488. self.seq_len_dim = 1
  489. def forward(self, input):
  490. return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
  491. def ff_chunk(self, input):
  492. x = self.lin1(input)
  493. x = self.act(x)
  494. x = self.lin2(x)
  495. x = nn.functional.dropout(x, p=self.dropout, training=self.training)
  496. return x
  497. @auto_docstring
  498. class XLMPreTrainedModel(PreTrainedModel):
  499. config: XLMConfig
  500. base_model_prefix = "transformer"
  501. @property
  502. def dummy_inputs(self):
  503. inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
  504. attns_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
  505. if self.config.use_lang_emb and self.config.n_langs > 1:
  506. langs_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
  507. else:
  508. langs_list = None
  509. return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
  510. @torch.no_grad()
  511. def _init_weights(self, module):
  512. """Initialize the weights."""
  513. if isinstance(module, nn.Embedding):
  514. if self.config is not None and self.config.embed_init_std is not None:
  515. init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
  516. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  517. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  518. init.zeros_(module.weight[module.padding_idx])
  519. if isinstance(module, nn.Linear):
  520. if self.config is not None and self.config.init_std is not None:
  521. init.normal_(module.weight, mean=0, std=self.config.init_std)
  522. if module.bias is not None:
  523. init.constant_(module.bias, 0.0)
  524. if isinstance(module, nn.LayerNorm):
  525. init.zeros_(module.bias)
  526. init.ones_(module.weight)
  527. if isinstance(module, XLMModel):
  528. if self.config.sinusoidal_embeddings:
  529. init.copy_(
  530. module.position_embeddings.weight,
  531. create_sinusoidal_embeddings(
  532. self.config.max_position_embeddings,
  533. self.config.emb_dim,
  534. out=torch.empty_like(module.position_embeddings.weight),
  535. ),
  536. )
  537. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  538. @dataclass
  539. @auto_docstring(
  540. custom_intro="""
  541. Base class for outputs of question answering models using a `XLMSQuADHead`.
  542. """
  543. )
  544. class XLMForQuestionAnsweringOutput(ModelOutput):
  545. r"""
  546. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
  547. Classification loss as the sum of start token, end token (and is_impossible if provided) classification
  548. losses.
  549. start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  550. Log probabilities for the top config.start_n_top start token possibilities (beam-search).
  551. start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  552. Indices for the top config.start_n_top start token possibilities (beam-search).
  553. end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  554. Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
  555. (beam-search).
  556. end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  557. Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
  558. cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  559. Log probabilities for the `is_impossible` label of the answers.
  560. """
  561. loss: torch.FloatTensor | None = None
  562. start_top_log_probs: torch.FloatTensor | None = None
  563. start_top_index: torch.LongTensor | None = None
  564. end_top_log_probs: torch.FloatTensor | None = None
  565. end_top_index: torch.LongTensor | None = None
  566. cls_logits: torch.FloatTensor | None = None
  567. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  568. attentions: tuple[torch.FloatTensor, ...] | None = None
  569. @auto_docstring
  570. class XLMModel(XLMPreTrainedModel):
  571. def __init__(self, config):
  572. super().__init__(config)
  573. # encoder / decoder, output layer
  574. self.is_encoder = config.is_encoder
  575. self.is_decoder = not config.is_encoder
  576. if self.is_decoder:
  577. raise NotImplementedError("Currently XLM can only be used as an encoder")
  578. # self.with_output = with_output
  579. self.causal = config.causal
  580. # dictionary / languages
  581. self.n_langs = config.n_langs
  582. self.use_lang_emb = config.use_lang_emb
  583. self.n_words = config.n_words
  584. self.eos_index = config.eos_index
  585. self.pad_index = config.pad_index
  586. # self.dico = dico
  587. # self.id2lang = config.id2lang
  588. # self.lang2id = config.lang2id
  589. # assert len(self.dico) == self.n_words
  590. # assert len(self.id2lang) == len(self.lang2id) == self.n_langs
  591. # model parameters
  592. self.dim = config.emb_dim # 512 by default
  593. self.hidden_dim = self.dim * 4 # 2048 by default
  594. self.n_heads = config.n_heads # 8 by default
  595. self.n_layers = config.n_layers
  596. self.dropout = config.dropout
  597. self.attention_dropout = config.attention_dropout
  598. assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads"
  599. # embeddings
  600. self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
  601. if config.n_langs > 1 and config.use_lang_emb:
  602. self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
  603. self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
  604. self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
  605. # transformer layers
  606. self.attentions = nn.ModuleList()
  607. self.layer_norm1 = nn.ModuleList()
  608. self.ffns = nn.ModuleList()
  609. self.layer_norm2 = nn.ModuleList()
  610. # if self.is_decoder:
  611. # self.layer_norm15 = nn.ModuleList()
  612. # self.encoder_attn = nn.ModuleList()
  613. for i in range(self.n_layers):
  614. self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config, layer_idx=i))
  615. self.layer_norm1.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
  616. # if self.is_decoder:
  617. # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
  618. # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
  619. self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
  620. self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
  621. # Initialize weights and apply final processing
  622. self.register_buffer(
  623. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  624. )
  625. self.post_init()
  626. def get_input_embeddings(self):
  627. return self.embeddings
  628. def set_input_embeddings(self, new_embeddings):
  629. self.embeddings = new_embeddings
  630. @auto_docstring
  631. def forward(
  632. self,
  633. input_ids: torch.Tensor | None = None,
  634. attention_mask: torch.Tensor | None = None,
  635. langs: torch.Tensor | None = None,
  636. token_type_ids: torch.Tensor | None = None,
  637. position_ids: torch.Tensor | None = None,
  638. lengths: torch.Tensor | None = None,
  639. cache: dict[str, torch.Tensor] | None = None,
  640. inputs_embeds: torch.Tensor | None = None,
  641. output_attentions: bool | None = None,
  642. output_hidden_states: bool | None = None,
  643. return_dict: bool | None = None,
  644. **kwargs, # Dummy kwargs for now
  645. ) -> tuple | BaseModelOutput:
  646. r"""
  647. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  648. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  649. languages ids which can be obtained from the language names by using two conversion mappings provided in
  650. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  651. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  652. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  653. See usage examples detailed in the [multilingual documentation](../multilingual).
  654. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  655. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  656. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  657. `[0, ..., input_ids.size(-1)]`.
  658. cache (`dict[str, torch.FloatTensor]`, *optional*):
  659. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  660. decoding.
  661. """
  662. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  663. output_hidden_states = (
  664. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  665. )
  666. return_dict = return_dict if return_dict is not None else self.config.return_dict
  667. if input_ids is not None:
  668. bs, slen = input_ids.size()
  669. else:
  670. bs, slen = inputs_embeds.size()[:-1]
  671. device = input_ids.device if input_ids is not None else inputs_embeds.device
  672. if cache is None:
  673. cache = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  674. if lengths is None:
  675. if input_ids is not None:
  676. lengths = (input_ids != self.pad_index).sum(dim=1).long()
  677. else:
  678. lengths = torch.full((bs,), slen, device=device, dtype=torch.long)
  679. # check inputs
  680. assert lengths.size(0) == bs
  681. assert lengths.max().item() <= slen
  682. # generate masks
  683. mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
  684. # position_ids
  685. if position_ids is None:
  686. position_ids = self.position_ids[:, :slen]
  687. else:
  688. assert position_ids.size() == (bs, slen) # (slen, bs)
  689. # langs
  690. if langs is not None:
  691. assert langs.size() == (bs, slen) # (slen, bs)
  692. # do not recompute cached elements
  693. if cache is not None and input_ids is not None:
  694. _slen = slen - cache.get_seq_length()
  695. input_ids = input_ids[:, -_slen:]
  696. position_ids = position_ids[:, -_slen:]
  697. if langs is not None:
  698. langs = langs[:, -_slen:]
  699. mask = mask[:, -_slen:]
  700. attn_mask = attn_mask[:, -_slen:]
  701. # embeddings
  702. if inputs_embeds is None:
  703. inputs_embeds = self.embeddings(input_ids)
  704. tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds)
  705. if langs is not None and self.use_lang_emb and self.n_langs > 1:
  706. tensor = tensor + self.lang_embeddings(langs)
  707. if token_type_ids is not None:
  708. tensor = tensor + self.embeddings(token_type_ids)
  709. tensor = self.layer_norm_emb(tensor)
  710. tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training)
  711. tensor *= mask.unsqueeze(-1).to(tensor.dtype)
  712. # transformer layers
  713. hidden_states = () if output_hidden_states else None
  714. attentions = () if output_attentions else None
  715. for i in range(self.n_layers):
  716. if output_hidden_states:
  717. hidden_states = hidden_states + (tensor,)
  718. # self attention
  719. attn_outputs = self.attentions[i](
  720. tensor,
  721. attn_mask,
  722. cache=cache,
  723. output_attentions=output_attentions,
  724. )
  725. attn = attn_outputs[0]
  726. if output_attentions:
  727. attentions = attentions + (attn_outputs[1],)
  728. attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
  729. tensor = tensor + attn
  730. tensor = self.layer_norm1[i](tensor)
  731. # FFN
  732. tensor = tensor + self.ffns[i](tensor)
  733. tensor = self.layer_norm2[i](tensor)
  734. tensor *= mask.unsqueeze(-1).to(tensor.dtype)
  735. # Add last hidden state
  736. if output_hidden_states:
  737. hidden_states = hidden_states + (tensor,)
  738. if not return_dict:
  739. return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
  740. return BaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
  741. class XLMPredLayer(nn.Module):
  742. """
  743. Prediction layer (cross_entropy or adaptive_softmax).
  744. """
  745. def __init__(self, config):
  746. super().__init__()
  747. self.asm = config.asm
  748. self.n_words = config.n_words
  749. self.pad_index = config.pad_index
  750. dim = config.emb_dim
  751. if config.asm is False:
  752. self.proj = nn.Linear(dim, config.n_words, bias=True)
  753. else:
  754. self.proj = nn.AdaptiveLogSoftmaxWithLoss(
  755. in_features=dim,
  756. n_classes=config.n_words,
  757. cutoffs=config.asm_cutoffs,
  758. div_value=config.asm_div_value,
  759. head_bias=True, # default is False
  760. )
  761. def forward(self, x, y=None):
  762. """Compute the loss, and optionally the scores."""
  763. outputs = ()
  764. if self.asm is False:
  765. scores = self.proj(x)
  766. outputs = (scores,) + outputs
  767. if y is not None:
  768. loss = nn.functional.cross_entropy(scores.view(-1, self.n_words), y.view(-1), reduction="mean")
  769. outputs = (loss,) + outputs
  770. else:
  771. scores = self.proj.log_prob(x)
  772. outputs = (scores,) + outputs
  773. if y is not None:
  774. _, loss = self.proj(x, y)
  775. outputs = (loss,) + outputs
  776. return outputs
  777. @auto_docstring(
  778. custom_intro="""
  779. The XLM Model transformer with a language modeling head on top (linear layer with weights tied to the input
  780. embeddings).
  781. """
  782. )
  783. class XLMWithLMHeadModel(XLMPreTrainedModel, GenerationMixin):
  784. _tied_weights_keys = {"pred_layer.proj.weight": "transformer.embeddings.weight"}
  785. def __init__(self, config):
  786. super().__init__(config)
  787. self.transformer = XLMModel(config)
  788. self.pred_layer = XLMPredLayer(config)
  789. # Initialize weights and apply final processing
  790. self.post_init()
  791. def get_output_embeddings(self):
  792. return self.pred_layer.proj
  793. def set_output_embeddings(self, new_embeddings):
  794. self.pred_layer.proj = new_embeddings
  795. def prepare_inputs_for_generation(self, input_ids, is_first_iteration=False, **kwargs):
  796. # Overwritten -- this model uses config options to prepare inputs
  797. mask_token_id = self.config.mask_token_id
  798. lang_id = self.config.lang_id
  799. effective_batch_size = input_ids.shape[0]
  800. mask_token = torch.full((effective_batch_size, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
  801. input_ids = torch.cat([input_ids, mask_token], dim=1)
  802. if lang_id is not None:
  803. langs = torch.full_like(input_ids, lang_id)
  804. else:
  805. langs = None
  806. model_inputs = {"input_ids": input_ids, "langs": langs}
  807. # They are calculated on the fly on XLMModel.forward()
  808. kwargs.pop("token_type_ids", None)
  809. kwargs.pop("attention_mask", None)
  810. kwargs.pop("position_ids", None)
  811. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  812. for key, value in kwargs.items():
  813. if key not in model_inputs:
  814. model_inputs[key] = value
  815. return model_inputs
  816. @auto_docstring
  817. def forward(
  818. self,
  819. input_ids: torch.Tensor | None = None,
  820. attention_mask: torch.Tensor | None = None,
  821. langs: torch.Tensor | None = None,
  822. token_type_ids: torch.Tensor | None = None,
  823. position_ids: torch.Tensor | None = None,
  824. lengths: torch.Tensor | None = None,
  825. cache: dict[str, torch.Tensor] | None = None,
  826. inputs_embeds: torch.Tensor | None = None,
  827. labels: torch.Tensor | None = None,
  828. output_attentions: bool | None = None,
  829. output_hidden_states: bool | None = None,
  830. return_dict: bool | None = None,
  831. logits_to_keep: int | torch.Tensor = 0,
  832. **kwargs,
  833. ) -> tuple | MaskedLMOutput:
  834. r"""
  835. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  836. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  837. languages ids which can be obtained from the language names by using two conversion mappings provided in
  838. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  839. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  840. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  841. See usage examples detailed in the [multilingual documentation](../multilingual).
  842. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  843. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  844. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  845. `[0, ..., input_ids.size(-1)]`.
  846. cache (`dict[str, torch.FloatTensor]`, *optional*):
  847. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  848. decoding.
  849. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  850. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  851. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  852. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  853. """
  854. return_dict = return_dict if return_dict is not None else self.config.return_dict
  855. transformer_outputs = self.transformer(
  856. input_ids,
  857. attention_mask=attention_mask,
  858. langs=langs,
  859. token_type_ids=token_type_ids,
  860. position_ids=position_ids,
  861. lengths=lengths,
  862. cache=cache,
  863. inputs_embeds=inputs_embeds,
  864. output_attentions=output_attentions,
  865. output_hidden_states=output_hidden_states,
  866. return_dict=return_dict,
  867. **kwargs,
  868. )
  869. hidden_states = transformer_outputs[0]
  870. # Only compute necessary logits
  871. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  872. outputs = self.pred_layer(
  873. hidden_states[:, slice_indices, :],
  874. labels,
  875. ) # (loss, logits) or (logits,) depending on if labels are provided.
  876. if not return_dict:
  877. return outputs + transformer_outputs[1:]
  878. return MaskedLMOutput(
  879. loss=outputs[0] if labels is not None else None,
  880. logits=outputs[0] if labels is None else outputs[1],
  881. hidden_states=transformer_outputs.hidden_states,
  882. attentions=transformer_outputs.attentions,
  883. )
  884. @auto_docstring(
  885. custom_intro="""
  886. XLM Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g.
  887. for GLUE tasks.
  888. """
  889. )
  890. class XLMForSequenceClassification(XLMPreTrainedModel):
  891. def __init__(self, config):
  892. super().__init__(config)
  893. self.num_labels = config.num_labels
  894. self.config = config
  895. self.transformer = XLMModel(config)
  896. self.sequence_summary = XLMSequenceSummary(config)
  897. # Initialize weights and apply final processing
  898. self.post_init()
  899. @auto_docstring
  900. def forward(
  901. self,
  902. input_ids: torch.Tensor | None = None,
  903. attention_mask: torch.Tensor | None = None,
  904. langs: torch.Tensor | None = None,
  905. token_type_ids: torch.Tensor | None = None,
  906. position_ids: torch.Tensor | None = None,
  907. lengths: torch.Tensor | None = None,
  908. cache: dict[str, torch.Tensor] | None = None,
  909. inputs_embeds: torch.Tensor | None = None,
  910. labels: torch.Tensor | None = None,
  911. output_attentions: bool | None = None,
  912. output_hidden_states: bool | None = None,
  913. return_dict: bool | None = None,
  914. **kwargs,
  915. ) -> tuple | SequenceClassifierOutput:
  916. r"""
  917. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  918. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  919. languages ids which can be obtained from the language names by using two conversion mappings provided in
  920. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  921. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  922. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  923. See usage examples detailed in the [multilingual documentation](../multilingual).
  924. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  925. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  926. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  927. `[0, ..., input_ids.size(-1)]`.
  928. cache (`dict[str, torch.FloatTensor]`, *optional*):
  929. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  930. decoding.
  931. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  932. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  933. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  934. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  935. """
  936. return_dict = return_dict if return_dict is not None else self.config.return_dict
  937. transformer_outputs = self.transformer(
  938. input_ids,
  939. attention_mask=attention_mask,
  940. langs=langs,
  941. token_type_ids=token_type_ids,
  942. position_ids=position_ids,
  943. lengths=lengths,
  944. cache=cache,
  945. inputs_embeds=inputs_embeds,
  946. output_attentions=output_attentions,
  947. output_hidden_states=output_hidden_states,
  948. return_dict=return_dict,
  949. )
  950. output = transformer_outputs[0]
  951. logits = self.sequence_summary(output)
  952. loss = None
  953. if labels is not None:
  954. if self.config.problem_type is None:
  955. if self.num_labels == 1:
  956. self.config.problem_type = "regression"
  957. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  958. self.config.problem_type = "single_label_classification"
  959. else:
  960. self.config.problem_type = "multi_label_classification"
  961. if self.config.problem_type == "regression":
  962. loss_fct = MSELoss()
  963. if self.num_labels == 1:
  964. loss = loss_fct(logits.squeeze(), labels.squeeze())
  965. else:
  966. loss = loss_fct(logits, labels)
  967. elif self.config.problem_type == "single_label_classification":
  968. loss_fct = CrossEntropyLoss()
  969. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  970. elif self.config.problem_type == "multi_label_classification":
  971. loss_fct = BCEWithLogitsLoss()
  972. loss = loss_fct(logits, labels)
  973. if not return_dict:
  974. output = (logits,) + transformer_outputs[1:]
  975. return ((loss,) + output) if loss is not None else output
  976. return SequenceClassifierOutput(
  977. loss=loss,
  978. logits=logits,
  979. hidden_states=transformer_outputs.hidden_states,
  980. attentions=transformer_outputs.attentions,
  981. )
  982. @auto_docstring(
  983. custom_intro="""
  984. XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
  985. layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
  986. """
  987. )
  988. class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
  989. def __init__(self, config):
  990. super().__init__(config)
  991. self.transformer = XLMModel(config)
  992. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  993. # Initialize weights and apply final processing
  994. self.post_init()
  995. @auto_docstring
  996. def forward(
  997. self,
  998. input_ids: torch.Tensor | None = None,
  999. attention_mask: torch.Tensor | None = None,
  1000. langs: torch.Tensor | None = None,
  1001. token_type_ids: torch.Tensor | None = None,
  1002. position_ids: torch.Tensor | None = None,
  1003. lengths: torch.Tensor | None = None,
  1004. cache: dict[str, torch.Tensor] | None = None,
  1005. inputs_embeds: torch.Tensor | None = None,
  1006. start_positions: torch.Tensor | None = None,
  1007. end_positions: torch.Tensor | None = None,
  1008. output_attentions: bool | None = None,
  1009. output_hidden_states: bool | None = None,
  1010. return_dict: bool | None = None,
  1011. **kwargs,
  1012. ) -> tuple | QuestionAnsweringModelOutput:
  1013. r"""
  1014. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1015. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  1016. languages ids which can be obtained from the language names by using two conversion mappings provided in
  1017. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  1018. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  1019. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  1020. See usage examples detailed in the [multilingual documentation](../multilingual).
  1021. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1022. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  1023. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  1024. `[0, ..., input_ids.size(-1)]`.
  1025. cache (`dict[str, torch.FloatTensor]`, *optional*):
  1026. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  1027. decoding.
  1028. """
  1029. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1030. transformer_outputs = self.transformer(
  1031. input_ids,
  1032. attention_mask=attention_mask,
  1033. langs=langs,
  1034. token_type_ids=token_type_ids,
  1035. position_ids=position_ids,
  1036. lengths=lengths,
  1037. cache=cache,
  1038. inputs_embeds=inputs_embeds,
  1039. output_attentions=output_attentions,
  1040. output_hidden_states=output_hidden_states,
  1041. return_dict=return_dict,
  1042. )
  1043. sequence_output = transformer_outputs[0]
  1044. logits = self.qa_outputs(sequence_output)
  1045. start_logits, end_logits = logits.split(1, dim=-1)
  1046. start_logits = start_logits.squeeze(-1).contiguous()
  1047. end_logits = end_logits.squeeze(-1).contiguous()
  1048. total_loss = None
  1049. if start_positions is not None and end_positions is not None:
  1050. # If we are on multi-GPU, split add a dimension
  1051. if len(start_positions.size()) > 1:
  1052. start_positions = start_positions.squeeze(-1)
  1053. if len(end_positions.size()) > 1:
  1054. end_positions = end_positions.squeeze(-1)
  1055. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1056. ignored_index = start_logits.size(1)
  1057. start_positions = start_positions.clamp(0, ignored_index)
  1058. end_positions = end_positions.clamp(0, ignored_index)
  1059. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1060. start_loss = loss_fct(start_logits, start_positions)
  1061. end_loss = loss_fct(end_logits, end_positions)
  1062. total_loss = (start_loss + end_loss) / 2
  1063. if not return_dict:
  1064. output = (start_logits, end_logits) + transformer_outputs[1:]
  1065. return ((total_loss,) + output) if total_loss is not None else output
  1066. return QuestionAnsweringModelOutput(
  1067. loss=total_loss,
  1068. start_logits=start_logits,
  1069. end_logits=end_logits,
  1070. hidden_states=transformer_outputs.hidden_states,
  1071. attentions=transformer_outputs.attentions,
  1072. )
  1073. @auto_docstring
  1074. class XLMForQuestionAnswering(XLMPreTrainedModel):
  1075. def __init__(self, config):
  1076. super().__init__(config)
  1077. self.transformer = XLMModel(config)
  1078. self.qa_outputs = XLMSQuADHead(config)
  1079. # Initialize weights and apply final processing
  1080. self.post_init()
  1081. @auto_docstring
  1082. def forward(
  1083. self,
  1084. input_ids: torch.Tensor | None = None,
  1085. attention_mask: torch.Tensor | None = None,
  1086. langs: torch.Tensor | None = None,
  1087. token_type_ids: torch.Tensor | None = None,
  1088. position_ids: torch.Tensor | None = None,
  1089. lengths: torch.Tensor | None = None,
  1090. cache: dict[str, torch.Tensor] | None = None,
  1091. inputs_embeds: torch.Tensor | None = None,
  1092. start_positions: torch.Tensor | None = None,
  1093. end_positions: torch.Tensor | None = None,
  1094. is_impossible: torch.Tensor | None = None,
  1095. cls_index: torch.Tensor | None = None,
  1096. p_mask: torch.Tensor | None = None,
  1097. output_attentions: bool | None = None,
  1098. output_hidden_states: bool | None = None,
  1099. return_dict: bool | None = None,
  1100. **kwargs,
  1101. ) -> tuple | XLMForQuestionAnsweringOutput:
  1102. r"""
  1103. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1104. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  1105. languages ids which can be obtained from the language names by using two conversion mappings provided in
  1106. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  1107. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  1108. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  1109. See usage examples detailed in the [multilingual documentation](../multilingual).
  1110. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1111. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  1112. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  1113. `[0, ..., input_ids.size(-1)]`.
  1114. cache (`dict[str, torch.FloatTensor]`, *optional*):
  1115. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  1116. decoding.
  1117. is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1118. Labels whether a question has an answer or no answer (SQuAD 2.0)
  1119. cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1120. Labels for position (index) of the classification token to use as input for computing plausibility of the
  1121. answer.
  1122. p_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1123. Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...). 1.0 means token should be
  1124. masked. 0.0 mean token is not masked.
  1125. Example:
  1126. ```python
  1127. >>> from transformers import AutoTokenizer, XLMForQuestionAnswering
  1128. >>> import torch
  1129. >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-mlm-en-2048")
  1130. >>> model = XLMForQuestionAnswering.from_pretrained("FacebookAI/xlm-mlm-en-2048")
  1131. >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(
  1132. ... 0
  1133. ... ) # Batch size 1
  1134. >>> start_positions = torch.tensor([1])
  1135. >>> end_positions = torch.tensor([3])
  1136. >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
  1137. >>> loss = outputs.loss
  1138. ```"""
  1139. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1140. transformer_outputs = self.transformer(
  1141. input_ids,
  1142. attention_mask=attention_mask,
  1143. langs=langs,
  1144. token_type_ids=token_type_ids,
  1145. position_ids=position_ids,
  1146. lengths=lengths,
  1147. cache=cache,
  1148. inputs_embeds=inputs_embeds,
  1149. output_attentions=output_attentions,
  1150. output_hidden_states=output_hidden_states,
  1151. return_dict=return_dict,
  1152. )
  1153. output = transformer_outputs[0]
  1154. outputs = self.qa_outputs(
  1155. output,
  1156. start_positions=start_positions,
  1157. end_positions=end_positions,
  1158. cls_index=cls_index,
  1159. is_impossible=is_impossible,
  1160. p_mask=p_mask,
  1161. return_dict=return_dict,
  1162. )
  1163. if not return_dict:
  1164. return outputs + transformer_outputs[1:]
  1165. return XLMForQuestionAnsweringOutput(
  1166. loss=outputs.loss,
  1167. start_top_log_probs=outputs.start_top_log_probs,
  1168. start_top_index=outputs.start_top_index,
  1169. end_top_log_probs=outputs.end_top_log_probs,
  1170. end_top_index=outputs.end_top_index,
  1171. cls_logits=outputs.cls_logits,
  1172. hidden_states=transformer_outputs.hidden_states,
  1173. attentions=transformer_outputs.attentions,
  1174. )
  1175. @auto_docstring
  1176. class XLMForTokenClassification(XLMPreTrainedModel):
  1177. def __init__(self, config):
  1178. super().__init__(config)
  1179. self.num_labels = config.num_labels
  1180. self.transformer = XLMModel(config)
  1181. self.dropout = nn.Dropout(config.dropout)
  1182. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1183. # Initialize weights and apply final processing
  1184. self.post_init()
  1185. @auto_docstring
  1186. def forward(
  1187. self,
  1188. input_ids: torch.Tensor | None = None,
  1189. attention_mask: torch.Tensor | None = None,
  1190. langs: torch.Tensor | None = None,
  1191. token_type_ids: torch.Tensor | None = None,
  1192. position_ids: torch.Tensor | None = None,
  1193. lengths: torch.Tensor | None = None,
  1194. cache: dict[str, torch.Tensor] | None = None,
  1195. inputs_embeds: torch.Tensor | None = None,
  1196. labels: torch.Tensor | None = None,
  1197. output_attentions: bool | None = None,
  1198. output_hidden_states: bool | None = None,
  1199. return_dict: bool | None = None,
  1200. **kwargs,
  1201. ) -> tuple | TokenClassifierOutput:
  1202. r"""
  1203. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1204. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  1205. languages ids which can be obtained from the language names by using two conversion mappings provided in
  1206. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  1207. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  1208. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  1209. See usage examples detailed in the [multilingual documentation](../multilingual).
  1210. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1211. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  1212. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  1213. `[0, ..., input_ids.size(-1)]`.
  1214. cache (`dict[str, torch.FloatTensor]`, *optional*):
  1215. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  1216. decoding.
  1217. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1218. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1219. """
  1220. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1221. outputs = self.transformer(
  1222. input_ids,
  1223. attention_mask=attention_mask,
  1224. langs=langs,
  1225. token_type_ids=token_type_ids,
  1226. position_ids=position_ids,
  1227. lengths=lengths,
  1228. cache=cache,
  1229. inputs_embeds=inputs_embeds,
  1230. output_attentions=output_attentions,
  1231. output_hidden_states=output_hidden_states,
  1232. return_dict=return_dict,
  1233. )
  1234. sequence_output = outputs[0]
  1235. sequence_output = self.dropout(sequence_output)
  1236. logits = self.classifier(sequence_output)
  1237. loss = None
  1238. if labels is not None:
  1239. loss_fct = CrossEntropyLoss()
  1240. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1241. if not return_dict:
  1242. output = (logits,) + outputs[1:]
  1243. return ((loss,) + output) if loss is not None else output
  1244. return TokenClassifierOutput(
  1245. loss=loss,
  1246. logits=logits,
  1247. hidden_states=outputs.hidden_states,
  1248. attentions=outputs.attentions,
  1249. )
  1250. @auto_docstring
  1251. class XLMForMultipleChoice(XLMPreTrainedModel):
  1252. def __init__(self, config, *inputs, **kwargs):
  1253. super().__init__(config, *inputs, **kwargs)
  1254. self.transformer = XLMModel(config)
  1255. self.sequence_summary = XLMSequenceSummary(config)
  1256. self.logits_proj = nn.Linear(config.num_labels, 1)
  1257. # Initialize weights and apply final processing
  1258. self.post_init()
  1259. @auto_docstring
  1260. def forward(
  1261. self,
  1262. input_ids: torch.Tensor | None = None,
  1263. attention_mask: torch.Tensor | None = None,
  1264. langs: torch.Tensor | None = None,
  1265. token_type_ids: torch.Tensor | None = None,
  1266. position_ids: torch.Tensor | None = None,
  1267. lengths: torch.Tensor | None = None,
  1268. cache: dict[str, torch.Tensor] | None = None,
  1269. inputs_embeds: torch.Tensor | None = None,
  1270. labels: torch.Tensor | None = None,
  1271. output_attentions: bool | None = None,
  1272. output_hidden_states: bool | None = None,
  1273. return_dict: bool | None = None,
  1274. **kwargs,
  1275. ) -> tuple | MultipleChoiceModelOutput:
  1276. r"""
  1277. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  1278. Indices of input sequence tokens in the vocabulary.
  1279. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1280. [`PreTrainedTokenizer.__call__`] for details.
  1281. [What are input IDs?](../glossary#input-ids)
  1282. langs (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1283. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  1284. languages ids which can be obtained from the language names by using two conversion mappings provided in
  1285. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  1286. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  1287. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  1288. See usage examples detailed in the [multilingual documentation](../multilingual).
  1289. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1290. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1291. 1]`:
  1292. - 0 corresponds to a *sentence A* token,
  1293. - 1 corresponds to a *sentence B* token.
  1294. [What are token type IDs?](../glossary#token-type-ids)
  1295. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1296. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1297. config.max_position_embeddings - 1]`.
  1298. [What are position IDs?](../glossary#position-ids)
  1299. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1300. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  1301. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  1302. `[0, ..., input_ids.size(-1)]`.
  1303. cache (`dict[str, torch.FloatTensor]`, *optional*):
  1304. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  1305. decoding.
  1306. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  1307. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1308. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1309. model's internal embedding lookup matrix.
  1310. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1311. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1312. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1313. `input_ids` above)
  1314. """
  1315. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1316. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1317. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1318. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1319. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1320. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1321. langs = langs.view(-1, langs.size(-1)) if langs is not None else None
  1322. inputs_embeds = (
  1323. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1324. if inputs_embeds is not None
  1325. else None
  1326. )
  1327. if lengths is not None:
  1328. logger.warning(
  1329. "The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the "
  1330. "attention mask instead."
  1331. )
  1332. lengths = None
  1333. transformer_outputs = self.transformer(
  1334. input_ids=input_ids,
  1335. attention_mask=attention_mask,
  1336. langs=langs,
  1337. token_type_ids=token_type_ids,
  1338. position_ids=position_ids,
  1339. lengths=lengths,
  1340. cache=cache,
  1341. inputs_embeds=inputs_embeds,
  1342. output_attentions=output_attentions,
  1343. output_hidden_states=output_hidden_states,
  1344. return_dict=return_dict,
  1345. )
  1346. output = transformer_outputs[0]
  1347. logits = self.sequence_summary(output)
  1348. logits = self.logits_proj(logits)
  1349. reshaped_logits = logits.view(-1, num_choices)
  1350. loss = None
  1351. if labels is not None:
  1352. loss_fct = CrossEntropyLoss()
  1353. loss = loss_fct(reshaped_logits, labels)
  1354. if not return_dict:
  1355. output = (reshaped_logits,) + transformer_outputs[1:]
  1356. return ((loss,) + output) if loss is not None else output
  1357. return MultipleChoiceModelOutput(
  1358. loss=loss,
  1359. logits=reshaped_logits,
  1360. hidden_states=transformer_outputs.hidden_states,
  1361. attentions=transformer_outputs.attentions,
  1362. )
  1363. __all__ = [
  1364. "XLMForMultipleChoice",
  1365. "XLMForQuestionAnswering",
  1366. "XLMForQuestionAnsweringSimple",
  1367. "XLMForSequenceClassification",
  1368. "XLMForTokenClassification",
  1369. "XLMModel",
  1370. "XLMPreTrainedModel",
  1371. "XLMWithLMHeadModel",
  1372. ]