modeling_flaubert.py 76 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651
  1. # Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Flaubert model, based on XLM."""
  15. import math
  16. from collections.abc import Callable
  17. from dataclasses import dataclass
  18. import numpy as np
  19. import torch
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ... import initialization as init
  23. from ...activations import gelu, get_activation
  24. from ...cache_utils import DynamicCache, EncoderDecoderCache
  25. from ...generation import GenerationMixin
  26. from ...modeling_outputs import (
  27. BaseModelOutput,
  28. MaskedLMOutput,
  29. MultipleChoiceModelOutput,
  30. QuestionAnsweringModelOutput,
  31. SequenceClassifierOutput,
  32. TokenClassifierOutput,
  33. )
  34. from ...modeling_utils import PreTrainedModel
  35. from ...pytorch_utils import apply_chunking_to_forward
  36. from ...utils import ModelOutput, auto_docstring, logging
  37. from .configuration_flaubert import FlaubertConfig
  38. logger = logging.get_logger(__name__)
  39. # Copied from transformers.models.xlm.modeling_xlm.create_sinusoidal_embeddings
  40. def create_sinusoidal_embeddings(n_pos, dim, out):
  41. position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
  42. out.requires_grad = False
  43. out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
  44. out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
  45. out.detach_()
  46. return out
  47. # Copied from transformers.models.xlm.modeling_xlm.get_masks
  48. def get_masks(slen, lengths, causal, padding_mask=None):
  49. """
  50. Generate hidden states mask, and optionally an attention mask.
  51. """
  52. alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
  53. if padding_mask is not None:
  54. mask = padding_mask
  55. else:
  56. assert lengths.max().item() <= slen
  57. mask = alen < lengths[:, None]
  58. # attention mask is the same as mask, or triangular inferior attention (causal)
  59. bs = lengths.size(0)
  60. if causal:
  61. attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
  62. else:
  63. attn_mask = mask
  64. # sanity check
  65. assert mask.size() == (bs, slen)
  66. assert causal is False or attn_mask.size() == (bs, slen, slen)
  67. return mask, attn_mask
  68. # Copied from transformers.models.xlm.modeling_xlm.MultiHeadAttention
  69. class MultiHeadAttention(nn.Module):
  70. def __init__(self, n_heads, dim, config, layer_idx: int = 0):
  71. super().__init__()
  72. self.layer_id = layer_idx
  73. self.dim = dim
  74. self.n_heads = n_heads
  75. self.head_dim = dim // n_heads
  76. self.dropout = config.attention_dropout
  77. assert self.dim % self.n_heads == 0
  78. self.q_lin = nn.Linear(dim, dim)
  79. self.k_lin = nn.Linear(dim, dim)
  80. self.v_lin = nn.Linear(dim, dim)
  81. self.out_lin = nn.Linear(dim, dim)
  82. def forward(
  83. self,
  84. input,
  85. mask,
  86. kv=None,
  87. cache=None,
  88. output_attentions=False,
  89. **kwargs,
  90. ):
  91. """
  92. Self-attention (if kv is None) or attention over source sentence (provided by kv).
  93. """
  94. # Input is (bs, qlen, dim)
  95. # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
  96. bs, qlen, dim = input.size()
  97. is_cross_attention = kv is not None
  98. mask_reshape = (bs, 1, qlen, -1) if mask.dim() == 3 else (bs, 1, 1, -1)
  99. q = self.q_lin(input).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
  100. if cache is not None:
  101. if isinstance(cache, EncoderDecoderCache):
  102. is_updated = cache.is_updated.get(self.layer_id)
  103. if is_cross_attention:
  104. # after the first generated id, we can subsequently re-use all key/value_states from cache
  105. curr_past_key_values = cache.cross_attention_cache
  106. else:
  107. curr_past_key_values = cache.self_attention_cache
  108. else:
  109. curr_past_key_values = cache
  110. current_states = kv if is_cross_attention else input
  111. if is_cross_attention and cache is not None and is_updated:
  112. # reuse k,v, cross_attentions
  113. k = curr_past_key_values.key_cache[self.layer_id]
  114. v = curr_past_key_values.value_cache[self.layer_id]
  115. else:
  116. k = self.k_lin(current_states)
  117. v = self.v_lin(current_states)
  118. k = k.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
  119. v = v.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
  120. if cache is not None:
  121. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  122. k, v = curr_past_key_values.update(k, v, self.layer_id)
  123. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  124. if is_cross_attention:
  125. cache.is_updated[self.layer_id] = True
  126. q = q / math.sqrt(self.head_dim) # (bs, n_heads, qlen, head_dim)
  127. scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
  128. mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
  129. scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
  130. weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
  131. weights = nn.functional.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
  132. context = torch.matmul(weights, v) # (bs, n_heads, qlen, head_dim)
  133. context = context.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.head_dim)
  134. outputs = (self.out_lin(context),)
  135. if output_attentions:
  136. outputs = outputs + (weights,)
  137. return outputs
  138. # Copied from transformers.models.xlm.modeling_xlm.TransformerFFN
  139. class TransformerFFN(nn.Module):
  140. def __init__(self, in_dim, dim_hidden, out_dim, config):
  141. super().__init__()
  142. self.dropout = config.dropout
  143. self.lin1 = nn.Linear(in_dim, dim_hidden)
  144. self.lin2 = nn.Linear(dim_hidden, out_dim)
  145. self.act = gelu if config.gelu_activation else nn.functional.relu
  146. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  147. self.seq_len_dim = 1
  148. def forward(self, input):
  149. return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
  150. def ff_chunk(self, input):
  151. x = self.lin1(input)
  152. x = self.act(x)
  153. x = self.lin2(x)
  154. x = nn.functional.dropout(x, p=self.dropout, training=self.training)
  155. return x
  156. @auto_docstring(
  157. custom_intro="""
  158. The bare Flaubert Model transformer outputting raw hidden-states without any specific head on top.
  159. """
  160. )
  161. # Copied from transformers.models.xlm.modeling_xlm.XLMPredLayer with XLM->Flaubert
  162. class FlaubertPredLayer(nn.Module):
  163. """
  164. Prediction layer (cross_entropy or adaptive_softmax).
  165. """
  166. def __init__(self, config):
  167. super().__init__()
  168. self.asm = config.asm
  169. self.n_words = config.n_words
  170. self.pad_index = config.pad_index
  171. dim = config.emb_dim
  172. if config.asm is False:
  173. self.proj = nn.Linear(dim, config.n_words, bias=True)
  174. else:
  175. self.proj = nn.AdaptiveLogSoftmaxWithLoss(
  176. in_features=dim,
  177. n_classes=config.n_words,
  178. cutoffs=config.asm_cutoffs,
  179. div_value=config.asm_div_value,
  180. head_bias=True, # default is False
  181. )
  182. def forward(self, x, y=None):
  183. """Compute the loss, and optionally the scores."""
  184. outputs = ()
  185. if self.asm is False:
  186. scores = self.proj(x)
  187. outputs = (scores,) + outputs
  188. if y is not None:
  189. loss = nn.functional.cross_entropy(scores.view(-1, self.n_words), y.view(-1), reduction="mean")
  190. outputs = (loss,) + outputs
  191. else:
  192. scores = self.proj.log_prob(x)
  193. outputs = (scores,) + outputs
  194. if y is not None:
  195. _, loss = self.proj(x, y)
  196. outputs = (loss,) + outputs
  197. return outputs
  198. @dataclass
  199. @auto_docstring(
  200. custom_intro="""
  201. Base class for outputs of question answering models using a [`~modeling_utils.FlaubertSQuADHead`].
  202. """
  203. )
  204. # Copied from transformers.models.xlm.modeling_xlm.XLMSquadHeadOutput with XLM->Flaubert
  205. class FlaubertSquadHeadOutput(ModelOutput):
  206. r"""
  207. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
  208. Classification loss as the sum of start token, end token (and is_impossible if provided) classification
  209. losses.
  210. start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  211. Log probabilities for the top config.start_n_top start token possibilities (beam-search).
  212. start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  213. Indices for the top config.start_n_top start token possibilities (beam-search).
  214. end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  215. Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
  216. (beam-search).
  217. end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  218. Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
  219. cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  220. Log probabilities for the `is_impossible` label of the answers.
  221. """
  222. loss: torch.FloatTensor | None = None
  223. start_top_log_probs: torch.FloatTensor | None = None
  224. start_top_index: torch.LongTensor | None = None
  225. end_top_log_probs: torch.FloatTensor | None = None
  226. end_top_index: torch.LongTensor | None = None
  227. cls_logits: torch.FloatTensor | None = None
  228. # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerStartLogits with XLM->Flaubert
  229. class FlaubertPoolerStartLogits(nn.Module):
  230. """
  231. Compute SQuAD start logits from sequence hidden states.
  232. Args:
  233. config ([`FlaubertConfig`]):
  234. The config used by the model, will be used to grab the `hidden_size` of the model.
  235. """
  236. def __init__(self, config: FlaubertConfig):
  237. super().__init__()
  238. self.dense = nn.Linear(config.hidden_size, 1)
  239. def forward(self, hidden_states: torch.FloatTensor, p_mask: torch.FloatTensor | None = None) -> torch.FloatTensor:
  240. """
  241. Args:
  242. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  243. The final hidden states of the model.
  244. p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
  245. Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
  246. should be masked.
  247. Returns:
  248. `torch.FloatTensor`: The start logits for SQuAD.
  249. """
  250. x = self.dense(hidden_states).squeeze(-1)
  251. if p_mask is not None:
  252. if p_mask.dtype == torch.float16:
  253. x = x * (1 - p_mask) - 65500 * p_mask
  254. else:
  255. x = x * (1 - p_mask) - 1e30 * p_mask
  256. return x
  257. # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerEndLogits with XLM->Flaubert
  258. class FlaubertPoolerEndLogits(nn.Module):
  259. """
  260. Compute SQuAD end logits from sequence hidden states.
  261. Args:
  262. config ([`FlaubertConfig`]):
  263. The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
  264. to use.
  265. """
  266. def __init__(self, config: FlaubertConfig):
  267. super().__init__()
  268. self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
  269. self.activation = nn.Tanh()
  270. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  271. self.dense_1 = nn.Linear(config.hidden_size, 1)
  272. def forward(
  273. self,
  274. hidden_states: torch.FloatTensor,
  275. start_states: torch.FloatTensor | None = None,
  276. start_positions: torch.LongTensor | None = None,
  277. p_mask: torch.FloatTensor | None = None,
  278. ) -> torch.FloatTensor:
  279. """
  280. Args:
  281. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  282. The final hidden states of the model.
  283. start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
  284. The hidden states of the first tokens for the labeled span.
  285. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  286. The position of the first token for the labeled span.
  287. p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
  288. Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
  289. should be masked.
  290. <Tip>
  291. One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
  292. `start_states`.
  293. </Tip>
  294. Returns:
  295. `torch.FloatTensor`: The end logits for SQuAD.
  296. """
  297. assert start_states is not None or start_positions is not None, (
  298. "One of start_states, start_positions should be not None"
  299. )
  300. if start_positions is not None:
  301. slen, hsz = hidden_states.shape[-2:]
  302. start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  303. start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
  304. start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
  305. x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
  306. x = self.activation(x)
  307. x = self.LayerNorm(x)
  308. x = self.dense_1(x).squeeze(-1)
  309. if p_mask is not None:
  310. if p_mask.dtype == torch.float16:
  311. x = x * (1 - p_mask) - 65500 * p_mask
  312. else:
  313. x = x * (1 - p_mask) - 1e30 * p_mask
  314. return x
  315. # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerAnswerClass with XLM->Flaubert
  316. class FlaubertPoolerAnswerClass(nn.Module):
  317. """
  318. Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
  319. Args:
  320. config ([`FlaubertConfig`]):
  321. The config used by the model, will be used to grab the `hidden_size` of the model.
  322. """
  323. def __init__(self, config: FlaubertConfig):
  324. super().__init__()
  325. self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
  326. self.activation = nn.Tanh()
  327. self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
  328. def forward(
  329. self,
  330. hidden_states: torch.FloatTensor,
  331. start_states: torch.FloatTensor | None = None,
  332. start_positions: torch.LongTensor | None = None,
  333. cls_index: torch.LongTensor | None = None,
  334. ) -> torch.FloatTensor:
  335. """
  336. Args:
  337. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  338. The final hidden states of the model.
  339. start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
  340. The hidden states of the first tokens for the labeled span.
  341. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  342. The position of the first token for the labeled span.
  343. cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  344. Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
  345. <Tip>
  346. One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
  347. `start_states`.
  348. </Tip>
  349. Returns:
  350. `torch.FloatTensor`: The SQuAD 2.0 answer class.
  351. """
  352. # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
  353. hsz = hidden_states.shape[-1]
  354. assert start_states is not None or start_positions is not None, (
  355. "One of start_states, start_positions should be not None"
  356. )
  357. if start_positions is not None:
  358. start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  359. start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
  360. if cls_index is not None:
  361. cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  362. cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
  363. else:
  364. cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
  365. x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
  366. x = self.activation(x)
  367. x = self.dense_1(x).squeeze(-1)
  368. return x
  369. # Copied from transformers.models.xlm.modeling_xlm.XLMSQuADHead with XLM->Flaubert
  370. class FlaubertSQuADHead(nn.Module):
  371. r"""
  372. A SQuAD head inspired by XLNet.
  373. Args:
  374. config ([`FlaubertConfig`]):
  375. The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
  376. to use.
  377. """
  378. def __init__(self, config: FlaubertConfig):
  379. super().__init__()
  380. self.start_n_top = config.start_n_top
  381. self.end_n_top = config.end_n_top
  382. self.start_logits = FlaubertPoolerStartLogits(config)
  383. self.end_logits = FlaubertPoolerEndLogits(config)
  384. self.answer_class = FlaubertPoolerAnswerClass(config)
  385. @auto_docstring
  386. def forward(
  387. self,
  388. hidden_states: torch.FloatTensor,
  389. start_positions: torch.LongTensor | None = None,
  390. end_positions: torch.LongTensor | None = None,
  391. cls_index: torch.LongTensor | None = None,
  392. is_impossible: torch.LongTensor | None = None,
  393. p_mask: torch.FloatTensor | None = None,
  394. return_dict: bool = False,
  395. ) -> FlaubertSquadHeadOutput | tuple[torch.FloatTensor]:
  396. r"""
  397. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  398. Final hidden states of the model on the sequence tokens.
  399. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  400. Positions of the first token for the labeled span.
  401. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  402. Positions of the last token for the labeled span.
  403. cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  404. Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
  405. is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  406. Whether the question has a possible answer in the paragraph or not.
  407. p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
  408. Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
  409. should be masked.
  410. """
  411. start_logits = self.start_logits(hidden_states, p_mask=p_mask)
  412. if start_positions is not None and end_positions is not None:
  413. # If we are on multi-GPU, let's remove the dimension added by batch splitting
  414. for x in (start_positions, end_positions, cls_index, is_impossible):
  415. if x is not None and x.dim() > 1:
  416. x.squeeze_(-1)
  417. # during training, compute the end logits based on the ground truth of the start position
  418. end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
  419. loss_fct = CrossEntropyLoss()
  420. start_loss = loss_fct(start_logits, start_positions)
  421. end_loss = loss_fct(end_logits, end_positions)
  422. total_loss = (start_loss + end_loss) / 2
  423. if cls_index is not None and is_impossible is not None:
  424. # Predict answerability from the representation of CLS and START
  425. cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
  426. loss_fct_cls = nn.BCEWithLogitsLoss()
  427. cls_loss = loss_fct_cls(cls_logits, is_impossible)
  428. # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
  429. total_loss += cls_loss * 0.5
  430. return FlaubertSquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
  431. else:
  432. # during inference, compute the end logits based on beam search
  433. bsz, slen, hsz = hidden_states.size()
  434. start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
  435. start_top_log_probs, start_top_index = torch.topk(
  436. start_log_probs, self.start_n_top, dim=-1
  437. ) # shape (bsz, start_n_top)
  438. start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
  439. start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
  440. start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
  441. hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
  442. start_states
  443. ) # shape (bsz, slen, start_n_top, hsz)
  444. p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
  445. end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
  446. end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
  447. end_top_log_probs, end_top_index = torch.topk(
  448. end_log_probs, self.end_n_top, dim=1
  449. ) # shape (bsz, end_n_top, start_n_top)
  450. end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
  451. end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
  452. start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
  453. cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
  454. if not return_dict:
  455. return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
  456. else:
  457. return FlaubertSquadHeadOutput(
  458. start_top_log_probs=start_top_log_probs,
  459. start_top_index=start_top_index,
  460. end_top_log_probs=end_top_log_probs,
  461. end_top_index=end_top_index,
  462. cls_logits=cls_logits,
  463. )
  464. # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Flaubert
  465. class FlaubertSequenceSummary(nn.Module):
  466. r"""
  467. Compute a single vector summary of a sequence hidden states.
  468. Args:
  469. config ([`FlaubertConfig`]):
  470. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  471. config class of your model for the default values it uses):
  472. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  473. - `"last"` -- Take the last token hidden state (like XLNet)
  474. - `"first"` -- Take the first token hidden state (like Bert)
  475. - `"mean"` -- Take the mean of all tokens hidden states
  476. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  477. - `"attn"` -- Not implemented now, use multi-head attention
  478. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  479. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  480. (otherwise to `config.hidden_size`).
  481. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  482. another string or `None` will add no activation.
  483. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  484. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  485. """
  486. def __init__(self, config: FlaubertConfig):
  487. super().__init__()
  488. self.summary_type = getattr(config, "summary_type", "last")
  489. if self.summary_type == "attn":
  490. # We should use a standard multi-head attention module with absolute positional embedding for that.
  491. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  492. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  493. raise NotImplementedError
  494. self.summary = nn.Identity()
  495. if hasattr(config, "summary_use_proj") and config.summary_use_proj:
  496. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  497. num_classes = config.num_labels
  498. else:
  499. num_classes = config.hidden_size
  500. self.summary = nn.Linear(config.hidden_size, num_classes)
  501. activation_string = getattr(config, "summary_activation", None)
  502. self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
  503. self.first_dropout = nn.Identity()
  504. if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
  505. self.first_dropout = nn.Dropout(config.summary_first_dropout)
  506. self.last_dropout = nn.Identity()
  507. if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
  508. self.last_dropout = nn.Dropout(config.summary_last_dropout)
  509. def forward(
  510. self, hidden_states: torch.FloatTensor, cls_index: torch.LongTensor | None = None
  511. ) -> torch.FloatTensor:
  512. """
  513. Compute a single vector summary of a sequence hidden states.
  514. Args:
  515. hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
  516. The hidden states of the last layer.
  517. cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
  518. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
  519. Returns:
  520. `torch.FloatTensor`: The summary of the sequence hidden states.
  521. """
  522. if self.summary_type == "last":
  523. output = hidden_states[:, -1]
  524. elif self.summary_type == "first":
  525. output = hidden_states[:, 0]
  526. elif self.summary_type == "mean":
  527. output = hidden_states.mean(dim=1)
  528. elif self.summary_type == "cls_index":
  529. if cls_index is None:
  530. cls_index = torch.full_like(
  531. hidden_states[..., :1, :],
  532. hidden_states.shape[-2] - 1,
  533. dtype=torch.long,
  534. )
  535. else:
  536. cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
  537. cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
  538. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  539. output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
  540. elif self.summary_type == "attn":
  541. raise NotImplementedError
  542. output = self.first_dropout(output)
  543. output = self.summary(output)
  544. output = self.activation(output)
  545. output = self.last_dropout(output)
  546. return output
  547. @auto_docstring
  548. # Copied from transformers.models.xlm.modeling_xlm.XLMPreTrainedModel with XLM->Flaubert
  549. class FlaubertPreTrainedModel(PreTrainedModel):
  550. config: FlaubertConfig
  551. base_model_prefix = "transformer"
  552. @property
  553. def dummy_inputs(self):
  554. inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
  555. attns_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
  556. if self.config.use_lang_emb and self.config.n_langs > 1:
  557. langs_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
  558. else:
  559. langs_list = None
  560. return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
  561. @torch.no_grad()
  562. def _init_weights(self, module):
  563. """Initialize the weights."""
  564. if isinstance(module, nn.Embedding):
  565. if self.config is not None and self.config.embed_init_std is not None:
  566. init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
  567. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  568. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  569. init.zeros_(module.weight[module.padding_idx])
  570. if isinstance(module, nn.Linear):
  571. if self.config is not None and self.config.init_std is not None:
  572. init.normal_(module.weight, mean=0, std=self.config.init_std)
  573. if module.bias is not None:
  574. init.constant_(module.bias, 0.0)
  575. if isinstance(module, nn.LayerNorm):
  576. init.zeros_(module.bias)
  577. init.ones_(module.weight)
  578. if isinstance(module, FlaubertModel):
  579. if self.config.sinusoidal_embeddings:
  580. init.copy_(
  581. module.position_embeddings.weight,
  582. create_sinusoidal_embeddings(
  583. self.config.max_position_embeddings,
  584. self.config.emb_dim,
  585. out=torch.empty_like(module.position_embeddings.weight),
  586. ),
  587. )
  588. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  589. @auto_docstring
  590. class FlaubertModel(FlaubertPreTrainedModel):
  591. def __init__(self, config): # , dico, is_encoder, with_output):
  592. super().__init__(config)
  593. # encoder / decoder, output layer
  594. self.is_encoder = config.is_encoder
  595. self.is_decoder = not config.is_encoder
  596. if self.is_decoder:
  597. raise NotImplementedError("Currently Flaubert can only be used as an encoder")
  598. # self.with_output = with_output
  599. self.causal = config.causal
  600. # dictionary / languages
  601. self.n_langs = config.n_langs
  602. self.use_lang_emb = config.use_lang_emb
  603. self.n_words = config.n_words
  604. self.eos_index = config.eos_index
  605. self.pad_index = config.pad_index
  606. # self.dico = dico
  607. # self.id2lang = config.id2lang
  608. # self.lang2id = config.lang2id
  609. # assert len(self.dico) == self.n_words
  610. # assert len(self.id2lang) == len(self.lang2id) == self.n_langs
  611. # model parameters
  612. self.dim = config.emb_dim # 512 by default
  613. self.hidden_dim = self.dim * 4 # 2048 by default
  614. self.n_heads = config.n_heads # 8 by default
  615. self.n_layers = config.n_layers
  616. self.dropout = config.dropout
  617. self.attention_dropout = config.attention_dropout
  618. assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads"
  619. # embeddings
  620. self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
  621. if config.n_langs > 1 and config.use_lang_emb:
  622. self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
  623. self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
  624. self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
  625. # transformer layers
  626. self.attentions = nn.ModuleList()
  627. self.layer_norm1 = nn.ModuleList()
  628. self.ffns = nn.ModuleList()
  629. self.layer_norm2 = nn.ModuleList()
  630. # if self.is_decoder:
  631. # self.layer_norm15 = nn.ModuleList()
  632. # self.encoder_attn = nn.ModuleList()
  633. for i in range(self.n_layers):
  634. self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config, layer_idx=i))
  635. self.layer_norm1.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
  636. # if self.is_decoder:
  637. # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
  638. # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
  639. self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
  640. self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
  641. self.layerdrop = getattr(config, "layerdrop", 0.0)
  642. self.pre_norm = getattr(config, "pre_norm", False)
  643. self.register_buffer(
  644. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  645. )
  646. # Initialize weights and apply final processing
  647. self.post_init()
  648. # Copied from transformers.models.xlm.modeling_xlm.XLMModel.get_input_embeddings
  649. def get_input_embeddings(self):
  650. return self.embeddings
  651. # Copied from transformers.models.xlm.modeling_xlm.XLMModel.set_input_embeddings
  652. def set_input_embeddings(self, new_embeddings):
  653. self.embeddings = new_embeddings
  654. @auto_docstring
  655. def forward(
  656. self,
  657. input_ids: torch.LongTensor | None = None,
  658. attention_mask: torch.FloatTensor | None = None,
  659. langs: torch.Tensor | None = None,
  660. token_type_ids: torch.LongTensor | None = None,
  661. position_ids: torch.LongTensor | None = None,
  662. lengths: torch.LongTensor | None = None,
  663. cache: dict[str, torch.FloatTensor] | None = None,
  664. inputs_embeds: torch.FloatTensor | None = None,
  665. output_attentions: bool | None = None,
  666. output_hidden_states: bool | None = None,
  667. return_dict: bool | None = None,
  668. **kwargs,
  669. ) -> tuple | BaseModelOutput:
  670. r"""
  671. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  672. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  673. languages ids which can be obtained from the language names by using two conversion mappings provided in
  674. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  675. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  676. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  677. See usage examples detailed in the [multilingual documentation](../multilingual).
  678. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  679. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  680. also use `attention_mask` for the same result (see above), kept here for compatibility. Indices selected in
  681. `[0, ..., input_ids.size(-1)]`:
  682. cache (`dict[str, torch.FloatTensor]`, *optional*):
  683. Dictionary strings to `torch.FloatTensor` that contains precomputed hidden-states (key and values in the
  684. attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential
  685. decoding. The dictionary object will be modified in-place during the forward pass to add newly computed
  686. hidden-states.
  687. """
  688. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  689. output_hidden_states = (
  690. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  691. )
  692. return_dict = return_dict if return_dict is not None else self.config.return_dict
  693. # removed: src_enc=None, src_len=None
  694. if input_ids is not None:
  695. bs, slen = input_ids.size()
  696. else:
  697. bs, slen = inputs_embeds.size()[:-1]
  698. device = input_ids.device if input_ids is not None else inputs_embeds.device
  699. if cache is None:
  700. cache = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  701. if lengths is None:
  702. if input_ids is not None:
  703. lengths = (input_ids != self.pad_index).sum(dim=1).long()
  704. else:
  705. lengths = torch.full((bs,), slen, device=device, dtype=torch.long)
  706. # mask = input_ids != self.pad_index
  707. # check inputs
  708. assert lengths.size(0) == bs
  709. assert lengths.max().item() <= slen
  710. # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
  711. # assert (src_enc is None) == (src_len is None)
  712. # if src_enc is not None:
  713. # assert self.is_decoder
  714. # assert src_enc.size(0) == bs
  715. # generate masks
  716. mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
  717. # if self.is_decoder and src_enc is not None:
  718. # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
  719. # Setting the position-ids to the registered buffer in constructor, it helps
  720. # when tracing the model without passing position-ids, solves
  721. # issues similar to issue #5664
  722. if position_ids is None:
  723. if hasattr(self, "position_ids"):
  724. position_ids = self.position_ids[:, :slen]
  725. position_ids = position_ids.expand((bs, slen))
  726. else:
  727. position_ids = torch.arange(slen, dtype=torch.long, device=device)
  728. position_ids = position_ids.unsqueeze(0).expand((bs, slen))
  729. else:
  730. assert position_ids.size() == (bs, slen) # (slen, bs)
  731. # position_ids = position_ids.transpose(0, 1)
  732. # langs
  733. if langs is not None:
  734. assert langs.size() == (bs, slen) # (slen, bs)
  735. # langs = langs.transpose(0, 1)
  736. # do not recompute cached elements
  737. if cache is not None and input_ids is not None:
  738. _slen = slen - cache.get_seq_length()
  739. input_ids = input_ids[:, -_slen:]
  740. position_ids = position_ids[:, -_slen:]
  741. if langs is not None:
  742. langs = langs[:, -_slen:]
  743. mask = mask[:, -_slen:]
  744. attn_mask = attn_mask[:, -_slen:]
  745. # embeddings
  746. if inputs_embeds is None:
  747. inputs_embeds = self.embeddings(input_ids)
  748. tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds)
  749. if langs is not None and self.use_lang_emb and self.config.n_langs > 1:
  750. tensor = tensor + self.lang_embeddings(langs)
  751. if token_type_ids is not None:
  752. tensor = tensor + self.embeddings(token_type_ids)
  753. tensor = self.layer_norm_emb(tensor)
  754. tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training)
  755. tensor *= mask.unsqueeze(-1).to(tensor.dtype)
  756. # transformer layers
  757. hidden_states = () if output_hidden_states else None
  758. attentions = () if output_attentions else None
  759. for i in range(self.n_layers):
  760. # LayerDrop
  761. if self.training:
  762. dropout_probability = torch.rand([])
  763. if dropout_probability < self.layerdrop:
  764. continue
  765. if output_hidden_states:
  766. hidden_states = hidden_states + (tensor,)
  767. # self attention
  768. if not self.pre_norm:
  769. attn_outputs = self.attentions[i](
  770. tensor,
  771. attn_mask,
  772. cache=cache,
  773. output_attentions=output_attentions,
  774. )
  775. attn = attn_outputs[0]
  776. if output_attentions:
  777. attentions = attentions + (attn_outputs[1],)
  778. attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
  779. tensor = tensor + attn
  780. tensor = self.layer_norm1[i](tensor)
  781. else:
  782. tensor_normalized = self.layer_norm1[i](tensor)
  783. attn_outputs = self.attentions[i](tensor_normalized, attn_mask, cache=cache[i])
  784. attn = attn_outputs[0]
  785. if output_attentions:
  786. attentions = attentions + (attn_outputs[1],)
  787. attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
  788. tensor = tensor + attn
  789. # FFN
  790. if not self.pre_norm:
  791. tensor = tensor + self.ffns[i](tensor)
  792. tensor = self.layer_norm2[i](tensor)
  793. else:
  794. tensor_normalized = self.layer_norm2[i](tensor)
  795. tensor = tensor + self.ffns[i](tensor_normalized)
  796. tensor *= mask.unsqueeze(-1).to(tensor.dtype)
  797. # Add last hidden state
  798. if output_hidden_states:
  799. hidden_states = hidden_states + (tensor,)
  800. if not return_dict:
  801. return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
  802. return BaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
  803. @auto_docstring(
  804. custom_intro="""
  805. The Flaubert Model transformer with a language modeling head on top (linear layer with weights tied to the input
  806. embeddings).
  807. """
  808. )
  809. class FlaubertWithLMHeadModel(FlaubertPreTrainedModel, GenerationMixin):
  810. _tied_weights_keys = {"pred_layer.proj.weight": "transformer.embeddings.weight"}
  811. def __init__(self, config):
  812. super().__init__(config)
  813. self.transformer = FlaubertModel(config)
  814. self.pred_layer = FlaubertPredLayer(config)
  815. # Initialize weights and apply final processing
  816. self.post_init()
  817. def get_output_embeddings(self):
  818. return self.pred_layer.proj
  819. def set_output_embeddings(self, new_embeddings):
  820. self.pred_layer.proj = new_embeddings
  821. def prepare_inputs_for_generation(self, input_ids, **kwargs):
  822. # Overwritten -- uses a language id
  823. mask_token_id = self.config.mask_token_id
  824. lang_id = self.config.lang_id
  825. effective_batch_size = input_ids.shape[0]
  826. mask_token = torch.full((effective_batch_size, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
  827. input_ids = torch.cat([input_ids, mask_token], dim=1)
  828. if lang_id is not None:
  829. langs = torch.full_like(input_ids, lang_id)
  830. else:
  831. langs = None
  832. return {"input_ids": input_ids, "langs": langs}
  833. @auto_docstring
  834. def forward(
  835. self,
  836. input_ids: torch.Tensor | None = None,
  837. attention_mask: torch.Tensor | None = None,
  838. langs: torch.Tensor | None = None,
  839. token_type_ids: torch.Tensor | None = None,
  840. position_ids: torch.Tensor | None = None,
  841. lengths: torch.Tensor | None = None,
  842. cache: dict[str, torch.Tensor] | None = None,
  843. inputs_embeds: torch.Tensor | None = None,
  844. labels: torch.Tensor | None = None,
  845. output_attentions: bool | None = None,
  846. output_hidden_states: bool | None = None,
  847. return_dict: bool | None = None,
  848. **kwargs,
  849. ) -> tuple | MaskedLMOutput:
  850. r"""
  851. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  852. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  853. languages ids which can be obtained from the language names by using two conversion mappings provided in
  854. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  855. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  856. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  857. See usage examples detailed in the [multilingual documentation](../multilingual).
  858. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  859. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  860. also use `attention_mask` for the same result (see above), kept here for compatibility. Indices selected in
  861. `[0, ..., input_ids.size(-1)]`:
  862. cache (`dict[str, torch.FloatTensor]`, *optional*):
  863. Dictionary strings to `torch.FloatTensor` that contains precomputed hidden-states (key and values in the
  864. attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential
  865. decoding. The dictionary object will be modified in-place during the forward pass to add newly computed
  866. hidden-states.
  867. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  868. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  869. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  870. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  871. """
  872. return_dict = return_dict if return_dict is not None else self.config.return_dict
  873. transformer_outputs = self.transformer(
  874. input_ids,
  875. attention_mask=attention_mask,
  876. langs=langs,
  877. token_type_ids=token_type_ids,
  878. position_ids=position_ids,
  879. lengths=lengths,
  880. cache=cache,
  881. inputs_embeds=inputs_embeds,
  882. output_attentions=output_attentions,
  883. output_hidden_states=output_hidden_states,
  884. return_dict=return_dict,
  885. )
  886. output = transformer_outputs[0]
  887. outputs = self.pred_layer(output, labels) # (loss, logits) or (logits,) depending on if labels are provided.
  888. if not return_dict:
  889. return outputs + transformer_outputs[1:]
  890. return MaskedLMOutput(
  891. loss=outputs[0] if labels is not None else None,
  892. logits=outputs[0] if labels is None else outputs[1],
  893. hidden_states=transformer_outputs.hidden_states,
  894. attentions=transformer_outputs.attentions,
  895. )
  896. @auto_docstring(
  897. custom_intro="""
  898. Flaubert Model with a sequence classification/regression head on top (a linear layer on top of the pooled output)
  899. e.g. for GLUE tasks.
  900. """
  901. )
  902. # Copied from transformers.models.xlm.modeling_xlm.XLMForSequenceClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
  903. class FlaubertForSequenceClassification(FlaubertPreTrainedModel):
  904. def __init__(self, config):
  905. super().__init__(config)
  906. self.num_labels = config.num_labels
  907. self.config = config
  908. self.transformer = FlaubertModel(config)
  909. self.sequence_summary = FlaubertSequenceSummary(config)
  910. # Initialize weights and apply final processing
  911. self.post_init()
  912. @auto_docstring
  913. def forward(
  914. self,
  915. input_ids: torch.Tensor | None = None,
  916. attention_mask: torch.Tensor | None = None,
  917. langs: torch.Tensor | None = None,
  918. token_type_ids: torch.Tensor | None = None,
  919. position_ids: torch.Tensor | None = None,
  920. lengths: torch.Tensor | None = None,
  921. cache: dict[str, torch.Tensor] | None = None,
  922. inputs_embeds: torch.Tensor | None = None,
  923. labels: torch.Tensor | None = None,
  924. output_attentions: bool | None = None,
  925. output_hidden_states: bool | None = None,
  926. return_dict: bool | None = None,
  927. **kwargs,
  928. ) -> tuple | SequenceClassifierOutput:
  929. r"""
  930. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  931. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  932. languages ids which can be obtained from the language names by using two conversion mappings provided in
  933. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  934. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  935. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  936. See usage examples detailed in the [multilingual documentation](../multilingual).
  937. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  938. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  939. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  940. `[0, ..., input_ids.size(-1)]`.
  941. cache (`dict[str, torch.FloatTensor]`, *optional*):
  942. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  943. decoding.
  944. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  945. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  946. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  947. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  948. """
  949. return_dict = return_dict if return_dict is not None else self.config.return_dict
  950. transformer_outputs = self.transformer(
  951. input_ids,
  952. attention_mask=attention_mask,
  953. langs=langs,
  954. token_type_ids=token_type_ids,
  955. position_ids=position_ids,
  956. lengths=lengths,
  957. cache=cache,
  958. inputs_embeds=inputs_embeds,
  959. output_attentions=output_attentions,
  960. output_hidden_states=output_hidden_states,
  961. return_dict=return_dict,
  962. )
  963. output = transformer_outputs[0]
  964. logits = self.sequence_summary(output)
  965. loss = None
  966. if labels is not None:
  967. if self.config.problem_type is None:
  968. if self.num_labels == 1:
  969. self.config.problem_type = "regression"
  970. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  971. self.config.problem_type = "single_label_classification"
  972. else:
  973. self.config.problem_type = "multi_label_classification"
  974. if self.config.problem_type == "regression":
  975. loss_fct = MSELoss()
  976. if self.num_labels == 1:
  977. loss = loss_fct(logits.squeeze(), labels.squeeze())
  978. else:
  979. loss = loss_fct(logits, labels)
  980. elif self.config.problem_type == "single_label_classification":
  981. loss_fct = CrossEntropyLoss()
  982. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  983. elif self.config.problem_type == "multi_label_classification":
  984. loss_fct = BCEWithLogitsLoss()
  985. loss = loss_fct(logits, labels)
  986. if not return_dict:
  987. output = (logits,) + transformer_outputs[1:]
  988. return ((loss,) + output) if loss is not None else output
  989. return SequenceClassifierOutput(
  990. loss=loss,
  991. logits=logits,
  992. hidden_states=transformer_outputs.hidden_states,
  993. attentions=transformer_outputs.attentions,
  994. )
  995. @auto_docstring
  996. # Copied from transformers.models.xlm.modeling_xlm.XLMForTokenClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
  997. class FlaubertForTokenClassification(FlaubertPreTrainedModel):
  998. def __init__(self, config):
  999. super().__init__(config)
  1000. self.num_labels = config.num_labels
  1001. self.transformer = FlaubertModel(config)
  1002. self.dropout = nn.Dropout(config.dropout)
  1003. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1004. # Initialize weights and apply final processing
  1005. self.post_init()
  1006. @auto_docstring
  1007. def forward(
  1008. self,
  1009. input_ids: torch.Tensor | None = None,
  1010. attention_mask: torch.Tensor | None = None,
  1011. langs: torch.Tensor | None = None,
  1012. token_type_ids: torch.Tensor | None = None,
  1013. position_ids: torch.Tensor | None = None,
  1014. lengths: torch.Tensor | None = None,
  1015. cache: dict[str, torch.Tensor] | None = None,
  1016. inputs_embeds: torch.Tensor | None = None,
  1017. labels: torch.Tensor | None = None,
  1018. output_attentions: bool | None = None,
  1019. output_hidden_states: bool | None = None,
  1020. return_dict: bool | None = None,
  1021. **kwargs,
  1022. ) -> tuple | TokenClassifierOutput:
  1023. r"""
  1024. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1025. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  1026. languages ids which can be obtained from the language names by using two conversion mappings provided in
  1027. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  1028. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  1029. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  1030. See usage examples detailed in the [multilingual documentation](../multilingual).
  1031. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1032. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  1033. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  1034. `[0, ..., input_ids.size(-1)]`.
  1035. cache (`dict[str, torch.FloatTensor]`, *optional*):
  1036. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  1037. decoding.
  1038. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1039. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1040. """
  1041. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1042. outputs = self.transformer(
  1043. input_ids,
  1044. attention_mask=attention_mask,
  1045. langs=langs,
  1046. token_type_ids=token_type_ids,
  1047. position_ids=position_ids,
  1048. lengths=lengths,
  1049. cache=cache,
  1050. inputs_embeds=inputs_embeds,
  1051. output_attentions=output_attentions,
  1052. output_hidden_states=output_hidden_states,
  1053. return_dict=return_dict,
  1054. )
  1055. sequence_output = outputs[0]
  1056. sequence_output = self.dropout(sequence_output)
  1057. logits = self.classifier(sequence_output)
  1058. loss = None
  1059. if labels is not None:
  1060. loss_fct = CrossEntropyLoss()
  1061. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1062. if not return_dict:
  1063. output = (logits,) + outputs[1:]
  1064. return ((loss,) + output) if loss is not None else output
  1065. return TokenClassifierOutput(
  1066. loss=loss,
  1067. logits=logits,
  1068. hidden_states=outputs.hidden_states,
  1069. attentions=outputs.attentions,
  1070. )
  1071. @auto_docstring(
  1072. custom_intro="""
  1073. Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
  1074. layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1075. """
  1076. )
  1077. # Copied from transformers.models.xlm.modeling_xlm.XLMForQuestionAnsweringSimple with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
  1078. class FlaubertForQuestionAnsweringSimple(FlaubertPreTrainedModel):
  1079. def __init__(self, config):
  1080. super().__init__(config)
  1081. self.transformer = FlaubertModel(config)
  1082. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1083. # Initialize weights and apply final processing
  1084. self.post_init()
  1085. @auto_docstring
  1086. def forward(
  1087. self,
  1088. input_ids: torch.Tensor | None = None,
  1089. attention_mask: torch.Tensor | None = None,
  1090. langs: torch.Tensor | None = None,
  1091. token_type_ids: torch.Tensor | None = None,
  1092. position_ids: torch.Tensor | None = None,
  1093. lengths: torch.Tensor | None = None,
  1094. cache: dict[str, torch.Tensor] | None = None,
  1095. inputs_embeds: torch.Tensor | None = None,
  1096. start_positions: torch.Tensor | None = None,
  1097. end_positions: torch.Tensor | None = None,
  1098. output_attentions: bool | None = None,
  1099. output_hidden_states: bool | None = None,
  1100. return_dict: bool | None = None,
  1101. **kwargs,
  1102. ) -> tuple | QuestionAnsweringModelOutput:
  1103. r"""
  1104. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1105. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  1106. languages ids which can be obtained from the language names by using two conversion mappings provided in
  1107. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  1108. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  1109. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  1110. See usage examples detailed in the [multilingual documentation](../multilingual).
  1111. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1112. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  1113. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  1114. `[0, ..., input_ids.size(-1)]`.
  1115. cache (`dict[str, torch.FloatTensor]`, *optional*):
  1116. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  1117. decoding.
  1118. """
  1119. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1120. transformer_outputs = self.transformer(
  1121. input_ids,
  1122. attention_mask=attention_mask,
  1123. langs=langs,
  1124. token_type_ids=token_type_ids,
  1125. position_ids=position_ids,
  1126. lengths=lengths,
  1127. cache=cache,
  1128. inputs_embeds=inputs_embeds,
  1129. output_attentions=output_attentions,
  1130. output_hidden_states=output_hidden_states,
  1131. return_dict=return_dict,
  1132. )
  1133. sequence_output = transformer_outputs[0]
  1134. logits = self.qa_outputs(sequence_output)
  1135. start_logits, end_logits = logits.split(1, dim=-1)
  1136. start_logits = start_logits.squeeze(-1).contiguous()
  1137. end_logits = end_logits.squeeze(-1).contiguous()
  1138. total_loss = None
  1139. if start_positions is not None and end_positions is not None:
  1140. # If we are on multi-GPU, split add a dimension
  1141. if len(start_positions.size()) > 1:
  1142. start_positions = start_positions.squeeze(-1)
  1143. if len(end_positions.size()) > 1:
  1144. end_positions = end_positions.squeeze(-1)
  1145. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1146. ignored_index = start_logits.size(1)
  1147. start_positions = start_positions.clamp(0, ignored_index)
  1148. end_positions = end_positions.clamp(0, ignored_index)
  1149. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1150. start_loss = loss_fct(start_logits, start_positions)
  1151. end_loss = loss_fct(end_logits, end_positions)
  1152. total_loss = (start_loss + end_loss) / 2
  1153. if not return_dict:
  1154. output = (start_logits, end_logits) + transformer_outputs[1:]
  1155. return ((total_loss,) + output) if total_loss is not None else output
  1156. return QuestionAnsweringModelOutput(
  1157. loss=total_loss,
  1158. start_logits=start_logits,
  1159. end_logits=end_logits,
  1160. hidden_states=transformer_outputs.hidden_states,
  1161. attentions=transformer_outputs.attentions,
  1162. )
  1163. @dataclass
  1164. @auto_docstring(
  1165. custom_intro="""
  1166. Base class for outputs of question answering models using a `SquadHead`.
  1167. """
  1168. )
  1169. # Copied from transformer.models.xlm.modeling_xlm.XLMForQuestionAnsweringOutput with XLM->Flaubert
  1170. class FlaubertForQuestionAnsweringOutput(ModelOutput):
  1171. r"""
  1172. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
  1173. Classification loss as the sum of start token, end token (and is_impossible if provided) classification
  1174. losses.
  1175. start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  1176. Log probabilities for the top config.start_n_top start token possibilities (beam-search).
  1177. start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  1178. Indices for the top config.start_n_top start token possibilities (beam-search).
  1179. end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  1180. Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
  1181. (beam-search).
  1182. end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  1183. Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
  1184. cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  1185. Log probabilities for the `is_impossible` label of the answers.
  1186. """
  1187. loss: torch.FloatTensor | None = None
  1188. start_top_log_probs: torch.FloatTensor | None = None
  1189. start_top_index: torch.LongTensor | None = None
  1190. end_top_log_probs: torch.FloatTensor | None = None
  1191. end_top_index: torch.LongTensor | None = None
  1192. cls_logits: torch.FloatTensor | None = None
  1193. hidden_states: tuple[torch.FloatTensor] | None = None
  1194. attentions: tuple[torch.FloatTensor] | None = None
  1195. @auto_docstring
  1196. # Copied from transformers.models.xlm.modeling_xlm.XLMForQuestionAnswering with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
  1197. class FlaubertForQuestionAnswering(FlaubertPreTrainedModel):
  1198. def __init__(self, config):
  1199. super().__init__(config)
  1200. self.transformer = FlaubertModel(config)
  1201. self.qa_outputs = FlaubertSQuADHead(config)
  1202. # Initialize weights and apply final processing
  1203. self.post_init()
  1204. @auto_docstring
  1205. def forward(
  1206. self,
  1207. input_ids: torch.Tensor | None = None,
  1208. attention_mask: torch.Tensor | None = None,
  1209. langs: torch.Tensor | None = None,
  1210. token_type_ids: torch.Tensor | None = None,
  1211. position_ids: torch.Tensor | None = None,
  1212. lengths: torch.Tensor | None = None,
  1213. cache: dict[str, torch.Tensor] | None = None,
  1214. inputs_embeds: torch.Tensor | None = None,
  1215. start_positions: torch.Tensor | None = None,
  1216. end_positions: torch.Tensor | None = None,
  1217. is_impossible: torch.Tensor | None = None,
  1218. cls_index: torch.Tensor | None = None,
  1219. p_mask: torch.Tensor | None = None,
  1220. output_attentions: bool | None = None,
  1221. output_hidden_states: bool | None = None,
  1222. return_dict: bool | None = None,
  1223. **kwargs,
  1224. ) -> tuple | FlaubertForQuestionAnsweringOutput:
  1225. r"""
  1226. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1227. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  1228. languages ids which can be obtained from the language names by using two conversion mappings provided in
  1229. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  1230. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  1231. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  1232. See usage examples detailed in the [multilingual documentation](../multilingual).
  1233. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1234. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  1235. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  1236. `[0, ..., input_ids.size(-1)]`.
  1237. cache (`dict[str, torch.FloatTensor]`, *optional*):
  1238. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  1239. decoding.
  1240. is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1241. Labels whether a question has an answer or no answer (SQuAD 2.0)
  1242. cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1243. Labels for position (index) of the classification token to use as input for computing plausibility of the
  1244. answer.
  1245. p_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1246. Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...). 1.0 means token should be
  1247. masked. 0.0 mean token is not masked.
  1248. Example:
  1249. ```python
  1250. >>> from transformers import AutoTokenizer, FlaubertForQuestionAnswering
  1251. >>> import torch
  1252. >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-mlm-en-2048")
  1253. >>> model = FlaubertForQuestionAnswering.from_pretrained("FacebookAI/xlm-mlm-en-2048")
  1254. >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(
  1255. ... 0
  1256. ... ) # Batch size 1
  1257. >>> start_positions = torch.tensor([1])
  1258. >>> end_positions = torch.tensor([3])
  1259. >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
  1260. >>> loss = outputs.loss
  1261. ```"""
  1262. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1263. transformer_outputs = self.transformer(
  1264. input_ids,
  1265. attention_mask=attention_mask,
  1266. langs=langs,
  1267. token_type_ids=token_type_ids,
  1268. position_ids=position_ids,
  1269. lengths=lengths,
  1270. cache=cache,
  1271. inputs_embeds=inputs_embeds,
  1272. output_attentions=output_attentions,
  1273. output_hidden_states=output_hidden_states,
  1274. return_dict=return_dict,
  1275. )
  1276. output = transformer_outputs[0]
  1277. outputs = self.qa_outputs(
  1278. output,
  1279. start_positions=start_positions,
  1280. end_positions=end_positions,
  1281. cls_index=cls_index,
  1282. is_impossible=is_impossible,
  1283. p_mask=p_mask,
  1284. return_dict=return_dict,
  1285. )
  1286. if not return_dict:
  1287. return outputs + transformer_outputs[1:]
  1288. return FlaubertForQuestionAnsweringOutput(
  1289. loss=outputs.loss,
  1290. start_top_log_probs=outputs.start_top_log_probs,
  1291. start_top_index=outputs.start_top_index,
  1292. end_top_log_probs=outputs.end_top_log_probs,
  1293. end_top_index=outputs.end_top_index,
  1294. cls_logits=outputs.cls_logits,
  1295. hidden_states=transformer_outputs.hidden_states,
  1296. attentions=transformer_outputs.attentions,
  1297. )
  1298. @auto_docstring
  1299. # Copied from transformers.models.xlm.modeling_xlm.XLMForMultipleChoice with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
  1300. class FlaubertForMultipleChoice(FlaubertPreTrainedModel):
  1301. def __init__(self, config, *inputs, **kwargs):
  1302. super().__init__(config, *inputs, **kwargs)
  1303. self.transformer = FlaubertModel(config)
  1304. self.sequence_summary = FlaubertSequenceSummary(config)
  1305. self.logits_proj = nn.Linear(config.num_labels, 1)
  1306. # Initialize weights and apply final processing
  1307. self.post_init()
  1308. @auto_docstring
  1309. def forward(
  1310. self,
  1311. input_ids: torch.Tensor | None = None,
  1312. attention_mask: torch.Tensor | None = None,
  1313. langs: torch.Tensor | None = None,
  1314. token_type_ids: torch.Tensor | None = None,
  1315. position_ids: torch.Tensor | None = None,
  1316. lengths: torch.Tensor | None = None,
  1317. cache: dict[str, torch.Tensor] | None = None,
  1318. inputs_embeds: torch.Tensor | None = None,
  1319. labels: torch.Tensor | None = None,
  1320. output_attentions: bool | None = None,
  1321. output_hidden_states: bool | None = None,
  1322. return_dict: bool | None = None,
  1323. **kwargs,
  1324. ) -> tuple | MultipleChoiceModelOutput:
  1325. r"""
  1326. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  1327. Indices of input sequence tokens in the vocabulary.
  1328. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1329. [`PreTrainedTokenizer.__call__`] for details.
  1330. [What are input IDs?](../glossary#input-ids)
  1331. langs (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1332. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  1333. languages ids which can be obtained from the language names by using two conversion mappings provided in
  1334. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  1335. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  1336. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  1337. See usage examples detailed in the [multilingual documentation](../multilingual).
  1338. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1339. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1340. 1]`:
  1341. - 0 corresponds to a *sentence A* token,
  1342. - 1 corresponds to a *sentence B* token.
  1343. [What are token type IDs?](../glossary#token-type-ids)
  1344. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1345. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1346. config.max_position_embeddings - 1]`.
  1347. [What are position IDs?](../glossary#position-ids)
  1348. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1349. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  1350. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  1351. `[0, ..., input_ids.size(-1)]`.
  1352. cache (`dict[str, torch.FloatTensor]`, *optional*):
  1353. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  1354. decoding.
  1355. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  1356. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1357. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1358. model's internal embedding lookup matrix.
  1359. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1360. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1361. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1362. `input_ids` above)
  1363. """
  1364. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1365. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1366. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1367. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1368. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1369. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1370. langs = langs.view(-1, langs.size(-1)) if langs is not None else None
  1371. inputs_embeds = (
  1372. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1373. if inputs_embeds is not None
  1374. else None
  1375. )
  1376. if lengths is not None:
  1377. logger.warning(
  1378. "The `lengths` parameter cannot be used with the Flaubert multiple choice models. Please use the "
  1379. "attention mask instead."
  1380. )
  1381. lengths = None
  1382. transformer_outputs = self.transformer(
  1383. input_ids=input_ids,
  1384. attention_mask=attention_mask,
  1385. langs=langs,
  1386. token_type_ids=token_type_ids,
  1387. position_ids=position_ids,
  1388. lengths=lengths,
  1389. cache=cache,
  1390. inputs_embeds=inputs_embeds,
  1391. output_attentions=output_attentions,
  1392. output_hidden_states=output_hidden_states,
  1393. return_dict=return_dict,
  1394. )
  1395. output = transformer_outputs[0]
  1396. logits = self.sequence_summary(output)
  1397. logits = self.logits_proj(logits)
  1398. reshaped_logits = logits.view(-1, num_choices)
  1399. loss = None
  1400. if labels is not None:
  1401. loss_fct = CrossEntropyLoss()
  1402. loss = loss_fct(reshaped_logits, labels)
  1403. if not return_dict:
  1404. output = (reshaped_logits,) + transformer_outputs[1:]
  1405. return ((loss,) + output) if loss is not None else output
  1406. return MultipleChoiceModelOutput(
  1407. loss=loss,
  1408. logits=reshaped_logits,
  1409. hidden_states=transformer_outputs.hidden_states,
  1410. attentions=transformer_outputs.attentions,
  1411. )
  1412. __all__ = [
  1413. "FlaubertForMultipleChoice",
  1414. "FlaubertForQuestionAnswering",
  1415. "FlaubertForQuestionAnsweringSimple",
  1416. "FlaubertForSequenceClassification",
  1417. "FlaubertForTokenClassification",
  1418. "FlaubertModel",
  1419. "FlaubertWithLMHeadModel",
  1420. "FlaubertPreTrainedModel",
  1421. ]