modeling_dpr.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. # Copyright 2018 DPR Authors, The Hugging Face Team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch DPR model for Open Domain Question Answering."""
  15. from dataclasses import dataclass
  16. import torch
  17. from torch import Tensor, nn
  18. from ...modeling_outputs import BaseModelOutputWithPooling
  19. from ...modeling_utils import PreTrainedModel
  20. from ...utils import (
  21. ModelOutput,
  22. auto_docstring,
  23. logging,
  24. )
  25. from ..bert.modeling_bert import BertModel
  26. from .configuration_dpr import DPRConfig
  27. logger = logging.get_logger(__name__)
  28. ##########
  29. # Outputs
  30. ##########
  31. @dataclass
  32. @auto_docstring(
  33. custom_intro="""
  34. Class for outputs of [`DPRQuestionEncoder`].
  35. """
  36. )
  37. class DPRContextEncoderOutput(ModelOutput):
  38. r"""
  39. pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
  40. The DPR encoder outputs the *pooler_output* that corresponds to the context representation. Last layer
  41. hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
  42. This output is to be used to embed contexts for nearest neighbors queries with questions embeddings.
  43. """
  44. pooler_output: torch.FloatTensor
  45. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  46. attentions: tuple[torch.FloatTensor, ...] | None = None
  47. @dataclass
  48. @auto_docstring(
  49. custom_intro="""
  50. Class for outputs of [`DPRQuestionEncoder`].
  51. """
  52. )
  53. class DPRQuestionEncoderOutput(ModelOutput):
  54. r"""
  55. pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
  56. The DPR encoder outputs the *pooler_output* that corresponds to the question representation. Last layer
  57. hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
  58. This output is to be used to embed questions for nearest neighbors queries with context embeddings.
  59. """
  60. pooler_output: torch.FloatTensor
  61. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  62. attentions: tuple[torch.FloatTensor, ...] | None = None
  63. @dataclass
  64. @auto_docstring(
  65. custom_intro="""
  66. Class for outputs of [`DPRQuestionEncoder`].
  67. """
  68. )
  69. class DPRReaderOutput(ModelOutput):
  70. r"""
  71. start_logits (`torch.FloatTensor` of shape `(n_passages, sequence_length)`):
  72. Logits of the start index of the span for each passage.
  73. end_logits (`torch.FloatTensor` of shape `(n_passages, sequence_length)`):
  74. Logits of the end index of the span for each passage.
  75. relevance_logits (`torch.FloatTensor` of shape `(n_passages, )`):
  76. Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage to answer the
  77. question, compared to all the other passages.
  78. """
  79. start_logits: torch.FloatTensor
  80. end_logits: torch.FloatTensor | None = None
  81. relevance_logits: torch.FloatTensor | None = None
  82. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  83. attentions: tuple[torch.FloatTensor, ...] | None = None
  84. @auto_docstring
  85. class DPRPreTrainedModel(PreTrainedModel):
  86. _supports_sdpa = True
  87. class DPREncoder(DPRPreTrainedModel):
  88. base_model_prefix = "bert_model"
  89. def __init__(self, config: DPRConfig):
  90. super().__init__(config)
  91. self.bert_model = BertModel(config, add_pooling_layer=False)
  92. if self.bert_model.config.hidden_size <= 0:
  93. raise ValueError("Encoder hidden_size can't be zero")
  94. self.projection_dim = config.projection_dim
  95. if self.projection_dim > 0:
  96. self.encode_proj = nn.Linear(self.bert_model.config.hidden_size, config.projection_dim)
  97. # Initialize weights and apply final processing
  98. self.post_init()
  99. def forward(
  100. self,
  101. input_ids: Tensor,
  102. attention_mask: Tensor | None = None,
  103. token_type_ids: Tensor | None = None,
  104. inputs_embeds: Tensor | None = None,
  105. output_attentions: bool = False,
  106. output_hidden_states: bool = False,
  107. return_dict: bool = False,
  108. **kwargs,
  109. ) -> BaseModelOutputWithPooling | tuple[Tensor, ...]:
  110. outputs = self.bert_model(
  111. input_ids=input_ids,
  112. attention_mask=attention_mask,
  113. token_type_ids=token_type_ids,
  114. inputs_embeds=inputs_embeds,
  115. output_attentions=output_attentions,
  116. output_hidden_states=output_hidden_states,
  117. return_dict=return_dict,
  118. )
  119. sequence_output = outputs[0]
  120. pooled_output = sequence_output[:, 0, :]
  121. if self.projection_dim > 0:
  122. pooled_output = self.encode_proj(pooled_output)
  123. if not return_dict:
  124. return (sequence_output, pooled_output) + outputs[2:]
  125. return BaseModelOutputWithPooling(
  126. last_hidden_state=sequence_output,
  127. pooler_output=pooled_output,
  128. hidden_states=outputs.hidden_states,
  129. attentions=outputs.attentions,
  130. )
  131. @property
  132. def embeddings_size(self) -> int:
  133. if self.projection_dim > 0:
  134. return self.encode_proj.out_features
  135. return self.bert_model.config.hidden_size
  136. class DPRSpanPredictor(DPRPreTrainedModel):
  137. base_model_prefix = "encoder"
  138. def __init__(self, config: DPRConfig):
  139. super().__init__(config)
  140. self.encoder = DPREncoder(config)
  141. self.qa_outputs = nn.Linear(self.encoder.embeddings_size, 2)
  142. self.qa_classifier = nn.Linear(self.encoder.embeddings_size, 1)
  143. # Initialize weights and apply final processing
  144. self.post_init()
  145. def forward(
  146. self,
  147. input_ids: Tensor,
  148. attention_mask: Tensor,
  149. inputs_embeds: Tensor | None = None,
  150. output_attentions: bool = False,
  151. output_hidden_states: bool = False,
  152. return_dict: bool = False,
  153. **kwargs,
  154. ) -> DPRReaderOutput | tuple[Tensor, ...]:
  155. # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
  156. n_passages, sequence_length = input_ids.size() if input_ids is not None else inputs_embeds.size()[:2]
  157. # feed encoder
  158. outputs = self.encoder(
  159. input_ids,
  160. attention_mask=attention_mask,
  161. inputs_embeds=inputs_embeds,
  162. output_attentions=output_attentions,
  163. output_hidden_states=output_hidden_states,
  164. return_dict=return_dict,
  165. )
  166. sequence_output = outputs[0]
  167. # compute logits
  168. logits = self.qa_outputs(sequence_output)
  169. start_logits, end_logits = logits.split(1, dim=-1)
  170. start_logits = start_logits.squeeze(-1).contiguous()
  171. end_logits = end_logits.squeeze(-1).contiguous()
  172. relevance_logits = self.qa_classifier(sequence_output[:, 0, :])
  173. # resize
  174. start_logits = start_logits.view(n_passages, sequence_length)
  175. end_logits = end_logits.view(n_passages, sequence_length)
  176. relevance_logits = relevance_logits.view(n_passages)
  177. if not return_dict:
  178. return (start_logits, end_logits, relevance_logits) + outputs[2:]
  179. return DPRReaderOutput(
  180. start_logits=start_logits,
  181. end_logits=end_logits,
  182. relevance_logits=relevance_logits,
  183. hidden_states=outputs.hidden_states,
  184. attentions=outputs.attentions,
  185. )
  186. ##################
  187. # PreTrainedModel
  188. ##################
  189. class DPRPretrainedContextEncoder(DPRPreTrainedModel):
  190. """
  191. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  192. models.
  193. """
  194. config: DPRConfig
  195. base_model_prefix = "ctx_encoder"
  196. class DPRPretrainedQuestionEncoder(DPRPreTrainedModel):
  197. """
  198. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  199. models.
  200. """
  201. config: DPRConfig
  202. base_model_prefix = "question_encoder"
  203. class DPRPretrainedReader(DPRPreTrainedModel):
  204. """
  205. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  206. models.
  207. """
  208. config: DPRConfig
  209. base_model_prefix = "span_predictor"
  210. ###############
  211. # Actual Models
  212. ###############
  213. @auto_docstring(
  214. custom_intro="""
  215. The bare DPRContextEncoder transformer outputting pooler outputs as context representations.
  216. """
  217. )
  218. class DPRContextEncoder(DPRPretrainedContextEncoder):
  219. def __init__(self, config: DPRConfig):
  220. super().__init__(config)
  221. self.config = config
  222. self.ctx_encoder = DPREncoder(config)
  223. # Initialize weights and apply final processing
  224. self.post_init()
  225. @auto_docstring
  226. def forward(
  227. self,
  228. input_ids: Tensor | None = None,
  229. attention_mask: Tensor | None = None,
  230. token_type_ids: Tensor | None = None,
  231. inputs_embeds: Tensor | None = None,
  232. output_attentions: bool | None = None,
  233. output_hidden_states: bool | None = None,
  234. return_dict: bool | None = None,
  235. **kwargs,
  236. ) -> DPRContextEncoderOutput | tuple[Tensor, ...]:
  237. r"""
  238. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  239. Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be
  240. formatted with [CLS] and [SEP] tokens as follows:
  241. (a) For sequence pairs (for a pair title+text for example):
  242. ```
  243. tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
  244. token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
  245. ```
  246. (b) For single sequences (for a question for example):
  247. ```
  248. tokens: [CLS] the dog is hairy . [SEP]
  249. token_type_ids: 0 0 0 0 0 0 0
  250. ```
  251. DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
  252. rather than the left.
  253. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  254. [`PreTrainedTokenizer.__call__`] for details.
  255. [What are input IDs?](../glossary#input-ids)
  256. Examples:
  257. ```python
  258. >>> from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
  259. >>> tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
  260. >>> model = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
  261. >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="pt")["input_ids"]
  262. >>> embeddings = model(input_ids).pooler_output
  263. ```"""
  264. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  265. output_hidden_states = (
  266. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  267. )
  268. return_dict = return_dict if return_dict is not None else self.config.return_dict
  269. if input_ids is not None and inputs_embeds is not None:
  270. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  271. elif input_ids is not None:
  272. input_shape = input_ids.size()
  273. elif inputs_embeds is not None:
  274. input_shape = inputs_embeds.size()[:-1]
  275. else:
  276. raise ValueError("You have to specify either input_ids or inputs_embeds")
  277. device = input_ids.device if input_ids is not None else inputs_embeds.device
  278. if attention_mask is None:
  279. attention_mask = (
  280. torch.ones(input_shape, device=device)
  281. if input_ids is None
  282. else (input_ids != self.config.pad_token_id)
  283. )
  284. if token_type_ids is None:
  285. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  286. outputs = self.ctx_encoder(
  287. input_ids=input_ids,
  288. attention_mask=attention_mask,
  289. token_type_ids=token_type_ids,
  290. inputs_embeds=inputs_embeds,
  291. output_attentions=output_attentions,
  292. output_hidden_states=output_hidden_states,
  293. return_dict=return_dict,
  294. )
  295. if not return_dict:
  296. return outputs[1:]
  297. return DPRContextEncoderOutput(
  298. pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  299. )
  300. @auto_docstring(
  301. custom_intro="""
  302. The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.
  303. """
  304. )
  305. class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
  306. def __init__(self, config: DPRConfig):
  307. super().__init__(config)
  308. self.config = config
  309. self.question_encoder = DPREncoder(config)
  310. # Initialize weights and apply final processing
  311. self.post_init()
  312. @auto_docstring
  313. def forward(
  314. self,
  315. input_ids: Tensor | None = None,
  316. attention_mask: Tensor | None = None,
  317. token_type_ids: Tensor | None = None,
  318. inputs_embeds: Tensor | None = None,
  319. output_attentions: bool | None = None,
  320. output_hidden_states: bool | None = None,
  321. return_dict: bool | None = None,
  322. **kwargs,
  323. ) -> DPRQuestionEncoderOutput | tuple[Tensor, ...]:
  324. r"""
  325. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  326. Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be
  327. formatted with [CLS] and [SEP] tokens as follows:
  328. (a) For sequence pairs (for a pair title+text for example):
  329. ```
  330. tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
  331. token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
  332. ```
  333. (b) For single sequences (for a question for example):
  334. ```
  335. tokens: [CLS] the dog is hairy . [SEP]
  336. token_type_ids: 0 0 0 0 0 0 0
  337. ```
  338. DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
  339. rather than the left.
  340. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  341. [`PreTrainedTokenizer.__call__`] for details.
  342. [What are input IDs?](../glossary#input-ids)
  343. Examples:
  344. ```python
  345. >>> from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
  346. >>> tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
  347. >>> model = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
  348. >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="pt")["input_ids"]
  349. >>> embeddings = model(input_ids).pooler_output
  350. ```
  351. """
  352. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  353. output_hidden_states = (
  354. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  355. )
  356. return_dict = return_dict if return_dict is not None else self.config.return_dict
  357. if input_ids is not None and inputs_embeds is not None:
  358. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  359. elif input_ids is not None:
  360. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  361. input_shape = input_ids.size()
  362. elif inputs_embeds is not None:
  363. input_shape = inputs_embeds.size()[:-1]
  364. else:
  365. raise ValueError("You have to specify either input_ids or inputs_embeds")
  366. device = input_ids.device if input_ids is not None else inputs_embeds.device
  367. if attention_mask is None:
  368. attention_mask = (
  369. torch.ones(input_shape, device=device)
  370. if input_ids is None
  371. else (input_ids != self.config.pad_token_id)
  372. )
  373. if token_type_ids is None:
  374. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  375. outputs = self.question_encoder(
  376. input_ids=input_ids,
  377. attention_mask=attention_mask,
  378. token_type_ids=token_type_ids,
  379. inputs_embeds=inputs_embeds,
  380. output_attentions=output_attentions,
  381. output_hidden_states=output_hidden_states,
  382. return_dict=return_dict,
  383. )
  384. if not return_dict:
  385. return outputs[1:]
  386. return DPRQuestionEncoderOutput(
  387. pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  388. )
  389. @auto_docstring(
  390. custom_intro="""
  391. The bare DPRReader transformer outputting span predictions.
  392. """
  393. )
  394. class DPRReader(DPRPretrainedReader):
  395. def __init__(self, config: DPRConfig):
  396. super().__init__(config)
  397. self.config = config
  398. self.span_predictor = DPRSpanPredictor(config)
  399. # Initialize weights and apply final processing
  400. self.post_init()
  401. @auto_docstring
  402. def forward(
  403. self,
  404. input_ids: Tensor | None = None,
  405. attention_mask: Tensor | None = None,
  406. inputs_embeds: Tensor | None = None,
  407. output_attentions: bool | None = None,
  408. output_hidden_states: bool | None = None,
  409. return_dict: bool | None = None,
  410. **kwargs,
  411. ) -> DPRReaderOutput | tuple[Tensor, ...]:
  412. r"""
  413. input_ids (`tuple[torch.LongTensor]` of shapes `(n_passages, sequence_length)`):
  414. Indices of input sequence tokens in the vocabulary. It has to be a sequence triplet with 1) the question
  415. and 2) the passages titles and 3) the passages texts To match pretraining, DPR `input_ids` sequence should
  416. be formatted with [CLS] and [SEP] with the format:
  417. `[CLS] <question token ids> [SEP] <titles ids> [SEP] <texts ids>`
  418. DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
  419. rather than the left.
  420. Indices can be obtained using [`DPRReaderTokenizer`]. See this class documentation for more details.
  421. [What are input IDs?](../glossary#input-ids)
  422. inputs_embeds (`torch.FloatTensor` of shape `(n_passages, sequence_length, hidden_size)`, *optional*):
  423. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  424. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  425. model's internal embedding lookup matrix.
  426. Examples:
  427. ```python
  428. >>> from transformers import DPRReader, DPRReaderTokenizer
  429. >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base")
  430. >>> model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base")
  431. >>> encoded_inputs = tokenizer(
  432. ... questions=["What is love ?"],
  433. ... titles=["Haddaway"],
  434. ... texts=["'What Is Love' is a song recorded by the artist Haddaway"],
  435. ... return_tensors="pt",
  436. ... )
  437. >>> outputs = model(**encoded_inputs)
  438. >>> start_logits = outputs.start_logits
  439. >>> end_logits = outputs.end_logits
  440. >>> relevance_logits = outputs.relevance_logits
  441. ```
  442. """
  443. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  444. output_hidden_states = (
  445. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  446. )
  447. return_dict = return_dict if return_dict is not None else self.config.return_dict
  448. if input_ids is not None and inputs_embeds is not None:
  449. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  450. elif input_ids is not None:
  451. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  452. input_shape = input_ids.size()
  453. elif inputs_embeds is not None:
  454. input_shape = inputs_embeds.size()[:-1]
  455. else:
  456. raise ValueError("You have to specify either input_ids or inputs_embeds")
  457. device = input_ids.device if input_ids is not None else inputs_embeds.device
  458. if attention_mask is None:
  459. attention_mask = torch.ones(input_shape, device=device)
  460. return self.span_predictor(
  461. input_ids,
  462. attention_mask,
  463. inputs_embeds=inputs_embeds,
  464. output_attentions=output_attentions,
  465. output_hidden_states=output_hidden_states,
  466. return_dict=return_dict,
  467. )
  468. __all__ = [
  469. "DPRContextEncoder",
  470. "DPRPretrainedContextEncoder",
  471. "DPRPreTrainedModel",
  472. "DPRPretrainedQuestionEncoder",
  473. "DPRPretrainedReader",
  474. "DPRQuestionEncoder",
  475. "DPRReader",
  476. ]