modeling_luke.py 94 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133
  1. # Copyright Studio Ousia and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch LUKE model."""
  15. import math
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN, gelu
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  24. from ...modeling_utils import PreTrainedModel
  25. from ...pytorch_utils import apply_chunking_to_forward
  26. from ...utils import ModelOutput, auto_docstring, logging
  27. from .configuration_luke import LukeConfig
  28. logger = logging.get_logger(__name__)
  29. @dataclass
  30. @auto_docstring(
  31. custom_intro="""
  32. Base class for outputs of the LUKE model.
  33. """
  34. )
  35. class BaseLukeModelOutputWithPooling(BaseModelOutputWithPooling):
  36. r"""
  37. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  38. Last layer hidden-state of the first token of the sequence (classification token) further processed by a
  39. Linear layer and a Tanh activation function.
  40. entity_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, entity_length, hidden_size)`):
  41. Sequence of entity hidden-states at the output of the last layer of the model.
  42. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  43. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  44. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  45. layer plus the initial entity embedding outputs.
  46. """
  47. entity_last_hidden_state: torch.FloatTensor | None = None
  48. entity_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  49. @dataclass
  50. @auto_docstring(
  51. custom_intro="""
  52. Base class for model's outputs, with potential hidden states and attentions.
  53. """
  54. )
  55. class BaseLukeModelOutput(BaseModelOutput):
  56. r"""
  57. entity_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, entity_length, hidden_size)`):
  58. Sequence of entity hidden-states at the output of the last layer of the model.
  59. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  60. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  61. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  62. layer plus the initial entity embedding outputs.
  63. """
  64. entity_last_hidden_state: torch.FloatTensor | None = None
  65. entity_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  66. @dataclass
  67. @auto_docstring(
  68. custom_intro="""
  69. Base class for model's outputs, with potential hidden states and attentions.
  70. """
  71. )
  72. class LukeMaskedLMOutput(ModelOutput):
  73. r"""
  74. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  75. The sum of masked language modeling (MLM) loss and entity prediction loss.
  76. mlm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  77. Masked language modeling (MLM) loss.
  78. mep_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  79. Masked entity prediction (MEP) loss.
  80. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  81. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  82. entity_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  83. Prediction scores of the entity prediction head (scores for each entity vocabulary token before SoftMax).
  84. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  85. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  86. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  87. layer plus the initial entity embedding outputs.
  88. """
  89. loss: torch.FloatTensor | None = None
  90. mlm_loss: torch.FloatTensor | None = None
  91. mep_loss: torch.FloatTensor | None = None
  92. logits: torch.FloatTensor | None = None
  93. entity_logits: torch.FloatTensor | None = None
  94. hidden_states: tuple[torch.FloatTensor] | None = None
  95. entity_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  96. attentions: tuple[torch.FloatTensor, ...] | None = None
  97. @dataclass
  98. @auto_docstring(
  99. custom_intro="""
  100. Outputs of entity classification models.
  101. """
  102. )
  103. class EntityClassificationOutput(ModelOutput):
  104. r"""
  105. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  106. Classification loss.
  107. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  108. Classification scores (before SoftMax).
  109. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  110. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  111. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  112. layer plus the initial entity embedding outputs.
  113. """
  114. loss: torch.FloatTensor | None = None
  115. logits: torch.FloatTensor | None = None
  116. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  117. entity_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  118. attentions: tuple[torch.FloatTensor, ...] | None = None
  119. @dataclass
  120. @auto_docstring(
  121. custom_intro="""
  122. Outputs of entity pair classification models.
  123. """
  124. )
  125. class EntityPairClassificationOutput(ModelOutput):
  126. r"""
  127. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  128. Classification loss.
  129. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  130. Classification scores (before SoftMax).
  131. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  132. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  133. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  134. layer plus the initial entity embedding outputs.
  135. """
  136. loss: torch.FloatTensor | None = None
  137. logits: torch.FloatTensor | None = None
  138. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  139. entity_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  140. attentions: tuple[torch.FloatTensor, ...] | None = None
  141. @dataclass
  142. @auto_docstring(
  143. custom_intro="""
  144. Outputs of entity span classification models.
  145. """
  146. )
  147. class EntitySpanClassificationOutput(ModelOutput):
  148. r"""
  149. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  150. Classification loss.
  151. logits (`torch.FloatTensor` of shape `(batch_size, entity_length, config.num_labels)`):
  152. Classification scores (before SoftMax).
  153. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  154. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  155. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  156. layer plus the initial entity embedding outputs.
  157. """
  158. loss: torch.FloatTensor | None = None
  159. logits: torch.FloatTensor | None = None
  160. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  161. entity_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  162. attentions: tuple[torch.FloatTensor, ...] | None = None
  163. @dataclass
  164. @auto_docstring(
  165. custom_intro="""
  166. Outputs of sentence classification models.
  167. """
  168. )
  169. class LukeSequenceClassifierOutput(ModelOutput):
  170. r"""
  171. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  172. Classification (or regression if config.num_labels==1) loss.
  173. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  174. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  175. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  176. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  177. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  178. layer plus the initial entity embedding outputs.
  179. """
  180. loss: torch.FloatTensor | None = None
  181. logits: torch.FloatTensor | None = None
  182. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  183. entity_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  184. attentions: tuple[torch.FloatTensor, ...] | None = None
  185. @dataclass
  186. @auto_docstring(
  187. custom_intro="""
  188. Base class for outputs of token classification models.
  189. """
  190. )
  191. class LukeTokenClassifierOutput(ModelOutput):
  192. r"""
  193. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  194. Classification loss.
  195. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
  196. Classification scores (before SoftMax).
  197. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  198. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  199. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  200. layer plus the initial entity embedding outputs.
  201. """
  202. loss: torch.FloatTensor | None = None
  203. logits: torch.FloatTensor | None = None
  204. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  205. entity_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  206. attentions: tuple[torch.FloatTensor, ...] | None = None
  207. @dataclass
  208. @auto_docstring(
  209. custom_intro="""
  210. Outputs of question answering models.
  211. """
  212. )
  213. class LukeQuestionAnsweringModelOutput(ModelOutput):
  214. r"""
  215. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  216. Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
  217. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  218. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  219. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  220. layer plus the initial entity embedding outputs.
  221. """
  222. loss: torch.FloatTensor | None = None
  223. start_logits: torch.FloatTensor | None = None
  224. end_logits: torch.FloatTensor | None = None
  225. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  226. entity_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  227. attentions: tuple[torch.FloatTensor, ...] | None = None
  228. @dataclass
  229. @auto_docstring(
  230. custom_intro="""
  231. Outputs of multiple choice models.
  232. """
  233. )
  234. class LukeMultipleChoiceModelOutput(ModelOutput):
  235. r"""
  236. loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
  237. Classification loss.
  238. logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
  239. *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
  240. Classification scores (before SoftMax).
  241. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  242. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  243. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  244. layer plus the initial entity embedding outputs.
  245. """
  246. loss: torch.FloatTensor | None = None
  247. logits: torch.FloatTensor | None = None
  248. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  249. entity_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  250. attentions: tuple[torch.FloatTensor, ...] | None = None
  251. class LukeEmbeddings(nn.Module):
  252. """
  253. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  254. """
  255. def __init__(self, config):
  256. super().__init__()
  257. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  258. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  259. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  260. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  261. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  262. # End copy
  263. self.padding_idx = config.pad_token_id
  264. self.position_embeddings = nn.Embedding(
  265. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  266. )
  267. def forward(
  268. self,
  269. input_ids=None,
  270. token_type_ids=None,
  271. position_ids=None,
  272. inputs_embeds=None,
  273. ):
  274. if position_ids is None:
  275. if input_ids is not None:
  276. # Create the position ids from the input token ids. Any padded tokens remain padded.
  277. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device)
  278. else:
  279. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  280. if input_ids is not None:
  281. input_shape = input_ids.size()
  282. else:
  283. input_shape = inputs_embeds.size()[:-1]
  284. if token_type_ids is None:
  285. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  286. if inputs_embeds is None:
  287. inputs_embeds = self.word_embeddings(input_ids)
  288. position_embeddings = self.position_embeddings(position_ids)
  289. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  290. embeddings = inputs_embeds + position_embeddings + token_type_embeddings
  291. embeddings = self.LayerNorm(embeddings)
  292. embeddings = self.dropout(embeddings)
  293. return embeddings
  294. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  295. """
  296. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  297. Args:
  298. inputs_embeds: torch.Tensor
  299. Returns: torch.Tensor
  300. """
  301. input_shape = inputs_embeds.size()[:-1]
  302. sequence_length = input_shape[1]
  303. position_ids = torch.arange(
  304. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  305. )
  306. return position_ids.unsqueeze(0).expand(input_shape)
  307. class LukeEntityEmbeddings(nn.Module):
  308. def __init__(self, config: LukeConfig):
  309. super().__init__()
  310. self.config = config
  311. self.entity_embeddings = nn.Embedding(config.entity_vocab_size, config.entity_emb_size, padding_idx=0)
  312. if config.entity_emb_size != config.hidden_size:
  313. self.entity_embedding_dense = nn.Linear(config.entity_emb_size, config.hidden_size, bias=False)
  314. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  315. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  316. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  317. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  318. def forward(
  319. self,
  320. entity_ids: torch.LongTensor,
  321. position_ids: torch.LongTensor,
  322. token_type_ids: torch.LongTensor | None = None,
  323. ):
  324. if token_type_ids is None:
  325. token_type_ids = torch.zeros_like(entity_ids)
  326. entity_embeddings = self.entity_embeddings(entity_ids)
  327. if self.config.entity_emb_size != self.config.hidden_size:
  328. entity_embeddings = self.entity_embedding_dense(entity_embeddings)
  329. position_embeddings = self.position_embeddings(position_ids.clamp(min=0))
  330. position_embedding_mask = (position_ids != -1).type_as(position_embeddings).unsqueeze(-1)
  331. position_embeddings = position_embeddings * position_embedding_mask
  332. position_embeddings = torch.sum(position_embeddings, dim=-2)
  333. position_embeddings = position_embeddings / position_embedding_mask.sum(dim=-2).clamp(min=1e-7)
  334. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  335. embeddings = entity_embeddings + position_embeddings + token_type_embeddings
  336. embeddings = self.LayerNorm(embeddings)
  337. embeddings = self.dropout(embeddings)
  338. return embeddings
  339. class LukeSelfAttention(nn.Module):
  340. def __init__(self, config):
  341. super().__init__()
  342. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  343. raise ValueError(
  344. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  345. f"heads {config.num_attention_heads}."
  346. )
  347. self.num_attention_heads = config.num_attention_heads
  348. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  349. self.all_head_size = self.num_attention_heads * self.attention_head_size
  350. self.use_entity_aware_attention = config.use_entity_aware_attention
  351. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  352. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  353. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  354. if self.use_entity_aware_attention:
  355. self.w2e_query = nn.Linear(config.hidden_size, self.all_head_size)
  356. self.e2w_query = nn.Linear(config.hidden_size, self.all_head_size)
  357. self.e2e_query = nn.Linear(config.hidden_size, self.all_head_size)
  358. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  359. def transpose_for_scores(self, x):
  360. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  361. x = x.view(*new_x_shape)
  362. return x.permute(0, 2, 1, 3)
  363. def forward(
  364. self,
  365. word_hidden_states,
  366. entity_hidden_states,
  367. attention_mask=None,
  368. output_attentions=False,
  369. ):
  370. word_size = word_hidden_states.size(1)
  371. if entity_hidden_states is None:
  372. concat_hidden_states = word_hidden_states
  373. else:
  374. concat_hidden_states = torch.cat([word_hidden_states, entity_hidden_states], dim=1)
  375. key_layer = self.transpose_for_scores(self.key(concat_hidden_states))
  376. value_layer = self.transpose_for_scores(self.value(concat_hidden_states))
  377. if self.use_entity_aware_attention and entity_hidden_states is not None:
  378. # compute query vectors using word-word (w2w), word-entity (w2e), entity-word (e2w), entity-entity (e2e)
  379. # query layers
  380. w2w_query_layer = self.transpose_for_scores(self.query(word_hidden_states))
  381. w2e_query_layer = self.transpose_for_scores(self.w2e_query(word_hidden_states))
  382. e2w_query_layer = self.transpose_for_scores(self.e2w_query(entity_hidden_states))
  383. e2e_query_layer = self.transpose_for_scores(self.e2e_query(entity_hidden_states))
  384. # compute w2w, w2e, e2w, and e2e key vectors used with the query vectors computed above
  385. w2w_key_layer = key_layer[:, :, :word_size, :]
  386. e2w_key_layer = key_layer[:, :, :word_size, :]
  387. w2e_key_layer = key_layer[:, :, word_size:, :]
  388. e2e_key_layer = key_layer[:, :, word_size:, :]
  389. # compute attention scores based on the dot product between the query and key vectors
  390. w2w_attention_scores = torch.matmul(w2w_query_layer, w2w_key_layer.transpose(-1, -2))
  391. w2e_attention_scores = torch.matmul(w2e_query_layer, w2e_key_layer.transpose(-1, -2))
  392. e2w_attention_scores = torch.matmul(e2w_query_layer, e2w_key_layer.transpose(-1, -2))
  393. e2e_attention_scores = torch.matmul(e2e_query_layer, e2e_key_layer.transpose(-1, -2))
  394. # combine attention scores to create the final attention score matrix
  395. word_attention_scores = torch.cat([w2w_attention_scores, w2e_attention_scores], dim=3)
  396. entity_attention_scores = torch.cat([e2w_attention_scores, e2e_attention_scores], dim=3)
  397. attention_scores = torch.cat([word_attention_scores, entity_attention_scores], dim=2)
  398. else:
  399. query_layer = self.transpose_for_scores(self.query(concat_hidden_states))
  400. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  401. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  402. if attention_mask is not None:
  403. # Apply the attention mask is (precomputed for all layers in LukeModel forward() function)
  404. attention_scores = attention_scores + attention_mask
  405. # Normalize the attention scores to probabilities.
  406. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  407. # This is actually dropping out entire tokens to attend to, which might
  408. # seem a bit unusual, but is taken from the original Transformer paper.
  409. attention_probs = self.dropout(attention_probs)
  410. context_layer = torch.matmul(attention_probs, value_layer)
  411. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  412. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  413. context_layer = context_layer.view(*new_context_layer_shape)
  414. output_word_hidden_states = context_layer[:, :word_size, :]
  415. if entity_hidden_states is None:
  416. output_entity_hidden_states = None
  417. else:
  418. output_entity_hidden_states = context_layer[:, word_size:, :]
  419. if output_attentions:
  420. outputs = (output_word_hidden_states, output_entity_hidden_states, attention_probs)
  421. else:
  422. outputs = (output_word_hidden_states, output_entity_hidden_states)
  423. return outputs
  424. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
  425. class LukeSelfOutput(nn.Module):
  426. def __init__(self, config):
  427. super().__init__()
  428. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  429. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  430. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  431. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  432. hidden_states = self.dense(hidden_states)
  433. hidden_states = self.dropout(hidden_states)
  434. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  435. return hidden_states
  436. class LukeAttention(nn.Module):
  437. def __init__(self, config):
  438. super().__init__()
  439. self.self = LukeSelfAttention(config)
  440. self.output = LukeSelfOutput(config)
  441. def forward(
  442. self,
  443. word_hidden_states,
  444. entity_hidden_states,
  445. attention_mask=None,
  446. output_attentions=False,
  447. ):
  448. word_size = word_hidden_states.size(1)
  449. self_outputs = self.self(
  450. word_hidden_states,
  451. entity_hidden_states,
  452. attention_mask,
  453. output_attentions,
  454. )
  455. if entity_hidden_states is None:
  456. concat_self_outputs = self_outputs[0]
  457. concat_hidden_states = word_hidden_states
  458. else:
  459. concat_self_outputs = torch.cat(self_outputs[:2], dim=1)
  460. concat_hidden_states = torch.cat([word_hidden_states, entity_hidden_states], dim=1)
  461. attention_output = self.output(concat_self_outputs, concat_hidden_states)
  462. word_attention_output = attention_output[:, :word_size, :]
  463. if entity_hidden_states is None:
  464. entity_attention_output = None
  465. else:
  466. entity_attention_output = attention_output[:, word_size:, :]
  467. # add attentions if we output them
  468. outputs = (word_attention_output, entity_attention_output) + self_outputs[2:]
  469. return outputs
  470. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  471. class LukeIntermediate(nn.Module):
  472. def __init__(self, config):
  473. super().__init__()
  474. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  475. if isinstance(config.hidden_act, str):
  476. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  477. else:
  478. self.intermediate_act_fn = config.hidden_act
  479. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  480. hidden_states = self.dense(hidden_states)
  481. hidden_states = self.intermediate_act_fn(hidden_states)
  482. return hidden_states
  483. # Copied from transformers.models.bert.modeling_bert.BertOutput
  484. class LukeOutput(nn.Module):
  485. def __init__(self, config):
  486. super().__init__()
  487. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  488. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  489. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  490. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  491. hidden_states = self.dense(hidden_states)
  492. hidden_states = self.dropout(hidden_states)
  493. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  494. return hidden_states
  495. class LukeLayer(GradientCheckpointingLayer):
  496. def __init__(self, config):
  497. super().__init__()
  498. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  499. self.seq_len_dim = 1
  500. self.attention = LukeAttention(config)
  501. self.intermediate = LukeIntermediate(config)
  502. self.output = LukeOutput(config)
  503. def forward(
  504. self,
  505. word_hidden_states,
  506. entity_hidden_states,
  507. attention_mask=None,
  508. output_attentions=False,
  509. ):
  510. word_size = word_hidden_states.size(1)
  511. self_attention_outputs = self.attention(
  512. word_hidden_states,
  513. entity_hidden_states,
  514. attention_mask,
  515. output_attentions=output_attentions,
  516. )
  517. if entity_hidden_states is None:
  518. concat_attention_output = self_attention_outputs[0]
  519. else:
  520. concat_attention_output = torch.cat(self_attention_outputs[:2], dim=1)
  521. outputs = self_attention_outputs[2:] # add self attentions if we output attention weights
  522. layer_output = apply_chunking_to_forward(
  523. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, concat_attention_output
  524. )
  525. word_layer_output = layer_output[:, :word_size, :]
  526. if entity_hidden_states is None:
  527. entity_layer_output = None
  528. else:
  529. entity_layer_output = layer_output[:, word_size:, :]
  530. outputs = (word_layer_output, entity_layer_output) + outputs
  531. return outputs
  532. def feed_forward_chunk(self, attention_output):
  533. intermediate_output = self.intermediate(attention_output)
  534. layer_output = self.output(intermediate_output, attention_output)
  535. return layer_output
  536. class LukeEncoder(nn.Module):
  537. def __init__(self, config):
  538. super().__init__()
  539. self.config = config
  540. self.layer = nn.ModuleList([LukeLayer(config) for _ in range(config.num_hidden_layers)])
  541. self.gradient_checkpointing = False
  542. def forward(
  543. self,
  544. word_hidden_states,
  545. entity_hidden_states,
  546. attention_mask=None,
  547. output_attentions=False,
  548. output_hidden_states=False,
  549. return_dict=True,
  550. ):
  551. all_word_hidden_states = () if output_hidden_states else None
  552. all_entity_hidden_states = () if output_hidden_states else None
  553. all_self_attentions = () if output_attentions else None
  554. for i, layer_module in enumerate(self.layer):
  555. if output_hidden_states:
  556. all_word_hidden_states = all_word_hidden_states + (word_hidden_states,)
  557. all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,)
  558. layer_outputs = layer_module(
  559. word_hidden_states,
  560. entity_hidden_states,
  561. attention_mask,
  562. output_attentions,
  563. )
  564. word_hidden_states = layer_outputs[0]
  565. if entity_hidden_states is not None:
  566. entity_hidden_states = layer_outputs[1]
  567. if output_attentions:
  568. all_self_attentions = all_self_attentions + (layer_outputs[2],)
  569. if output_hidden_states:
  570. all_word_hidden_states = all_word_hidden_states + (word_hidden_states,)
  571. all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,)
  572. if not return_dict:
  573. return tuple(
  574. v
  575. for v in [
  576. word_hidden_states,
  577. all_word_hidden_states,
  578. all_self_attentions,
  579. entity_hidden_states,
  580. all_entity_hidden_states,
  581. ]
  582. if v is not None
  583. )
  584. return BaseLukeModelOutput(
  585. last_hidden_state=word_hidden_states,
  586. hidden_states=all_word_hidden_states,
  587. attentions=all_self_attentions,
  588. entity_last_hidden_state=entity_hidden_states,
  589. entity_hidden_states=all_entity_hidden_states,
  590. )
  591. # Copied from transformers.models.bert.modeling_bert.BertPooler
  592. class LukePooler(nn.Module):
  593. def __init__(self, config):
  594. super().__init__()
  595. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  596. self.activation = nn.Tanh()
  597. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  598. # We "pool" the model by simply taking the hidden state corresponding
  599. # to the first token.
  600. first_token_tensor = hidden_states[:, 0]
  601. pooled_output = self.dense(first_token_tensor)
  602. pooled_output = self.activation(pooled_output)
  603. return pooled_output
  604. class EntityPredictionHeadTransform(nn.Module):
  605. def __init__(self, config):
  606. super().__init__()
  607. self.dense = nn.Linear(config.hidden_size, config.entity_emb_size)
  608. if isinstance(config.hidden_act, str):
  609. self.transform_act_fn = ACT2FN[config.hidden_act]
  610. else:
  611. self.transform_act_fn = config.hidden_act
  612. self.LayerNorm = nn.LayerNorm(config.entity_emb_size, eps=config.layer_norm_eps)
  613. def forward(self, hidden_states):
  614. hidden_states = self.dense(hidden_states)
  615. hidden_states = self.transform_act_fn(hidden_states)
  616. hidden_states = self.LayerNorm(hidden_states)
  617. return hidden_states
  618. class EntityPredictionHead(nn.Module):
  619. def __init__(self, config):
  620. super().__init__()
  621. self.config = config
  622. self.transform = EntityPredictionHeadTransform(config)
  623. self.decoder = nn.Linear(config.entity_emb_size, config.entity_vocab_size, bias=False)
  624. self.bias = nn.Parameter(torch.zeros(config.entity_vocab_size))
  625. def forward(self, hidden_states):
  626. hidden_states = self.transform(hidden_states)
  627. hidden_states = self.decoder(hidden_states) + self.bias
  628. return hidden_states
  629. @auto_docstring
  630. class LukePreTrainedModel(PreTrainedModel):
  631. config: LukeConfig
  632. base_model_prefix = "luke"
  633. supports_gradient_checkpointing = True
  634. _no_split_modules = ["LukeAttention", "LukeEntityEmbeddings"]
  635. @torch.no_grad()
  636. def _init_weights(self, module: nn.Module):
  637. """Initialize the weights"""
  638. if isinstance(module, nn.Linear):
  639. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  640. if module.bias is not None:
  641. init.zeros_(module.bias)
  642. elif isinstance(module, nn.Embedding):
  643. if module.embedding_dim == 1: # embedding for bias parameters
  644. init.zeros_(module.weight)
  645. else:
  646. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  647. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  648. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  649. init.zeros_(module.weight[module.padding_idx])
  650. elif isinstance(module, nn.LayerNorm):
  651. init.zeros_(module.bias)
  652. init.ones_(module.weight)
  653. @auto_docstring(
  654. custom_intro="""
  655. The bare LUKE model transformer outputting raw hidden-states for both word tokens and entities without any
  656. """
  657. )
  658. class LukeModel(LukePreTrainedModel):
  659. def __init__(self, config: LukeConfig, add_pooling_layer: bool = True):
  660. r"""
  661. add_pooling_layer (bool, *optional*, defaults to `True`):
  662. Whether to add a pooling layer
  663. """
  664. super().__init__(config)
  665. self.config = config
  666. self.embeddings = LukeEmbeddings(config)
  667. self.entity_embeddings = LukeEntityEmbeddings(config)
  668. self.encoder = LukeEncoder(config)
  669. self.pooler = LukePooler(config) if add_pooling_layer else None
  670. # Initialize weights and apply final processing
  671. self.post_init()
  672. def get_input_embeddings(self):
  673. return self.embeddings.word_embeddings
  674. def set_input_embeddings(self, value):
  675. self.embeddings.word_embeddings = value
  676. def get_entity_embeddings(self):
  677. return self.entity_embeddings.entity_embeddings
  678. def set_entity_embeddings(self, value):
  679. self.entity_embeddings.entity_embeddings = value
  680. @auto_docstring
  681. def forward(
  682. self,
  683. input_ids: torch.LongTensor | None = None,
  684. attention_mask: torch.FloatTensor | None = None,
  685. token_type_ids: torch.LongTensor | None = None,
  686. position_ids: torch.LongTensor | None = None,
  687. entity_ids: torch.LongTensor | None = None,
  688. entity_attention_mask: torch.FloatTensor | None = None,
  689. entity_token_type_ids: torch.LongTensor | None = None,
  690. entity_position_ids: torch.LongTensor | None = None,
  691. inputs_embeds: torch.FloatTensor | None = None,
  692. output_attentions: bool | None = None,
  693. output_hidden_states: bool | None = None,
  694. return_dict: bool | None = None,
  695. **kwargs,
  696. ) -> tuple | BaseLukeModelOutputWithPooling:
  697. r"""
  698. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  699. Indices of entity tokens in the entity vocabulary.
  700. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  701. [`PreTrainedTokenizer.__call__`] for details.
  702. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  703. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  704. - 1 for entity tokens that are **not masked**,
  705. - 0 for entity tokens that are **masked**.
  706. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  707. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  708. selected in `[0, 1]`:
  709. - 0 corresponds to a *portion A* entity token,
  710. - 1 corresponds to a *portion B* entity token.
  711. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  712. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  713. config.max_position_embeddings - 1]`.
  714. Examples:
  715. ```python
  716. >>> from transformers import AutoTokenizer, LukeModel
  717. >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-base")
  718. >>> model = LukeModel.from_pretrained("studio-ousia/luke-base")
  719. # Compute the contextualized entity representation corresponding to the entity mention "Beyoncé"
  720. >>> text = "Beyoncé lives in Los Angeles."
  721. >>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé"
  722. >>> encoding = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
  723. >>> outputs = model(**encoding)
  724. >>> word_last_hidden_state = outputs.last_hidden_state
  725. >>> entity_last_hidden_state = outputs.entity_last_hidden_state
  726. # Input Wikipedia entities to obtain enriched contextualized representations of word tokens
  727. >>> text = "Beyoncé lives in Los Angeles."
  728. >>> entities = [
  729. ... "Beyoncé",
  730. ... "Los Angeles",
  731. ... ] # Wikipedia entity titles corresponding to the entity mentions "Beyoncé" and "Los Angeles"
  732. >>> entity_spans = [
  733. ... (0, 7),
  734. ... (17, 28),
  735. ... ] # character-based entity spans corresponding to "Beyoncé" and "Los Angeles"
  736. >>> encoding = tokenizer(
  737. ... text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt"
  738. ... )
  739. >>> outputs = model(**encoding)
  740. >>> word_last_hidden_state = outputs.last_hidden_state
  741. >>> entity_last_hidden_state = outputs.entity_last_hidden_state
  742. ```"""
  743. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  744. output_hidden_states = (
  745. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  746. )
  747. return_dict = return_dict if return_dict is not None else self.config.return_dict
  748. if input_ids is not None and inputs_embeds is not None:
  749. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  750. elif input_ids is not None:
  751. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  752. input_shape = input_ids.size()
  753. elif inputs_embeds is not None:
  754. input_shape = inputs_embeds.size()[:-1]
  755. else:
  756. raise ValueError("You have to specify either input_ids or inputs_embeds")
  757. batch_size, seq_length = input_shape
  758. device = input_ids.device if input_ids is not None else inputs_embeds.device
  759. if attention_mask is None:
  760. attention_mask = torch.ones((batch_size, seq_length), device=device)
  761. if token_type_ids is None:
  762. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  763. if entity_ids is not None:
  764. entity_seq_length = entity_ids.size(1)
  765. if entity_attention_mask is None:
  766. entity_attention_mask = torch.ones((batch_size, entity_seq_length), device=device)
  767. if entity_token_type_ids is None:
  768. entity_token_type_ids = torch.zeros((batch_size, entity_seq_length), dtype=torch.long, device=device)
  769. # First, compute word embeddings
  770. word_embedding_output = self.embeddings(
  771. input_ids=input_ids,
  772. position_ids=position_ids,
  773. token_type_ids=token_type_ids,
  774. inputs_embeds=inputs_embeds,
  775. )
  776. # Second, compute extended attention mask
  777. extended_attention_mask = self.get_extended_attention_mask(attention_mask, entity_attention_mask)
  778. # Third, compute entity embeddings and concatenate with word embeddings
  779. if entity_ids is None:
  780. entity_embedding_output = None
  781. else:
  782. entity_embedding_output = self.entity_embeddings(entity_ids, entity_position_ids, entity_token_type_ids)
  783. # Fourth, send embeddings through the model
  784. encoder_outputs = self.encoder(
  785. word_embedding_output,
  786. entity_embedding_output,
  787. attention_mask=extended_attention_mask,
  788. output_attentions=output_attentions,
  789. output_hidden_states=output_hidden_states,
  790. return_dict=return_dict,
  791. )
  792. # Fifth, get the output. LukeModel outputs the same as BertModel, namely sequence_output of shape (batch_size, seq_len, hidden_size)
  793. sequence_output = encoder_outputs[0]
  794. # Sixth, we compute the pooled_output, word_sequence_output and entity_sequence_output based on the sequence_output
  795. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  796. if not return_dict:
  797. return (sequence_output, pooled_output) + encoder_outputs[1:]
  798. return BaseLukeModelOutputWithPooling(
  799. last_hidden_state=sequence_output,
  800. pooler_output=pooled_output,
  801. hidden_states=encoder_outputs.hidden_states,
  802. attentions=encoder_outputs.attentions,
  803. entity_last_hidden_state=encoder_outputs.entity_last_hidden_state,
  804. entity_hidden_states=encoder_outputs.entity_hidden_states,
  805. )
  806. def get_extended_attention_mask(
  807. self, word_attention_mask: torch.LongTensor, entity_attention_mask: torch.LongTensor | None
  808. ):
  809. """
  810. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
  811. Arguments:
  812. word_attention_mask (`torch.LongTensor`):
  813. Attention mask for word tokens with ones indicating tokens to attend to, zeros for tokens to ignore.
  814. entity_attention_mask (`torch.LongTensor`, *optional*):
  815. Attention mask for entity tokens with ones indicating tokens to attend to, zeros for tokens to ignore.
  816. Returns:
  817. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
  818. """
  819. attention_mask = word_attention_mask
  820. if entity_attention_mask is not None:
  821. attention_mask = torch.cat([attention_mask, entity_attention_mask], dim=-1)
  822. if attention_mask.dim() == 3:
  823. extended_attention_mask = attention_mask[:, None, :, :]
  824. elif attention_mask.dim() == 2:
  825. extended_attention_mask = attention_mask[:, None, None, :]
  826. else:
  827. raise ValueError(f"Wrong shape for attention_mask (shape {attention_mask.shape})")
  828. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
  829. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
  830. return extended_attention_mask
  831. def create_position_ids_from_input_ids(input_ids, padding_idx):
  832. """
  833. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  834. are ignored. This is modified from fairseq's `utils.make_positions`.
  835. Args:
  836. x: torch.Tensor x:
  837. Returns: torch.Tensor
  838. """
  839. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  840. mask = input_ids.ne(padding_idx).int()
  841. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
  842. return incremental_indices.long() + padding_idx
  843. # Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead
  844. class LukeLMHead(nn.Module):
  845. """Roberta Head for masked language modeling."""
  846. def __init__(self, config):
  847. super().__init__()
  848. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  849. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  850. self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
  851. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  852. def forward(self, features, **kwargs):
  853. x = self.dense(features)
  854. x = gelu(x)
  855. x = self.layer_norm(x)
  856. # project back to size of vocabulary with bias
  857. x = self.decoder(x)
  858. return x
  859. @auto_docstring(
  860. custom_intro="""
  861. The LUKE model with a language modeling head and entity prediction head on top for masked language modeling and
  862. masked entity prediction.
  863. """
  864. )
  865. class LukeForMaskedLM(LukePreTrainedModel):
  866. _tied_weights_keys = {
  867. "entity_predictions.decoder.weight": "luke.entity_embeddings.entity_embeddings.weight",
  868. "lm_head.bias": "lm_head.decoder.bias",
  869. }
  870. def __init__(self, config):
  871. super().__init__(config)
  872. self.luke = LukeModel(config)
  873. self.lm_head = LukeLMHead(config)
  874. self.entity_predictions = EntityPredictionHead(config)
  875. self.loss_fn = nn.CrossEntropyLoss()
  876. # Initialize weights and apply final processing
  877. self.post_init()
  878. def get_output_embeddings(self):
  879. return self.lm_head.decoder
  880. def set_output_embeddings(self, new_embeddings):
  881. self.lm_head.decoder = new_embeddings
  882. @auto_docstring
  883. def forward(
  884. self,
  885. input_ids: torch.LongTensor | None = None,
  886. attention_mask: torch.FloatTensor | None = None,
  887. token_type_ids: torch.LongTensor | None = None,
  888. position_ids: torch.LongTensor | None = None,
  889. entity_ids: torch.LongTensor | None = None,
  890. entity_attention_mask: torch.LongTensor | None = None,
  891. entity_token_type_ids: torch.LongTensor | None = None,
  892. entity_position_ids: torch.LongTensor | None = None,
  893. labels: torch.LongTensor | None = None,
  894. entity_labels: torch.LongTensor | None = None,
  895. inputs_embeds: torch.FloatTensor | None = None,
  896. output_attentions: bool | None = None,
  897. output_hidden_states: bool | None = None,
  898. return_dict: bool | None = None,
  899. **kwargs,
  900. ) -> tuple | LukeMaskedLMOutput:
  901. r"""
  902. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  903. Indices of entity tokens in the entity vocabulary.
  904. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  905. [`PreTrainedTokenizer.__call__`] for details.
  906. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  907. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  908. - 1 for entity tokens that are **not masked**,
  909. - 0 for entity tokens that are **masked**.
  910. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  911. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  912. selected in `[0, 1]`:
  913. - 0 corresponds to a *portion A* entity token,
  914. - 1 corresponds to a *portion B* entity token.
  915. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  916. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  917. config.max_position_embeddings - 1]`.
  918. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  919. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  920. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  921. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  922. entity_labels (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  923. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  924. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  925. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  926. """
  927. return_dict = return_dict if return_dict is not None else self.config.return_dict
  928. outputs = self.luke(
  929. input_ids=input_ids,
  930. attention_mask=attention_mask,
  931. token_type_ids=token_type_ids,
  932. position_ids=position_ids,
  933. entity_ids=entity_ids,
  934. entity_attention_mask=entity_attention_mask,
  935. entity_token_type_ids=entity_token_type_ids,
  936. entity_position_ids=entity_position_ids,
  937. inputs_embeds=inputs_embeds,
  938. output_attentions=output_attentions,
  939. output_hidden_states=output_hidden_states,
  940. return_dict=True,
  941. )
  942. loss = None
  943. mlm_loss = None
  944. logits = self.lm_head(outputs.last_hidden_state)
  945. if labels is not None:
  946. # move labels to correct device
  947. labels = labels.to(logits.device)
  948. mlm_loss = self.loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1))
  949. if loss is None:
  950. loss = mlm_loss
  951. mep_loss = None
  952. entity_logits = None
  953. if outputs.entity_last_hidden_state is not None:
  954. entity_logits = self.entity_predictions(outputs.entity_last_hidden_state)
  955. if entity_labels is not None:
  956. mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1))
  957. if loss is None:
  958. loss = mep_loss
  959. else:
  960. loss = loss + mep_loss
  961. if not return_dict:
  962. return tuple(
  963. v
  964. for v in [
  965. loss,
  966. mlm_loss,
  967. mep_loss,
  968. logits,
  969. entity_logits,
  970. outputs.hidden_states,
  971. outputs.entity_hidden_states,
  972. outputs.attentions,
  973. ]
  974. if v is not None
  975. )
  976. return LukeMaskedLMOutput(
  977. loss=loss,
  978. mlm_loss=mlm_loss,
  979. mep_loss=mep_loss,
  980. logits=logits,
  981. entity_logits=entity_logits,
  982. hidden_states=outputs.hidden_states,
  983. entity_hidden_states=outputs.entity_hidden_states,
  984. attentions=outputs.attentions,
  985. )
  986. @auto_docstring(
  987. custom_intro="""
  988. The LUKE model with a classification head on top (a linear layer on top of the hidden state of the first entity
  989. token) for entity classification tasks, such as Open Entity.
  990. """
  991. )
  992. class LukeForEntityClassification(LukePreTrainedModel):
  993. def __init__(self, config):
  994. super().__init__(config)
  995. self.luke = LukeModel(config)
  996. self.num_labels = config.num_labels
  997. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  998. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  999. # Initialize weights and apply final processing
  1000. self.post_init()
  1001. @auto_docstring
  1002. def forward(
  1003. self,
  1004. input_ids: torch.LongTensor | None = None,
  1005. attention_mask: torch.FloatTensor | None = None,
  1006. token_type_ids: torch.LongTensor | None = None,
  1007. position_ids: torch.LongTensor | None = None,
  1008. entity_ids: torch.LongTensor | None = None,
  1009. entity_attention_mask: torch.FloatTensor | None = None,
  1010. entity_token_type_ids: torch.LongTensor | None = None,
  1011. entity_position_ids: torch.LongTensor | None = None,
  1012. inputs_embeds: torch.FloatTensor | None = None,
  1013. labels: torch.FloatTensor | None = None,
  1014. output_attentions: bool | None = None,
  1015. output_hidden_states: bool | None = None,
  1016. return_dict: bool | None = None,
  1017. **kwargs,
  1018. ) -> tuple | EntityClassificationOutput:
  1019. r"""
  1020. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  1021. Indices of entity tokens in the entity vocabulary.
  1022. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1023. [`PreTrainedTokenizer.__call__`] for details.
  1024. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  1025. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  1026. - 1 for entity tokens that are **not masked**,
  1027. - 0 for entity tokens that are **masked**.
  1028. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  1029. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  1030. selected in `[0, 1]`:
  1031. - 0 corresponds to a *portion A* entity token,
  1032. - 1 corresponds to a *portion B* entity token.
  1033. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  1034. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  1035. config.max_position_embeddings - 1]`.
  1036. labels (`torch.LongTensor` of shape `(batch_size,)` or `(batch_size, num_labels)`, *optional*):
  1037. Labels for computing the classification loss. If the shape is `(batch_size,)`, the cross entropy loss is
  1038. used for the single-label classification. In this case, labels should contain the indices that should be in
  1039. `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, num_labels)`, the binary cross entropy
  1040. loss is used for the multi-label classification. In this case, labels should only contain `[0, 1]`, where 0
  1041. and 1 indicate false and true, respectively.
  1042. Examples:
  1043. ```python
  1044. >>> from transformers import AutoTokenizer, LukeForEntityClassification
  1045. >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-open-entity")
  1046. >>> model = LukeForEntityClassification.from_pretrained("studio-ousia/luke-large-finetuned-open-entity")
  1047. >>> text = "Beyoncé lives in Los Angeles."
  1048. >>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé"
  1049. >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
  1050. >>> outputs = model(**inputs)
  1051. >>> logits = outputs.logits
  1052. >>> predicted_class_idx = logits.argmax(-1).item()
  1053. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  1054. Predicted class: person
  1055. ```"""
  1056. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1057. outputs = self.luke(
  1058. input_ids=input_ids,
  1059. attention_mask=attention_mask,
  1060. token_type_ids=token_type_ids,
  1061. position_ids=position_ids,
  1062. entity_ids=entity_ids,
  1063. entity_attention_mask=entity_attention_mask,
  1064. entity_token_type_ids=entity_token_type_ids,
  1065. entity_position_ids=entity_position_ids,
  1066. inputs_embeds=inputs_embeds,
  1067. output_attentions=output_attentions,
  1068. output_hidden_states=output_hidden_states,
  1069. return_dict=True,
  1070. )
  1071. feature_vector = outputs.entity_last_hidden_state[:, 0, :]
  1072. feature_vector = self.dropout(feature_vector)
  1073. logits = self.classifier(feature_vector)
  1074. loss = None
  1075. if labels is not None:
  1076. # When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
  1077. # cross entropy is used otherwise.
  1078. # move labels to correct device
  1079. labels = labels.to(logits.device)
  1080. if labels.ndim == 1:
  1081. loss = nn.functional.cross_entropy(logits, labels)
  1082. else:
  1083. loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
  1084. if not return_dict:
  1085. return tuple(
  1086. v
  1087. for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
  1088. if v is not None
  1089. )
  1090. return EntityClassificationOutput(
  1091. loss=loss,
  1092. logits=logits,
  1093. hidden_states=outputs.hidden_states,
  1094. entity_hidden_states=outputs.entity_hidden_states,
  1095. attentions=outputs.attentions,
  1096. )
  1097. @auto_docstring(
  1098. custom_intro="""
  1099. The LUKE model with a classification head on top (a linear layer on top of the hidden states of the two entity
  1100. tokens) for entity pair classification tasks, such as TACRED.
  1101. """
  1102. )
  1103. class LukeForEntityPairClassification(LukePreTrainedModel):
  1104. def __init__(self, config):
  1105. super().__init__(config)
  1106. self.luke = LukeModel(config)
  1107. self.num_labels = config.num_labels
  1108. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1109. self.classifier = nn.Linear(config.hidden_size * 2, config.num_labels, False)
  1110. # Initialize weights and apply final processing
  1111. self.post_init()
  1112. @auto_docstring
  1113. def forward(
  1114. self,
  1115. input_ids: torch.LongTensor | None = None,
  1116. attention_mask: torch.FloatTensor | None = None,
  1117. token_type_ids: torch.LongTensor | None = None,
  1118. position_ids: torch.LongTensor | None = None,
  1119. entity_ids: torch.LongTensor | None = None,
  1120. entity_attention_mask: torch.FloatTensor | None = None,
  1121. entity_token_type_ids: torch.LongTensor | None = None,
  1122. entity_position_ids: torch.LongTensor | None = None,
  1123. inputs_embeds: torch.FloatTensor | None = None,
  1124. labels: torch.LongTensor | None = None,
  1125. output_attentions: bool | None = None,
  1126. output_hidden_states: bool | None = None,
  1127. return_dict: bool | None = None,
  1128. **kwargs,
  1129. ) -> tuple | EntityPairClassificationOutput:
  1130. r"""
  1131. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  1132. Indices of entity tokens in the entity vocabulary.
  1133. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1134. [`PreTrainedTokenizer.__call__`] for details.
  1135. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  1136. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  1137. - 1 for entity tokens that are **not masked**,
  1138. - 0 for entity tokens that are **masked**.
  1139. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  1140. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  1141. selected in `[0, 1]`:
  1142. - 0 corresponds to a *portion A* entity token,
  1143. - 1 corresponds to a *portion B* entity token.
  1144. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  1145. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  1146. config.max_position_embeddings - 1]`.
  1147. labels (`torch.LongTensor` of shape `(batch_size,)` or `(batch_size, num_labels)`, *optional*):
  1148. Labels for computing the classification loss. If the shape is `(batch_size,)`, the cross entropy loss is
  1149. used for the single-label classification. In this case, labels should contain the indices that should be in
  1150. `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, num_labels)`, the binary cross entropy
  1151. loss is used for the multi-label classification. In this case, labels should only contain `[0, 1]`, where 0
  1152. and 1 indicate false and true, respectively.
  1153. Examples:
  1154. ```python
  1155. >>> from transformers import AutoTokenizer, LukeForEntityPairClassification
  1156. >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-tacred")
  1157. >>> model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-large-finetuned-tacred")
  1158. >>> text = "Beyoncé lives in Los Angeles."
  1159. >>> entity_spans = [
  1160. ... (0, 7),
  1161. ... (17, 28),
  1162. ... ] # character-based entity spans corresponding to "Beyoncé" and "Los Angeles"
  1163. >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
  1164. >>> outputs = model(**inputs)
  1165. >>> logits = outputs.logits
  1166. >>> predicted_class_idx = logits.argmax(-1).item()
  1167. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  1168. Predicted class: per:cities_of_residence
  1169. ```"""
  1170. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1171. outputs = self.luke(
  1172. input_ids=input_ids,
  1173. attention_mask=attention_mask,
  1174. token_type_ids=token_type_ids,
  1175. position_ids=position_ids,
  1176. entity_ids=entity_ids,
  1177. entity_attention_mask=entity_attention_mask,
  1178. entity_token_type_ids=entity_token_type_ids,
  1179. entity_position_ids=entity_position_ids,
  1180. inputs_embeds=inputs_embeds,
  1181. output_attentions=output_attentions,
  1182. output_hidden_states=output_hidden_states,
  1183. return_dict=True,
  1184. )
  1185. feature_vector = torch.cat(
  1186. [outputs.entity_last_hidden_state[:, 0, :], outputs.entity_last_hidden_state[:, 1, :]], dim=1
  1187. )
  1188. feature_vector = self.dropout(feature_vector)
  1189. logits = self.classifier(feature_vector)
  1190. loss = None
  1191. if labels is not None:
  1192. # When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
  1193. # cross entropy is used otherwise.
  1194. # move labels to correct device
  1195. labels = labels.to(logits.device)
  1196. if labels.ndim == 1:
  1197. loss = nn.functional.cross_entropy(logits, labels)
  1198. else:
  1199. loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
  1200. if not return_dict:
  1201. return tuple(
  1202. v
  1203. for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
  1204. if v is not None
  1205. )
  1206. return EntityPairClassificationOutput(
  1207. loss=loss,
  1208. logits=logits,
  1209. hidden_states=outputs.hidden_states,
  1210. entity_hidden_states=outputs.entity_hidden_states,
  1211. attentions=outputs.attentions,
  1212. )
  1213. @auto_docstring(
  1214. custom_intro="""
  1215. The LUKE model with a span classification head on top (a linear layer on top of the hidden states output) for tasks
  1216. such as named entity recognition.
  1217. """
  1218. )
  1219. class LukeForEntitySpanClassification(LukePreTrainedModel):
  1220. def __init__(self, config):
  1221. super().__init__(config)
  1222. self.luke = LukeModel(config)
  1223. self.num_labels = config.num_labels
  1224. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1225. self.classifier = nn.Linear(config.hidden_size * 3, config.num_labels)
  1226. # Initialize weights and apply final processing
  1227. self.post_init()
  1228. @auto_docstring
  1229. def forward(
  1230. self,
  1231. input_ids: torch.LongTensor | None = None,
  1232. attention_mask: torch.FloatTensor | None = None,
  1233. token_type_ids: torch.LongTensor | None = None,
  1234. position_ids: torch.LongTensor | None = None,
  1235. entity_ids: torch.LongTensor | None = None,
  1236. entity_attention_mask: torch.LongTensor | None = None,
  1237. entity_token_type_ids: torch.LongTensor | None = None,
  1238. entity_position_ids: torch.LongTensor | None = None,
  1239. entity_start_positions: torch.LongTensor | None = None,
  1240. entity_end_positions: torch.LongTensor | None = None,
  1241. inputs_embeds: torch.FloatTensor | None = None,
  1242. labels: torch.LongTensor | None = None,
  1243. output_attentions: bool | None = None,
  1244. output_hidden_states: bool | None = None,
  1245. return_dict: bool | None = None,
  1246. **kwargs,
  1247. ) -> tuple | EntitySpanClassificationOutput:
  1248. r"""
  1249. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  1250. Indices of entity tokens in the entity vocabulary.
  1251. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1252. [`PreTrainedTokenizer.__call__`] for details.
  1253. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  1254. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  1255. - 1 for entity tokens that are **not masked**,
  1256. - 0 for entity tokens that are **masked**.
  1257. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  1258. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  1259. selected in `[0, 1]`:
  1260. - 0 corresponds to a *portion A* entity token,
  1261. - 1 corresponds to a *portion B* entity token.
  1262. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  1263. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  1264. config.max_position_embeddings - 1]`.
  1265. entity_start_positions (`torch.LongTensor`):
  1266. The start positions of entities in the word token sequence.
  1267. entity_end_positions (`torch.LongTensor`):
  1268. The end positions of entities in the word token sequence.
  1269. labels (`torch.LongTensor` of shape `(batch_size, entity_length)` or `(batch_size, entity_length, num_labels)`, *optional*):
  1270. Labels for computing the classification loss. If the shape is `(batch_size, entity_length)`, the cross
  1271. entropy loss is used for the single-label classification. In this case, labels should contain the indices
  1272. that should be in `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, entity_length,
  1273. num_labels)`, the binary cross entropy loss is used for the multi-label classification. In this case,
  1274. labels should only contain `[0, 1]`, where 0 and 1 indicate false and true, respectively.
  1275. Examples:
  1276. ```python
  1277. >>> from transformers import AutoTokenizer, LukeForEntitySpanClassification
  1278. >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")
  1279. >>> model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")
  1280. >>> text = "Beyoncé lives in Los Angeles"
  1281. # List all possible entity spans in the text
  1282. >>> word_start_positions = [0, 8, 14, 17, 21] # character-based start positions of word tokens
  1283. >>> word_end_positions = [7, 13, 16, 20, 28] # character-based end positions of word tokens
  1284. >>> entity_spans = []
  1285. >>> for i, start_pos in enumerate(word_start_positions):
  1286. ... for end_pos in word_end_positions[i:]:
  1287. ... entity_spans.append((start_pos, end_pos))
  1288. >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
  1289. >>> outputs = model(**inputs)
  1290. >>> logits = outputs.logits
  1291. >>> predicted_class_indices = logits.argmax(-1).squeeze().tolist()
  1292. >>> for span, predicted_class_idx in zip(entity_spans, predicted_class_indices):
  1293. ... if predicted_class_idx != 0:
  1294. ... print(text[span[0] : span[1]], model.config.id2label[predicted_class_idx])
  1295. Beyoncé PER
  1296. Los Angeles LOC
  1297. ```"""
  1298. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1299. outputs = self.luke(
  1300. input_ids=input_ids,
  1301. attention_mask=attention_mask,
  1302. token_type_ids=token_type_ids,
  1303. position_ids=position_ids,
  1304. entity_ids=entity_ids,
  1305. entity_attention_mask=entity_attention_mask,
  1306. entity_token_type_ids=entity_token_type_ids,
  1307. entity_position_ids=entity_position_ids,
  1308. inputs_embeds=inputs_embeds,
  1309. output_attentions=output_attentions,
  1310. output_hidden_states=output_hidden_states,
  1311. return_dict=True,
  1312. )
  1313. hidden_size = outputs.last_hidden_state.size(-1)
  1314. entity_start_positions = entity_start_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
  1315. if entity_start_positions.device != outputs.last_hidden_state.device:
  1316. entity_start_positions = entity_start_positions.to(outputs.last_hidden_state.device)
  1317. start_states = torch.gather(outputs.last_hidden_state, -2, entity_start_positions)
  1318. entity_end_positions = entity_end_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
  1319. if entity_end_positions.device != outputs.last_hidden_state.device:
  1320. entity_end_positions = entity_end_positions.to(outputs.last_hidden_state.device)
  1321. end_states = torch.gather(outputs.last_hidden_state, -2, entity_end_positions)
  1322. feature_vector = torch.cat([start_states, end_states, outputs.entity_last_hidden_state], dim=2)
  1323. feature_vector = self.dropout(feature_vector)
  1324. logits = self.classifier(feature_vector)
  1325. loss = None
  1326. if labels is not None:
  1327. # move labels to correct device
  1328. labels = labels.to(logits.device)
  1329. # When the number of dimension of `labels` is 2, cross entropy is used as the loss function. The binary
  1330. # cross entropy is used otherwise.
  1331. if labels.ndim == 2:
  1332. loss = nn.functional.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
  1333. else:
  1334. loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
  1335. if not return_dict:
  1336. return tuple(
  1337. v
  1338. for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
  1339. if v is not None
  1340. )
  1341. return EntitySpanClassificationOutput(
  1342. loss=loss,
  1343. logits=logits,
  1344. hidden_states=outputs.hidden_states,
  1345. entity_hidden_states=outputs.entity_hidden_states,
  1346. attentions=outputs.attentions,
  1347. )
  1348. @auto_docstring(
  1349. custom_intro="""
  1350. The LUKE Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  1351. pooled output) e.g. for GLUE tasks.
  1352. """
  1353. )
  1354. class LukeForSequenceClassification(LukePreTrainedModel):
  1355. def __init__(self, config):
  1356. super().__init__(config)
  1357. self.num_labels = config.num_labels
  1358. self.luke = LukeModel(config)
  1359. self.dropout = nn.Dropout(
  1360. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1361. )
  1362. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1363. # Initialize weights and apply final processing
  1364. self.post_init()
  1365. @auto_docstring
  1366. def forward(
  1367. self,
  1368. input_ids: torch.LongTensor | None = None,
  1369. attention_mask: torch.FloatTensor | None = None,
  1370. token_type_ids: torch.LongTensor | None = None,
  1371. position_ids: torch.LongTensor | None = None,
  1372. entity_ids: torch.LongTensor | None = None,
  1373. entity_attention_mask: torch.FloatTensor | None = None,
  1374. entity_token_type_ids: torch.LongTensor | None = None,
  1375. entity_position_ids: torch.LongTensor | None = None,
  1376. inputs_embeds: torch.FloatTensor | None = None,
  1377. labels: torch.FloatTensor | None = None,
  1378. output_attentions: bool | None = None,
  1379. output_hidden_states: bool | None = None,
  1380. return_dict: bool | None = None,
  1381. **kwargs,
  1382. ) -> tuple | LukeSequenceClassifierOutput:
  1383. r"""
  1384. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  1385. Indices of entity tokens in the entity vocabulary.
  1386. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1387. [`PreTrainedTokenizer.__call__`] for details.
  1388. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  1389. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  1390. - 1 for entity tokens that are **not masked**,
  1391. - 0 for entity tokens that are **masked**.
  1392. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  1393. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  1394. selected in `[0, 1]`:
  1395. - 0 corresponds to a *portion A* entity token,
  1396. - 1 corresponds to a *portion B* entity token.
  1397. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  1398. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  1399. config.max_position_embeddings - 1]`.
  1400. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1401. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1402. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1403. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1404. """
  1405. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1406. outputs = self.luke(
  1407. input_ids=input_ids,
  1408. attention_mask=attention_mask,
  1409. token_type_ids=token_type_ids,
  1410. position_ids=position_ids,
  1411. entity_ids=entity_ids,
  1412. entity_attention_mask=entity_attention_mask,
  1413. entity_token_type_ids=entity_token_type_ids,
  1414. entity_position_ids=entity_position_ids,
  1415. inputs_embeds=inputs_embeds,
  1416. output_attentions=output_attentions,
  1417. output_hidden_states=output_hidden_states,
  1418. return_dict=True,
  1419. )
  1420. pooled_output = outputs.pooler_output
  1421. pooled_output = self.dropout(pooled_output)
  1422. logits = self.classifier(pooled_output)
  1423. loss = None
  1424. if labels is not None:
  1425. # move labels to correct device
  1426. labels = labels.to(logits.device)
  1427. if self.config.problem_type is None:
  1428. if self.num_labels == 1:
  1429. self.config.problem_type = "regression"
  1430. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1431. self.config.problem_type = "single_label_classification"
  1432. else:
  1433. self.config.problem_type = "multi_label_classification"
  1434. if self.config.problem_type == "regression":
  1435. loss_fct = MSELoss()
  1436. if self.num_labels == 1:
  1437. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1438. else:
  1439. loss = loss_fct(logits, labels)
  1440. elif self.config.problem_type == "single_label_classification":
  1441. loss_fct = CrossEntropyLoss()
  1442. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1443. elif self.config.problem_type == "multi_label_classification":
  1444. loss_fct = BCEWithLogitsLoss()
  1445. loss = loss_fct(logits, labels)
  1446. if not return_dict:
  1447. return tuple(
  1448. v
  1449. for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
  1450. if v is not None
  1451. )
  1452. return LukeSequenceClassifierOutput(
  1453. loss=loss,
  1454. logits=logits,
  1455. hidden_states=outputs.hidden_states,
  1456. entity_hidden_states=outputs.entity_hidden_states,
  1457. attentions=outputs.attentions,
  1458. )
  1459. @auto_docstring(
  1460. custom_intro="""
  1461. The LUKE Model with a token classification head on top (a linear layer on top of the hidden-states output). To
  1462. solve Named-Entity Recognition (NER) task using LUKE, `LukeForEntitySpanClassification` is more suitable than this
  1463. class.
  1464. """
  1465. )
  1466. class LukeForTokenClassification(LukePreTrainedModel):
  1467. def __init__(self, config):
  1468. super().__init__(config)
  1469. self.num_labels = config.num_labels
  1470. self.luke = LukeModel(config, add_pooling_layer=False)
  1471. self.dropout = nn.Dropout(
  1472. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1473. )
  1474. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1475. # Initialize weights and apply final processing
  1476. self.post_init()
  1477. @auto_docstring
  1478. def forward(
  1479. self,
  1480. input_ids: torch.LongTensor | None = None,
  1481. attention_mask: torch.FloatTensor | None = None,
  1482. token_type_ids: torch.LongTensor | None = None,
  1483. position_ids: torch.LongTensor | None = None,
  1484. entity_ids: torch.LongTensor | None = None,
  1485. entity_attention_mask: torch.FloatTensor | None = None,
  1486. entity_token_type_ids: torch.LongTensor | None = None,
  1487. entity_position_ids: torch.LongTensor | None = None,
  1488. inputs_embeds: torch.FloatTensor | None = None,
  1489. labels: torch.FloatTensor | None = None,
  1490. output_attentions: bool | None = None,
  1491. output_hidden_states: bool | None = None,
  1492. return_dict: bool | None = None,
  1493. **kwargs,
  1494. ) -> tuple | LukeTokenClassifierOutput:
  1495. r"""
  1496. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  1497. Indices of entity tokens in the entity vocabulary.
  1498. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1499. [`PreTrainedTokenizer.__call__`] for details.
  1500. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  1501. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  1502. - 1 for entity tokens that are **not masked**,
  1503. - 0 for entity tokens that are **masked**.
  1504. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  1505. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  1506. selected in `[0, 1]`:
  1507. - 0 corresponds to a *portion A* entity token,
  1508. - 1 corresponds to a *portion B* entity token.
  1509. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  1510. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  1511. config.max_position_embeddings - 1]`.
  1512. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1513. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1514. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1515. `input_ids` above)
  1516. """
  1517. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1518. outputs = self.luke(
  1519. input_ids=input_ids,
  1520. attention_mask=attention_mask,
  1521. token_type_ids=token_type_ids,
  1522. position_ids=position_ids,
  1523. entity_ids=entity_ids,
  1524. entity_attention_mask=entity_attention_mask,
  1525. entity_token_type_ids=entity_token_type_ids,
  1526. entity_position_ids=entity_position_ids,
  1527. inputs_embeds=inputs_embeds,
  1528. output_attentions=output_attentions,
  1529. output_hidden_states=output_hidden_states,
  1530. return_dict=True,
  1531. )
  1532. sequence_output = outputs.last_hidden_state
  1533. sequence_output = self.dropout(sequence_output)
  1534. logits = self.classifier(sequence_output)
  1535. loss = None
  1536. if labels is not None:
  1537. # move labels to correct device
  1538. labels = labels.to(logits.device)
  1539. loss_fct = CrossEntropyLoss()
  1540. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1541. if not return_dict:
  1542. return tuple(
  1543. v
  1544. for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
  1545. if v is not None
  1546. )
  1547. return LukeTokenClassifierOutput(
  1548. loss=loss,
  1549. logits=logits,
  1550. hidden_states=outputs.hidden_states,
  1551. entity_hidden_states=outputs.entity_hidden_states,
  1552. attentions=outputs.attentions,
  1553. )
  1554. @auto_docstring
  1555. class LukeForQuestionAnswering(LukePreTrainedModel):
  1556. def __init__(self, config):
  1557. super().__init__(config)
  1558. self.num_labels = config.num_labels
  1559. self.luke = LukeModel(config, add_pooling_layer=False)
  1560. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1561. # Initialize weights and apply final processing
  1562. self.post_init()
  1563. @auto_docstring
  1564. def forward(
  1565. self,
  1566. input_ids: torch.LongTensor | None = None,
  1567. attention_mask: torch.FloatTensor | None = None,
  1568. token_type_ids: torch.LongTensor | None = None,
  1569. position_ids: torch.FloatTensor | None = None,
  1570. entity_ids: torch.LongTensor | None = None,
  1571. entity_attention_mask: torch.FloatTensor | None = None,
  1572. entity_token_type_ids: torch.LongTensor | None = None,
  1573. entity_position_ids: torch.LongTensor | None = None,
  1574. inputs_embeds: torch.FloatTensor | None = None,
  1575. start_positions: torch.LongTensor | None = None,
  1576. end_positions: torch.LongTensor | None = None,
  1577. output_attentions: bool | None = None,
  1578. output_hidden_states: bool | None = None,
  1579. return_dict: bool | None = None,
  1580. **kwargs,
  1581. ) -> tuple | LukeQuestionAnsweringModelOutput:
  1582. r"""
  1583. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  1584. Indices of entity tokens in the entity vocabulary.
  1585. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1586. [`PreTrainedTokenizer.__call__`] for details.
  1587. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  1588. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  1589. - 1 for entity tokens that are **not masked**,
  1590. - 0 for entity tokens that are **masked**.
  1591. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  1592. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  1593. selected in `[0, 1]`:
  1594. - 0 corresponds to a *portion A* entity token,
  1595. - 1 corresponds to a *portion B* entity token.
  1596. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  1597. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  1598. config.max_position_embeddings - 1]`.
  1599. """
  1600. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1601. outputs = self.luke(
  1602. input_ids=input_ids,
  1603. attention_mask=attention_mask,
  1604. token_type_ids=token_type_ids,
  1605. position_ids=position_ids,
  1606. entity_ids=entity_ids,
  1607. entity_attention_mask=entity_attention_mask,
  1608. entity_token_type_ids=entity_token_type_ids,
  1609. entity_position_ids=entity_position_ids,
  1610. inputs_embeds=inputs_embeds,
  1611. output_attentions=output_attentions,
  1612. output_hidden_states=output_hidden_states,
  1613. return_dict=True,
  1614. )
  1615. sequence_output = outputs.last_hidden_state
  1616. logits = self.qa_outputs(sequence_output)
  1617. start_logits, end_logits = logits.split(1, dim=-1)
  1618. start_logits = start_logits.squeeze(-1)
  1619. end_logits = end_logits.squeeze(-1)
  1620. total_loss = None
  1621. if start_positions is not None and end_positions is not None:
  1622. # If we are on multi-GPU, split add a dimension
  1623. if len(start_positions.size()) > 1:
  1624. start_positions = start_positions.squeeze(-1)
  1625. if len(end_positions.size()) > 1:
  1626. end_positions = end_positions.squeeze(-1)
  1627. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1628. ignored_index = start_logits.size(1)
  1629. start_positions.clamp_(0, ignored_index)
  1630. end_positions.clamp_(0, ignored_index)
  1631. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1632. start_loss = loss_fct(start_logits, start_positions)
  1633. end_loss = loss_fct(end_logits, end_positions)
  1634. total_loss = (start_loss + end_loss) / 2
  1635. if not return_dict:
  1636. return tuple(
  1637. v
  1638. for v in [
  1639. total_loss,
  1640. start_logits,
  1641. end_logits,
  1642. outputs.hidden_states,
  1643. outputs.entity_hidden_states,
  1644. outputs.attentions,
  1645. ]
  1646. if v is not None
  1647. )
  1648. return LukeQuestionAnsweringModelOutput(
  1649. loss=total_loss,
  1650. start_logits=start_logits,
  1651. end_logits=end_logits,
  1652. hidden_states=outputs.hidden_states,
  1653. entity_hidden_states=outputs.entity_hidden_states,
  1654. attentions=outputs.attentions,
  1655. )
  1656. @auto_docstring
  1657. class LukeForMultipleChoice(LukePreTrainedModel):
  1658. def __init__(self, config):
  1659. super().__init__(config)
  1660. self.luke = LukeModel(config)
  1661. self.dropout = nn.Dropout(
  1662. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1663. )
  1664. self.classifier = nn.Linear(config.hidden_size, 1)
  1665. # Initialize weights and apply final processing
  1666. self.post_init()
  1667. @auto_docstring
  1668. def forward(
  1669. self,
  1670. input_ids: torch.LongTensor | None = None,
  1671. attention_mask: torch.FloatTensor | None = None,
  1672. token_type_ids: torch.LongTensor | None = None,
  1673. position_ids: torch.LongTensor | None = None,
  1674. entity_ids: torch.LongTensor | None = None,
  1675. entity_attention_mask: torch.FloatTensor | None = None,
  1676. entity_token_type_ids: torch.LongTensor | None = None,
  1677. entity_position_ids: torch.LongTensor | None = None,
  1678. inputs_embeds: torch.FloatTensor | None = None,
  1679. labels: torch.FloatTensor | None = None,
  1680. output_attentions: bool | None = None,
  1681. output_hidden_states: bool | None = None,
  1682. return_dict: bool | None = None,
  1683. **kwargs,
  1684. ) -> tuple | LukeMultipleChoiceModelOutput:
  1685. r"""
  1686. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  1687. Indices of input sequence tokens in the vocabulary.
  1688. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1689. [`PreTrainedTokenizer.__call__`] for details.
  1690. [What are input IDs?](../glossary#input-ids)
  1691. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1692. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1693. 1]`:
  1694. - 0 corresponds to a *sentence A* token,
  1695. - 1 corresponds to a *sentence B* token.
  1696. [What are token type IDs?](../glossary#token-type-ids)
  1697. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1698. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1699. config.max_position_embeddings - 1]`.
  1700. [What are position IDs?](../glossary#position-ids)
  1701. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  1702. Indices of entity tokens in the entity vocabulary.
  1703. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1704. [`PreTrainedTokenizer.__call__`] for details.
  1705. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  1706. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  1707. - 1 for entity tokens that are **not masked**,
  1708. - 0 for entity tokens that are **masked**.
  1709. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  1710. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  1711. selected in `[0, 1]`:
  1712. - 0 corresponds to a *portion A* entity token,
  1713. - 1 corresponds to a *portion B* entity token.
  1714. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  1715. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  1716. config.max_position_embeddings - 1]`.
  1717. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  1718. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1719. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1720. model's internal embedding lookup matrix.
  1721. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1722. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1723. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1724. `input_ids` above)
  1725. """
  1726. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1727. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1728. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1729. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1730. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1731. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1732. inputs_embeds = (
  1733. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1734. if inputs_embeds is not None
  1735. else None
  1736. )
  1737. entity_ids = entity_ids.view(-1, entity_ids.size(-1)) if entity_ids is not None else None
  1738. entity_attention_mask = (
  1739. entity_attention_mask.view(-1, entity_attention_mask.size(-1))
  1740. if entity_attention_mask is not None
  1741. else None
  1742. )
  1743. entity_token_type_ids = (
  1744. entity_token_type_ids.view(-1, entity_token_type_ids.size(-1))
  1745. if entity_token_type_ids is not None
  1746. else None
  1747. )
  1748. entity_position_ids = (
  1749. entity_position_ids.view(-1, entity_position_ids.size(-2), entity_position_ids.size(-1))
  1750. if entity_position_ids is not None
  1751. else None
  1752. )
  1753. outputs = self.luke(
  1754. input_ids=input_ids,
  1755. attention_mask=attention_mask,
  1756. token_type_ids=token_type_ids,
  1757. position_ids=position_ids,
  1758. entity_ids=entity_ids,
  1759. entity_attention_mask=entity_attention_mask,
  1760. entity_token_type_ids=entity_token_type_ids,
  1761. entity_position_ids=entity_position_ids,
  1762. inputs_embeds=inputs_embeds,
  1763. output_attentions=output_attentions,
  1764. output_hidden_states=output_hidden_states,
  1765. return_dict=True,
  1766. )
  1767. pooled_output = outputs.pooler_output
  1768. pooled_output = self.dropout(pooled_output)
  1769. logits = self.classifier(pooled_output)
  1770. reshaped_logits = logits.view(-1, num_choices)
  1771. loss = None
  1772. if labels is not None:
  1773. # move labels to correct device
  1774. labels = labels.to(reshaped_logits.device)
  1775. loss_fct = CrossEntropyLoss()
  1776. loss = loss_fct(reshaped_logits, labels)
  1777. if not return_dict:
  1778. return tuple(
  1779. v
  1780. for v in [
  1781. loss,
  1782. reshaped_logits,
  1783. outputs.hidden_states,
  1784. outputs.entity_hidden_states,
  1785. outputs.attentions,
  1786. ]
  1787. if v is not None
  1788. )
  1789. return LukeMultipleChoiceModelOutput(
  1790. loss=loss,
  1791. logits=reshaped_logits,
  1792. hidden_states=outputs.hidden_states,
  1793. entity_hidden_states=outputs.entity_hidden_states,
  1794. attentions=outputs.attentions,
  1795. )
  1796. __all__ = [
  1797. "LukeForEntityClassification",
  1798. "LukeForEntityPairClassification",
  1799. "LukeForEntitySpanClassification",
  1800. "LukeForMultipleChoice",
  1801. "LukeForQuestionAnswering",
  1802. "LukeForSequenceClassification",
  1803. "LukeForTokenClassification",
  1804. "LukeForMaskedLM",
  1805. "LukeModel",
  1806. "LukePreTrainedModel",
  1807. ]