modeling_rag.py 86 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665
  1. # Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """RAG model implementation."""
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from ...cache_utils import Cache, EncoderDecoderCache
  20. from ...configuration_utils import PreTrainedConfig
  21. from ...generation import GenerationConfig, GenerationMixin, GenerationMode, LogitsProcessorList, StoppingCriteriaList
  22. from ...generation.utils import GENERATION_MODES_MAPPING
  23. from ...modeling_outputs import ModelOutput
  24. from ...modeling_utils import PreTrainedModel
  25. from ...utils import auto_docstring, logging
  26. from .configuration_rag import RagConfig
  27. from .retrieval_rag import RagRetriever
  28. logger = logging.get_logger(__name__)
  29. @dataclass
  30. @auto_docstring(
  31. custom_intro="""
  32. Base class for retriever augmented marginalized models outputs.
  33. """
  34. )
  35. class RetrievAugLMMarginOutput(ModelOutput):
  36. r"""
  37. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  38. Language modeling loss.
  39. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  40. Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
  41. each vocabulary token.
  42. doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
  43. Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
  44. `question_encoder_last_hidden_state`.
  45. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  46. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  47. Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
  48. (see `past_key_values` input) to speed up sequential decoding.
  49. retrieved_doc_embeds (`torch.FloatTensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
  50. Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
  51. the `doc_scores`.
  52. retrieved_doc_ids (`torch.LongTensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
  53. The indexes of the embedded documents retrieved by the retriever.
  54. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  55. Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
  56. context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  57. Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
  58. retriever.
  59. question_encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  60. Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
  61. model.
  62. question_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  63. Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
  64. shape `(batch_size, sequence_length, hidden_size)`.
  65. Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
  66. question_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  67. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  68. sequence_length)`.
  69. Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
  70. average in the self-attention heads.
  71. generator_enc_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  72. Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
  73. generator_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  74. Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
  75. shape `(batch_size, sequence_length, hidden_size)`.
  76. Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
  77. generator_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  78. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  79. sequence_length)`.
  80. Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
  81. average in the self-attention heads.
  82. generator_dec_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  83. Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
  84. shape `(batch_size, sequence_length, hidden_size)`.
  85. Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
  86. generator_dec_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  87. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  88. sequence_length)`.
  89. Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
  90. average in the self-attention heads.
  91. generator_cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  92. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  93. sequence_length)`.
  94. Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
  95. weighted average in the cross-attention heads.
  96. """
  97. loss: torch.FloatTensor | None = None
  98. logits: torch.FloatTensor | None = None
  99. doc_scores: torch.FloatTensor | None = None
  100. past_key_values: Cache | None = None
  101. retrieved_doc_embeds: torch.FloatTensor | None = None
  102. retrieved_doc_ids: torch.LongTensor | None = None
  103. context_input_ids: torch.LongTensor | None = None
  104. context_attention_mask: torch.LongTensor | None = None
  105. question_encoder_last_hidden_state: torch.FloatTensor | None = None
  106. question_enc_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  107. question_enc_attentions: tuple[torch.FloatTensor, ...] | None = None
  108. generator_enc_last_hidden_state: torch.FloatTensor | None = None
  109. generator_enc_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  110. generator_enc_attentions: tuple[torch.FloatTensor, ...] | None = None
  111. generator_dec_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  112. generator_dec_attentions: tuple[torch.FloatTensor, ...] | None = None
  113. generator_cross_attentions: tuple[torch.FloatTensor, ...] | None = None
  114. @dataclass
  115. @auto_docstring
  116. class RetrievAugLMOutput(ModelOutput):
  117. r"""
  118. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  119. Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
  120. each vocabulary token.
  121. doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
  122. Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
  123. `question_encoder_last_hidden_state`.
  124. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  125. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  126. Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
  127. (see `past_key_values` input) to speed up sequential decoding.
  128. retrieved_doc_embeds (`torch.FloatTensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
  129. Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
  130. the `doc_scores`.
  131. retrieved_doc_ids (`torch.LongTensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
  132. The indexes of the embedded documents retrieved by the retriever.
  133. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  134. Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
  135. context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  136. Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
  137. retriever.
  138. question_encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  139. Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
  140. model.
  141. question_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  142. Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
  143. shape `(batch_size, sequence_length, hidden_size)`.
  144. Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
  145. question_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  146. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  147. sequence_length)`.
  148. Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
  149. average in the self-attention heads.
  150. generator_enc_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  151. Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
  152. generator_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  153. Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
  154. shape `(batch_size, sequence_length, hidden_size)`.
  155. Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
  156. generator_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  157. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  158. sequence_length)`.
  159. Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
  160. average in the self-attention heads.
  161. generator_dec_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  162. Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
  163. shape `(batch_size, sequence_length, hidden_size)`.
  164. Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
  165. generator_dec_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  166. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  167. sequence_length)`.
  168. Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
  169. average in the self-attention heads.
  170. generator_cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  171. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  172. sequence_length)`.
  173. Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
  174. weighted average in the cross-attention heads.
  175. """
  176. logits: torch.FloatTensor | None = None
  177. doc_scores: torch.FloatTensor | None = None
  178. past_key_values: Cache | None = None
  179. retrieved_doc_embeds: torch.FloatTensor | None = None
  180. retrieved_doc_ids: torch.LongTensor | None = None
  181. context_input_ids: torch.LongTensor | None = None
  182. context_attention_mask: torch.LongTensor | None = None
  183. question_encoder_last_hidden_state: torch.FloatTensor | None = None
  184. question_enc_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  185. question_enc_attentions: tuple[torch.FloatTensor, ...] | None = None
  186. generator_enc_last_hidden_state: torch.FloatTensor | None = None
  187. generator_enc_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  188. generator_enc_attentions: tuple[torch.FloatTensor, ...] | None = None
  189. generator_dec_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  190. generator_dec_attentions: tuple[torch.FloatTensor, ...] | None = None
  191. generator_cross_attentions: tuple[torch.FloatTensor, ...] | None = None
  192. @auto_docstring(
  193. custom_intro="""
  194. RAG models were released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP
  195. Tasks](https://huggingface.co/papers/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.
  196. RAG is a retriever augmented model and encapsulate three components: a question encoder, a dataset retriever and a
  197. generator, the encoder and generator are trainable while the retriever is just an indexed dataset.
  198. """
  199. )
  200. @auto_docstring
  201. class RagPreTrainedModel(PreTrainedModel):
  202. config: RagConfig
  203. base_model_prefix = "rag"
  204. _supports_flash_attn = True
  205. _supports_sdpa = True
  206. @classmethod
  207. def from_pretrained_question_encoder_generator(
  208. cls,
  209. question_encoder_pretrained_model_name_or_path: str | None = None,
  210. generator_pretrained_model_name_or_path: str | None = None,
  211. retriever: RagRetriever | None = None,
  212. **kwargs,
  213. ) -> PreTrainedModel:
  214. r"""
  215. Instantiates an question encoder and a generator from one or two base classes of the library from pretrained
  216. model checkpoints.
  217. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
  218. the model, you need to first set it back in training mode with `model.train()`.
  219. Params:
  220. question_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
  221. Information necessary to initiate the question encoder. Can be either:
  222. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  223. - A path to a *directory* containing model weights saved using
  224. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  225. generator_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
  226. Information necessary to initiate the generator. Can be either:
  227. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  228. - A path to a *directory* containing model weights saved using
  229. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  230. model_args (remaining positional arguments, *optional*):
  231. All remaining positional arguments will be passed to the underlying model's `__init__` method.
  232. retriever ([`RagRetriever`], *optional*):
  233. The retriever to use.
  234. kwwargs (remaining dictionary of keyword arguments, *optional*):
  235. Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
  236. `output_attentions=True`).
  237. - To update the question_encoder configuration, use the prefix *question_encoder_* for each
  238. configuration parameter.
  239. - To update the generator configuration, use the prefix *generator_* for each configuration parameter.
  240. - To update the parent model configuration, do not use a prefix for each configuration parameter.
  241. Behaves differently depending on whether a `config` is provided or automatically loaded.
  242. Example:
  243. ```python
  244. >>> from transformers import RagModel
  245. >>> # initialize a RAG from two pretrained models.
  246. >>> model = RagModel.from_pretrained_question_encoder_generator(
  247. ... "facebook/dpr-question_encoder-single-nq-base", "google-t5/t5-small"
  248. ... )
  249. >>> # saving model after fine-tuning
  250. >>> model.save_pretrained("./rag")
  251. >>> # load fine-tuned model
  252. >>> model = RagModel.from_pretrained("./rag")
  253. ```"""
  254. kwargs_question_encoder = {
  255. argument[len("question_encoder_") :]: value
  256. for argument, value in kwargs.items()
  257. if argument.startswith("question_encoder_")
  258. }
  259. kwargs_generator = {
  260. argument[len("generator_") :]: value
  261. for argument, value in kwargs.items()
  262. if argument.startswith("generator_")
  263. }
  264. # remove question_encoder, generator kwargs from kwargs
  265. for key in kwargs_question_encoder:
  266. del kwargs["question_encoder_" + key]
  267. for key in kwargs_generator:
  268. del kwargs["generator_" + key]
  269. # Load and initialize the question_encoder and generator
  270. # The distinction between question_encoder and generator at the model level is made
  271. # by the value of the flag `is_generator` that we need to set correctly.
  272. question_encoder = kwargs_question_encoder.pop("model", None)
  273. if question_encoder is None:
  274. assert question_encoder_pretrained_model_name_or_path is not None, (
  275. "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to"
  276. " be defined"
  277. )
  278. from ..auto.modeling_auto import AutoModel
  279. if "config" not in kwargs_question_encoder:
  280. from ..auto.configuration_auto import AutoConfig
  281. question_encoder_config, kwargs_question_encoder = AutoConfig.from_pretrained(
  282. question_encoder_pretrained_model_name_or_path,
  283. **kwargs_question_encoder,
  284. return_unused_kwargs=True,
  285. )
  286. kwargs_question_encoder["config"] = question_encoder_config
  287. question_encoder = AutoModel.from_pretrained(
  288. question_encoder_pretrained_model_name_or_path, **kwargs_question_encoder
  289. )
  290. generator = kwargs_generator.pop("model", None)
  291. if generator is None:
  292. assert generator_pretrained_model_name_or_path is not None, (
  293. "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has"
  294. " to be defined"
  295. )
  296. from ..auto.modeling_auto import AutoModelForSeq2SeqLM
  297. if "config" not in kwargs_generator:
  298. from ..auto.configuration_auto import AutoConfig
  299. generator_config, kwargs_generator = AutoConfig.from_pretrained(
  300. generator_pretrained_model_name_or_path, **kwargs_generator, return_unused_kwargs=True
  301. )
  302. kwargs_generator["config"] = generator_config
  303. generator = AutoModelForSeq2SeqLM.from_pretrained(
  304. generator_pretrained_model_name_or_path, **kwargs_generator
  305. )
  306. # instantiate config with corresponding kwargs
  307. config = kwargs.get("config")
  308. if config is None:
  309. config = RagConfig.from_question_encoder_generator_configs(
  310. question_encoder.config, generator.config, **kwargs
  311. )
  312. return cls(question_encoder=question_encoder, generator=generator, config=config, retriever=retriever)
  313. @auto_docstring
  314. class RagModel(RagPreTrainedModel):
  315. def __init__(
  316. self,
  317. config: PreTrainedConfig | None = None,
  318. question_encoder: PreTrainedModel | None = None,
  319. generator: PreTrainedModel | None = None,
  320. retriever: RagRetriever | None = None, # or maybe just use a `set_retriever(...)` method
  321. **kwargs,
  322. ):
  323. r"""
  324. question_encoder (`PreTrainedModel`, *optional*):
  325. The model responsible for encoding the question into hidden states for retrieval.
  326. generator (`PreTrainedModel`, *optional*):
  327. The model responsible for generating text based on retrieved documents.
  328. retriever (`RagRetriever`, *optional*):
  329. The component responsible for retrieving documents from a knowledge base given the encoded question.
  330. """
  331. assert config is not None or (question_encoder is not None and generator is not None), (
  332. "Either a configuration or an question_encoder and a generator has to be provided."
  333. )
  334. if config is None:
  335. config = RagConfig.from_question_encoder_generator_configs(
  336. question_encoder.config, generator.config, **kwargs
  337. )
  338. else:
  339. assert isinstance(config, self.config_class), f"config: {config} has to be of type {self.config_class}"
  340. super().__init__(config)
  341. if question_encoder is None:
  342. from ..auto.modeling_auto import AutoModel
  343. question_encoder = AutoModel.from_config(config.question_encoder)
  344. if generator is None:
  345. from ..auto.modeling_auto import AutoModelForSeq2SeqLM
  346. generator = AutoModelForSeq2SeqLM.from_config(config.generator)
  347. self.retriever = retriever
  348. if self.retriever is not None:
  349. assert isinstance(retriever, RagRetriever), (
  350. f"`self.retriever` is of type {type(self.retriever)}, but should be of type `RagRetriever`"
  351. )
  352. self.retriever = retriever
  353. self.question_encoder = question_encoder
  354. self.generator = generator
  355. self.ctx_encoder = None
  356. self.context_encoder_training = False
  357. self.post_init()
  358. @auto_docstring
  359. def forward(
  360. self,
  361. input_ids: torch.LongTensor | None = None,
  362. attention_mask: torch.Tensor | None = None,
  363. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  364. decoder_input_ids: torch.LongTensor | None = None,
  365. decoder_attention_mask: torch.BoolTensor | None = None,
  366. past_key_values: Cache | None = None,
  367. doc_scores: torch.FloatTensor | None = None,
  368. context_input_ids: torch.LongTensor | None = None,
  369. context_attention_mask: torch.LongTensor | None = None,
  370. use_cache: bool | None = None,
  371. output_attentions: bool | None = None,
  372. output_hidden_states: bool | None = None,
  373. output_retrieved: bool | None = None,
  374. n_docs: int | None = None,
  375. **kwargs,
  376. ) -> tuple[torch.Tensor] | RetrievAugLMOutput:
  377. r"""
  378. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  379. Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
  380. which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
  381. obtain the indices.
  382. [What are input IDs?](../glossary#input-ids)
  383. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
  384. Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
  385. *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
  386. sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
  387. generator's encoder.
  388. Used by the ([`RagModel`]) model during decoding.
  389. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  390. Provide for generation tasks. `None` by default, construct as per instructions for the generator model
  391. you're using with your RAG instance.
  392. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  393. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  394. be used by default.
  395. doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
  396. Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
  397. `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
  398. has to be provided to the forward pass. `doc_scores` can be computed via
  399. `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
  400. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  401. Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
  402. retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
  403. the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
  404. context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
  405. Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
  406. retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
  407. provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
  408. output_retrieved (`bool`, *optional*):
  409. Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
  410. `context_attention_mask`. See returned tensors for more detail.
  411. n_docs (`int`, *optional*):
  412. The number of documents to retrieve.
  413. Example:
  414. ```python
  415. >>> from transformers import AutoTokenizer, RagRetriever, RagModel
  416. >>> import torch
  417. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-base")
  418. >>> retriever = RagRetriever.from_pretrained(
  419. ... "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True
  420. ... )
  421. >>> # initialize with RagRetriever to do everything in one forward call
  422. >>> model = RagModel.from_pretrained("facebook/rag-token-base", retriever=retriever)
  423. >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
  424. >>> outputs = model(input_ids=inputs["input_ids"])
  425. ```"""
  426. n_docs = n_docs if n_docs is not None else self.config.n_docs
  427. use_cache = use_cache if use_cache is not None else self.config.use_cache
  428. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  429. output_hidden_states = (
  430. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  431. )
  432. output_retrieved = output_retrieved if output_retrieved is not None else self.config.output_retrieved
  433. # whether retriever has to be used
  434. has_to_retrieve = (
  435. self.retriever is not None
  436. and (context_input_ids is None or context_attention_mask is None or doc_scores is None)
  437. and encoder_outputs is None
  438. )
  439. # encoder_outputs are pre-computed during RAG-token generation
  440. if encoder_outputs is None:
  441. if has_to_retrieve:
  442. question_enc_outputs = self.question_encoder(
  443. input_ids, attention_mask=attention_mask, return_dict=True
  444. )
  445. question_encoder_last_hidden_state = question_enc_outputs[0] # hidden states of question encoder
  446. retriever_outputs = self.retriever(
  447. input_ids,
  448. question_encoder_last_hidden_state.detach().to(device="cpu", dtype=torch.float32).numpy(),
  449. prefix=getattr(self.generator.config, "prefix", None),
  450. n_docs=n_docs,
  451. return_tensors="pt",
  452. )
  453. if self.context_encoder_training:
  454. (
  455. context_input_ids,
  456. context_attention_mask,
  457. retrieved_doc_embeds,
  458. retrieved_doc_input_ids,
  459. retrieved_doc_attention_mask,
  460. retrieved_doc_ids,
  461. ) = (
  462. retriever_outputs["context_input_ids"],
  463. retriever_outputs["context_attention_mask"],
  464. retriever_outputs["retrieved_doc_embeds"],
  465. retriever_outputs["tokenized_doc_ids"],
  466. retriever_outputs["tokenized_doc_attention_mask"],
  467. retriever_outputs["doc_ids"],
  468. )
  469. context_input_ids = context_input_ids.to(input_ids)
  470. context_attention_mask = context_attention_mask.to(input_ids)
  471. retrieved_doc_input_ids = retrieved_doc_input_ids.to(input_ids)
  472. retrieved_doc_attention_mask = retrieved_doc_attention_mask.to(input_ids)
  473. retrieved_doc_embeds = self.ctx_encoder(
  474. retrieved_doc_input_ids, attention_mask=retrieved_doc_attention_mask, return_dict=True
  475. ).pooler_output
  476. retrieved_doc_embeds = retrieved_doc_embeds.view(
  477. -1, n_docs, question_encoder_last_hidden_state.shape[1]
  478. ) # reshaping
  479. # compute doc_scores involving ctx_encoder
  480. doc_scores = torch.bmm(
  481. question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)
  482. ).squeeze(1)
  483. else:
  484. context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = (
  485. retriever_outputs["context_input_ids"],
  486. retriever_outputs["context_attention_mask"],
  487. retriever_outputs["retrieved_doc_embeds"],
  488. retriever_outputs["doc_ids"],
  489. )
  490. # set to correct device
  491. retrieved_doc_embeds = retrieved_doc_embeds.to(question_encoder_last_hidden_state)
  492. context_input_ids = context_input_ids.to(input_ids)
  493. context_attention_mask = context_attention_mask.to(input_ids)
  494. # compute doc_scores
  495. doc_scores = torch.bmm(
  496. question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)
  497. ).squeeze(1)
  498. else:
  499. assert context_input_ids is not None, (
  500. "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can"
  501. " set a retriever using the `set_retriever(...)` function."
  502. )
  503. assert context_attention_mask is not None, (
  504. "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you"
  505. " can set a retriever using the `set_retriever(...)` function."
  506. )
  507. assert doc_scores is not None, (
  508. "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a"
  509. " retriever using the `set_retriever(...)` function."
  510. )
  511. assert doc_scores is not None, (
  512. "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function."
  513. )
  514. assert (doc_scores.shape[1] % n_docs) == 0, (
  515. f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
  516. f" {context_input_ids.shape[0]}."
  517. )
  518. # Decoder input without context documents
  519. if decoder_input_ids is not None:
  520. decoder_input_ids = decoder_input_ids.repeat_interleave(n_docs, dim=0)
  521. if decoder_attention_mask is not None:
  522. decoder_attention_mask = decoder_attention_mask.repeat_interleave(n_docs, dim=0)
  523. gen_outputs = self.generator(
  524. input_ids=context_input_ids,
  525. attention_mask=context_attention_mask,
  526. encoder_outputs=encoder_outputs,
  527. decoder_input_ids=decoder_input_ids,
  528. decoder_attention_mask=decoder_attention_mask,
  529. past_key_values=past_key_values,
  530. use_cache=use_cache,
  531. output_attentions=output_attentions,
  532. return_dict=True,
  533. )
  534. if not has_to_retrieve:
  535. question_encoder_last_hidden_state = None
  536. question_enc_hidden_states = None
  537. question_enc_attentions = None
  538. retrieved_doc_embeds = None
  539. retrieved_doc_ids = None
  540. else:
  541. question_enc_hidden_states = question_enc_outputs.hidden_states
  542. question_enc_attentions = question_enc_outputs.attentions
  543. if not has_to_retrieve or not output_retrieved:
  544. # don't output retrieved docs
  545. context_input_ids = (None,)
  546. context_attention_mask = None
  547. retrieved_doc_embeds = None
  548. retrieved_doc_ids = None
  549. return RetrievAugLMOutput(
  550. logits=gen_outputs.logits,
  551. doc_scores=doc_scores,
  552. past_key_values=gen_outputs.past_key_values,
  553. context_input_ids=context_input_ids,
  554. context_attention_mask=context_attention_mask,
  555. retrieved_doc_embeds=retrieved_doc_embeds,
  556. retrieved_doc_ids=retrieved_doc_ids,
  557. question_encoder_last_hidden_state=question_encoder_last_hidden_state,
  558. question_enc_hidden_states=question_enc_hidden_states,
  559. question_enc_attentions=question_enc_attentions,
  560. generator_enc_last_hidden_state=gen_outputs.encoder_last_hidden_state,
  561. generator_enc_hidden_states=gen_outputs.encoder_hidden_states,
  562. generator_enc_attentions=gen_outputs.encoder_attentions,
  563. generator_dec_hidden_states=gen_outputs.decoder_hidden_states,
  564. generator_dec_attentions=gen_outputs.decoder_attentions,
  565. generator_cross_attentions=gen_outputs.cross_attentions,
  566. )
  567. @auto_docstring(
  568. custom_intro="""
  569. A RAG-sequence model implementation. It performs RAG-sequence specific marginalization in the forward pass.
  570. """
  571. )
  572. class RagSequenceForGeneration(RagPreTrainedModel):
  573. def __init__(
  574. self,
  575. config: PreTrainedConfig | None = None,
  576. question_encoder: PreTrainedModel | None = None,
  577. generator: PreTrainedModel | None = None,
  578. retriever: RagRetriever | None = None,
  579. **kwargs,
  580. ):
  581. r"""
  582. question_encoder (`PreTrainedModel`, *optional*):
  583. The model responsible for encoding the question into hidden states for retrieval.
  584. generator (`PreTrainedModel`, *optional*):
  585. The model responsible for generating text based on retrieved documents.
  586. retriever (`RagRetriever`, *optional*):
  587. The component responsible for retrieving documents from a knowledge base given the encoded question.
  588. """
  589. assert config is not None or (question_encoder is not None and generator is not None), (
  590. "Either a configuration or an encoder and a generator has to be provided."
  591. )
  592. if config is None:
  593. config = RagConfig.from_question_encoder_generator_configs(
  594. question_encoder.config, generator.config, **kwargs
  595. )
  596. super().__init__(config)
  597. # instantiate model
  598. self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever)
  599. self.post_init()
  600. def set_retriever(self, retriever: RagRetriever):
  601. self.rag.retriever = retriever
  602. def set_context_encoder_for_training(self, ctx_encoder: PreTrainedModel):
  603. self.rag.context_encoder_training = True
  604. self.rag.ctx_encoder = ctx_encoder
  605. @auto_docstring
  606. def forward(
  607. self,
  608. input_ids: torch.LongTensor | None = None,
  609. attention_mask: torch.Tensor | None = None,
  610. encoder_outputs: tuple[tuple[torch.Tensor]] | None = None,
  611. decoder_input_ids: torch.LongTensor | None = None,
  612. decoder_attention_mask: torch.BoolTensor | None = None,
  613. past_key_values: Cache | None = None,
  614. context_input_ids: torch.LongTensor | None = None,
  615. context_attention_mask: torch.LongTensor | None = None,
  616. doc_scores: torch.FloatTensor | None = None,
  617. use_cache: bool | None = None,
  618. output_attentions: bool | None = None,
  619. output_hidden_states: bool | None = None,
  620. output_retrieved: bool | None = None,
  621. exclude_bos_score: bool | None = None,
  622. reduce_loss: bool | None = None,
  623. labels: torch.LongTensor | None = None,
  624. n_docs: int | None = None,
  625. **kwargs, # needs kwargs for generation
  626. ) -> RetrievAugLMMarginOutput:
  627. r"""
  628. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  629. Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
  630. which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
  631. obtain the indices.
  632. [What are input IDs?](../glossary#input-ids)
  633. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
  634. Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
  635. *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
  636. sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
  637. generator's encoder.
  638. Used by the ([`RagModel`]) model during decoding.
  639. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  640. Provide for generation tasks. `None` by default, construct as per instructions for the generator model
  641. you're using with your RAG instance.
  642. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  643. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  644. be used by default.
  645. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  646. Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
  647. retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
  648. the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
  649. context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
  650. Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
  651. retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
  652. provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
  653. doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
  654. Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
  655. `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
  656. has to be provided to the forward pass. `doc_scores` can be computed via
  657. `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
  658. output_retrieved (`bool`, *optional*):
  659. Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
  660. `context_attention_mask`. See returned tensors for more detail.
  661. exclude_bos_score (`bool`, *optional*):
  662. Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing
  663. the loss.
  664. reduce_loss (`bool`, *optional*):
  665. Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `torch.Tensor.sum`
  666. operation.
  667. n_docs (`int`, *optional*):
  668. The number of documents to retrieve.
  669. Example:
  670. ```python
  671. >>> from transformers import AutoTokenizer, RagRetriever, RagSequenceForGeneration
  672. >>> import torch
  673. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
  674. >>> retriever = RagRetriever.from_pretrained(
  675. ... "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
  676. ... )
  677. >>> # initialize with RagRetriever to do everything in one forward call
  678. >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
  679. >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
  680. >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
  681. >>> input_ids = inputs["input_ids"]
  682. >>> labels = targets["input_ids"]
  683. >>> outputs = model(input_ids=input_ids, labels=labels)
  684. >>> # or use retriever separately
  685. >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
  686. >>> # 1. Encode
  687. >>> question_hidden_states = model.question_encoder(input_ids)[0]
  688. >>> # 2. Retrieve
  689. >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
  690. >>> doc_scores = torch.bmm(
  691. ... question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)
  692. ... ).squeeze(1)
  693. >>> # 3. Forward to generator
  694. >>> outputs = model(
  695. ... context_input_ids=docs_dict["context_input_ids"],
  696. ... context_attention_mask=docs_dict["context_attention_mask"],
  697. ... doc_scores=doc_scores,
  698. ... decoder_input_ids=labels,
  699. ... )
  700. ```"""
  701. n_docs = n_docs if n_docs is not None else self.config.n_docs
  702. exclude_bos_score = exclude_bos_score if exclude_bos_score is not None else self.config.exclude_bos_score
  703. reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss
  704. if labels is not None:
  705. if decoder_input_ids is None:
  706. decoder_input_ids = labels
  707. use_cache = False
  708. outputs = self.rag(
  709. input_ids=input_ids,
  710. attention_mask=attention_mask,
  711. encoder_outputs=encoder_outputs,
  712. decoder_input_ids=decoder_input_ids,
  713. decoder_attention_mask=decoder_attention_mask,
  714. context_input_ids=context_input_ids,
  715. context_attention_mask=context_attention_mask,
  716. doc_scores=doc_scores,
  717. past_key_values=past_key_values,
  718. use_cache=use_cache,
  719. output_attentions=output_attentions,
  720. output_hidden_states=output_hidden_states,
  721. output_retrieved=output_retrieved,
  722. n_docs=n_docs,
  723. )
  724. loss = None
  725. if labels is not None:
  726. loss = self.get_nll(
  727. outputs.logits,
  728. outputs.doc_scores,
  729. decoder_input_ids,
  730. reduce_loss=reduce_loss,
  731. epsilon=self.config.label_smoothing,
  732. exclude_bos_score=exclude_bos_score,
  733. n_docs=n_docs,
  734. )
  735. return RetrievAugLMMarginOutput(
  736. loss=loss,
  737. logits=outputs.logits,
  738. doc_scores=outputs.doc_scores,
  739. past_key_values=outputs.past_key_values,
  740. context_input_ids=outputs.context_input_ids,
  741. context_attention_mask=outputs.context_attention_mask,
  742. retrieved_doc_embeds=outputs.retrieved_doc_embeds,
  743. retrieved_doc_ids=outputs.retrieved_doc_ids,
  744. question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state,
  745. question_enc_hidden_states=outputs.question_enc_hidden_states,
  746. question_enc_attentions=outputs.question_enc_attentions,
  747. generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state,
  748. generator_enc_hidden_states=outputs.generator_enc_hidden_states,
  749. generator_enc_attentions=outputs.generator_enc_attentions,
  750. generator_dec_hidden_states=outputs.generator_dec_hidden_states,
  751. generator_dec_attentions=outputs.generator_dec_attentions,
  752. generator_cross_attentions=outputs.generator_cross_attentions,
  753. )
  754. @property
  755. def retriever(self):
  756. return self.rag.retriever
  757. @property
  758. def generator(self):
  759. return self.rag.generator
  760. @property
  761. def question_encoder(self):
  762. return self.rag.question_encoder
  763. @torch.no_grad()
  764. def generate(
  765. self,
  766. input_ids: torch.LongTensor | None = None,
  767. attention_mask: torch.LongTensor | None = None,
  768. context_input_ids: torch.LongTensor | None = None,
  769. context_attention_mask: torch.LongTensor | None = None,
  770. doc_scores: torch.FloatTensor | None = None,
  771. do_deduplication: bool | None = None, # defaults to True
  772. num_return_sequences: int | None = None, # defaults to 1
  773. num_beams: int | None = None, # defaults to 1
  774. n_docs: int | None = None,
  775. **model_kwargs,
  776. ) -> torch.LongTensor:
  777. """
  778. Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation
  779. for more information on how to set other generate input parameters.
  780. Args:
  781. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  782. The sequence used as a prompt for the generation. If `input_ids` is not passed, then
  783. `context_input_ids` has to be provided.
  784. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  785. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  786. - 1 for tokens that are **not masked**,
  787. - 0 for tokens that are **masked**.
  788. [What are attention masks?](../glossary#attention-mask)
  789. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  790. Input IDs post-processed from the retrieved documents and the question encoder input_ids by the
  791. retriever.
  792. context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  793. Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
  794. retriever.
  795. If the model is not initialized with a `retriever` or `input_ids` is not given, `context_input_ids` and
  796. `context_attention_mask` have to be provided to the forward pass. They are returned by
  797. [`~RagRetriever.__call__`].
  798. doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
  799. Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
  800. `question_encoder_last_hidden_state`.
  801. If the model is not initialized with a `retriever` or `input_ids` is not given, `doc_scores` has to be
  802. provided to the forward pass. `doc_scores` are returned by [`~RagRetriever.__call__`].
  803. do_deduplication (`bool`, *optional*):
  804. Whether or not to deduplicate the generations from different context documents for a given input. Has
  805. to be set to `False` if used while training with distributed backend.
  806. num_return_sequences(`int`, *optional*, defaults to 1):
  807. The number of independently computed returned sequences for each element in the batch. Note that this
  808. is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function,
  809. where we set `num_return_sequences` to `num_beams`.
  810. num_beams (`int`, *optional*, defaults to 1):
  811. Number of beams for beam search. 1 means no beam search.
  812. n_docs (`int`, *optional*, defaults to `config.n_docs`)
  813. Number of documents to retrieve and/or number of documents for which to generate an answer.
  814. kwargs (`dict[str, Any]`, *optional*):
  815. Additional kwargs will be passed to [`~generation.GenerationMixin.generate`].
  816. Return:
  817. `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
  818. sequences. The second dimension (sequence length) is either equal to `max_length` or shorter if all batches
  819. finished early due to the `eos_token_id`.
  820. """
  821. n_docs = n_docs if n_docs is not None else self.config.n_docs
  822. do_deduplication = do_deduplication if do_deduplication is not None else self.config.do_deduplication
  823. num_doc_return_sequences = (
  824. num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
  825. )
  826. num_beams = num_beams if num_beams is not None else self.config.num_beams
  827. assert input_ids is not None or context_input_ids is not None, (
  828. " At least one of input_ids or context_input_ids must be given"
  829. )
  830. if self.retriever is not None and context_input_ids is None:
  831. question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
  832. context_input_ids = self.retriever(
  833. input_ids,
  834. question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
  835. prefix=getattr(self.generator.config, "prefix", None),
  836. n_docs=n_docs,
  837. return_tensors="pt",
  838. )["context_input_ids"]
  839. # set to correct device
  840. context_input_ids = context_input_ids.to(input_ids)
  841. hypos = []
  842. model_kwargs["num_beams"] = num_beams
  843. model_kwargs["num_return_sequences"] = num_beams
  844. model_kwargs["attention_mask"] = None
  845. batch_size = input_ids.shape[0] if input_ids is not None else context_input_ids.shape[0] // n_docs
  846. for index in range(batch_size):
  847. # first, generate beams from documents:
  848. generator_input_ids = context_input_ids[index * n_docs : (index + 1) * n_docs] # (n_docs, max_len)
  849. output_sequences = self.generator.generate(
  850. generator_input_ids,
  851. **model_kwargs,
  852. ) # n_docs * n_beam, tgt_len
  853. if do_deduplication:
  854. # do_deduplication, max_output_len
  855. output_sequences = torch.stack(list({str(k.tolist()): k for k in output_sequences}.values()))
  856. num_candidates = output_sequences.shape[
  857. 0
  858. ] # after deduplication, this number can be less than n_docs*n_beam
  859. # then, run model forwards to get nll scores:
  860. if input_ids is not None:
  861. new_input_ids = input_ids[index : index + 1].repeat(num_candidates, 1)
  862. outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
  863. else: # input_ids is None, need context_input_ids/mask and doc_scores
  864. assert context_attention_mask is not None, (
  865. "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you"
  866. " can set a retriever using the `set_retriever(...)` function."
  867. )
  868. assert doc_scores is not None, (
  869. "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a"
  870. " retriever using the `set_retriever(...)` function."
  871. )
  872. individual_input_ids = generator_input_ids.repeat(
  873. num_candidates, 1
  874. ) # (num_candidates*n_docs, max_len)
  875. individual_attention_mask = context_attention_mask[index * n_docs : (index + 1) * n_docs]
  876. individual_attention_mask = individual_attention_mask.repeat(num_candidates, 1)
  877. individual_doc_scores = doc_scores[index : (index + 1), :] # doc_scores.shape = [batch, n_docs]
  878. individual_doc_scores = individual_doc_scores.repeat(num_candidates, 1) # [num_candidates, n_docs]
  879. outputs = self(
  880. context_input_ids=individual_input_ids,
  881. context_attention_mask=individual_attention_mask,
  882. doc_scores=individual_doc_scores,
  883. labels=output_sequences,
  884. exclude_bos_score=True,
  885. )
  886. top_cand_inds = (-outputs["loss"]).topk(num_doc_return_sequences)[1]
  887. # add hypothesis
  888. hypos.append(output_sequences[top_cand_inds])
  889. return self._cat_and_pad(hypos, pad_token_id=self.config.generator.pad_token_id)
  890. def get_nll(
  891. self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, exclude_bos_score=False, n_docs=None
  892. ):
  893. # shift tokens left
  894. target = torch.cat(
  895. [target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1
  896. )
  897. n_docs = n_docs if n_docs is not None else self.config.n_docs
  898. # bos_token_id is None for T5
  899. bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id
  900. use_bos = bos_token_id is not None and target[:, 0].eq(bos_token_id).all()
  901. def _mask_pads(ll, smooth_obj):
  902. pad_mask = target.eq(self.config.generator.pad_token_id)
  903. if pad_mask.any():
  904. ll.masked_fill_(pad_mask, 0.0)
  905. smooth_obj.masked_fill_(pad_mask, 0.0)
  906. return ll.squeeze(-1), smooth_obj.squeeze(-1)
  907. # seq_logits dim = (batch*n_docs, tgt_len , #vocabs)
  908. seq_logprobs = nn.functional.log_softmax(seq_logits, dim=-1).view(
  909. seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
  910. ) # batch_size x n_docs x tgt_len x #vocab_size
  911. doc_logprobs = nn.functional.log_softmax(doc_scores, dim=1).unsqueeze(-1).unsqueeze(-1)
  912. # RAG-sequence marginalization
  913. first_token_scores = seq_logprobs[:, :, :1, :]
  914. second_token_scores = seq_logprobs[:, :, 1:2, :]
  915. remainder = seq_logprobs[:, :, 2:, :]
  916. rag_logprobs = torch.cat([first_token_scores, second_token_scores + doc_logprobs, remainder], dim=2)
  917. # calculate loss
  918. target = target.unsqueeze(1).unsqueeze(-1).repeat(1, n_docs, 1, 1)
  919. assert target.dim() == rag_logprobs.dim()
  920. ll = rag_logprobs.gather(dim=-1, index=target)
  921. smooth_obj = rag_logprobs.sum(dim=-1, keepdim=True) # total sum of all (normalised) logits
  922. ll, smooth_obj = _mask_pads(ll, smooth_obj)
  923. # sum over tokens, exclude bos while scoring
  924. ll = ll[:, :, 1:].sum(2) if exclude_bos_score and use_bos else ll.sum(2)
  925. smooth_obj = smooth_obj.sum(2)
  926. ll = ll.logsumexp(1) # logsumexp over docs
  927. smooth_obj = smooth_obj.logsumexp(1)
  928. nll_loss = -ll
  929. smooth_loss = -smooth_obj
  930. if reduce_loss:
  931. nll_loss = nll_loss.sum()
  932. smooth_loss = smooth_loss.sum()
  933. eps_i = epsilon / rag_logprobs.size(-1)
  934. loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
  935. return loss
  936. @staticmethod
  937. def _cat_and_pad(tensors, pad_token_id):
  938. output = tensors[0].new(sum(t.shape[0] for t in tensors), max(t.shape[1] for t in tensors)).fill_(pad_token_id)
  939. ind = 0
  940. for t in tensors:
  941. output[ind : ind + t.shape[0], : t.shape[1]] = t
  942. ind += t.shape[0]
  943. return output
  944. @auto_docstring(
  945. custom_intro="""
  946. A RAG-token model implementation. It performs RAG-token specific marginalization in the forward pass.
  947. """
  948. )
  949. class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
  950. def __init__(
  951. self,
  952. config: PreTrainedConfig | None = None,
  953. question_encoder: PreTrainedModel | None = None,
  954. generator: PreTrainedModel | None = None,
  955. retriever: RagRetriever | None = None,
  956. **kwargs,
  957. ):
  958. r"""
  959. question_encoder (`PreTrainedModel`, *optional*):
  960. The model responsible for encoding the question into hidden states for retrieval.
  961. generator (`PreTrainedModel`, *optional*):
  962. The model responsible for generating text based on retrieved documents.
  963. retriever (`RagRetriever`, *optional*):
  964. The component responsible for retrieving documents from a knowledge base given the encoded question.
  965. """
  966. assert config is not None or (question_encoder is not None and generator is not None), (
  967. "Either a configuration or an encoder and a generator has to be provided."
  968. )
  969. if config is None:
  970. config = RagConfig.from_question_encoder_generator_configs(
  971. question_encoder.config, generator.config, **kwargs
  972. )
  973. super().__init__(config)
  974. # instantiate model
  975. self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever)
  976. self.post_init()
  977. def set_retriever(self, retriever: RagRetriever):
  978. self.rag.retriever = retriever
  979. def set_context_encoder_for_training(self, ctx_encoder: PreTrainedModel):
  980. self.rag.context_encoder_training = True
  981. self.rag.ctx_encoder = ctx_encoder
  982. def prepare_inputs_for_generation(
  983. self,
  984. decoder_input_ids,
  985. past_key_values=None,
  986. attention_mask=None,
  987. use_cache=None,
  988. encoder_outputs=None,
  989. doc_scores=None,
  990. n_docs=None,
  991. **kwargs,
  992. ):
  993. # Overwritten -- `do_marginalize` is explicitly set in the output
  994. if past_key_values is not None:
  995. # if past is defined use only last decoder_input_ids
  996. decoder_input_ids = decoder_input_ids[:, -1:]
  997. return {
  998. "input_ids": None,
  999. "encoder_outputs": encoder_outputs,
  1000. "doc_scores": doc_scores,
  1001. "context_attention_mask": attention_mask,
  1002. "decoder_input_ids": decoder_input_ids,
  1003. "past_key_values": past_key_values,
  1004. "use_cache": use_cache,
  1005. "do_marginalize": True,
  1006. "n_docs": n_docs,
  1007. }
  1008. @property
  1009. def retriever(self):
  1010. return self.rag.retriever
  1011. @property
  1012. def generator(self):
  1013. return self.rag.generator
  1014. @property
  1015. def question_encoder(self):
  1016. return self.rag.question_encoder
  1017. @staticmethod
  1018. def _reorder_cache(past_key_values, beam_idx):
  1019. """Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
  1020. def _reorder_stacked(hidden_states, new_order):
  1021. n_docs = hidden_states.shape[0] // new_order.shape[0]
  1022. hidden_states = hidden_states.view(-1, n_docs, *hidden_states.shape[1:])
  1023. hidden_states = hidden_states.index_select(0, new_order)
  1024. result = hidden_states.view(-1, *hidden_states.shape[2:])
  1025. return result
  1026. reordered_past = ()
  1027. for idx in range(len(past_key_values)):
  1028. if isinstance(past_key_values, EncoderDecoderCache):
  1029. self_attention_k, self_attention_v, cross_attention_k, cross_attention_v = (
  1030. _reorder_stacked(x, beam_idx.to(x.device))
  1031. for x in (
  1032. past_key_values.self_attention_cache.layers[idx].keys,
  1033. past_key_values.self_attention_cache.layers[idx].values,
  1034. past_key_values.cross_attention_cache.layers[idx].keys,
  1035. past_key_values.cross_attention_cache.layers[idx].values,
  1036. )
  1037. )
  1038. new_tuple = (self_attention_k, self_attention_v, cross_attention_k, cross_attention_v)
  1039. else:
  1040. self_attention_k, self_attention_v = (
  1041. _reorder_stacked(x, beam_idx.to(x.device))
  1042. for x in (past_key_values.layers[idx].keys, past_key_values.layers[idx].values)
  1043. )
  1044. new_tuple = (self_attention_k, self_attention_v)
  1045. reordered_past += (new_tuple,)
  1046. return type(past_key_values)(reordered_past)
  1047. def marginalize(self, seq_logits, doc_scores, n_docs=None):
  1048. n_docs = n_docs if n_docs is not None else self.config.n_docs
  1049. # RAG-token marginalization
  1050. seq_logprobs = nn.functional.log_softmax(seq_logits, dim=-1).view(
  1051. seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
  1052. )
  1053. doc_logprobs = torch.log_softmax(doc_scores, dim=1)
  1054. log_prob_sum = seq_logprobs + doc_logprobs.unsqueeze(-1).unsqueeze(-1)
  1055. return torch.logsumexp(log_prob_sum, dim=1)
  1056. @auto_docstring
  1057. def forward(
  1058. self,
  1059. input_ids: torch.LongTensor | None = None,
  1060. attention_mask: torch.FloatTensor | None = None,
  1061. encoder_outputs: tuple[tuple[torch.Tensor]] | None = None,
  1062. decoder_input_ids: torch.LongTensor | None = None,
  1063. decoder_attention_mask: torch.BoolTensor | None = None,
  1064. past_key_values: Cache | None = None,
  1065. context_input_ids: torch.LongTensor | None = None,
  1066. context_attention_mask: torch.LongTensor | None = None,
  1067. doc_scores: torch.FloatTensor | None = None,
  1068. use_cache: bool | None = None,
  1069. output_attentions: bool | None = None,
  1070. output_hidden_states: bool | None = None,
  1071. output_retrieved: bool | None = None,
  1072. do_marginalize: bool | None = None,
  1073. reduce_loss: bool | None = None,
  1074. labels: torch.LongTensor | None = None,
  1075. n_docs: int | None = None,
  1076. **kwargs, # needs kwargs for generation
  1077. ) -> RetrievAugLMMarginOutput:
  1078. r"""
  1079. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1080. Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
  1081. which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
  1082. obtain the indices.
  1083. [What are input IDs?](../glossary#input-ids)
  1084. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
  1085. Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
  1086. *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
  1087. sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
  1088. generator's encoder.
  1089. Used by the ([`RagModel`]) model during decoding.
  1090. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1091. Provide for generation tasks. `None` by default, construct as per instructions for the generator model
  1092. you're using with your RAG instance.
  1093. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1094. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1095. be used by default.
  1096. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  1097. Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
  1098. retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
  1099. the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
  1100. context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
  1101. Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
  1102. retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
  1103. provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
  1104. doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
  1105. Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
  1106. `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
  1107. has to be provided to the forward pass. `doc_scores` can be computed via
  1108. `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
  1109. output_retrieved (`bool`, *optional*):
  1110. Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
  1111. `context_attention_mask`. See returned tensors for more detail.
  1112. do_marginalize (`bool`, *optional*):
  1113. If `True`, the logits are marginalized over all documents by making use of
  1114. `torch.nn.functional.log_softmax`.
  1115. reduce_loss (`bool`, *optional*):
  1116. Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `torch.Tensor.sum`
  1117. operation.
  1118. n_docs (`int`, *optional*):
  1119. The number of documents to retrieve.
  1120. Example:
  1121. ```python
  1122. >>> from transformers import AutoTokenizer, RagRetriever, RagTokenForGeneration
  1123. >>> import torch
  1124. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-nq")
  1125. >>> retriever = RagRetriever.from_pretrained(
  1126. ... "facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True
  1127. ... )
  1128. >>> # initialize with RagRetriever to do everything in one forward call
  1129. >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
  1130. >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
  1131. >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
  1132. >>> input_ids = inputs["input_ids"]
  1133. >>> labels = targets["input_ids"]
  1134. >>> outputs = model(input_ids=input_ids, labels=labels)
  1135. >>> # or use retriever separately
  1136. >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
  1137. >>> # 1. Encode
  1138. >>> question_hidden_states = model.question_encoder(input_ids)[0]
  1139. >>> # 2. Retrieve
  1140. >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
  1141. >>> doc_scores = torch.bmm(
  1142. ... question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)
  1143. ... ).squeeze(1)
  1144. >>> # 3. Forward to generator
  1145. >>> outputs = model(
  1146. ... context_input_ids=docs_dict["context_input_ids"],
  1147. ... context_attention_mask=docs_dict["context_attention_mask"],
  1148. ... doc_scores=doc_scores,
  1149. ... decoder_input_ids=labels,
  1150. ... )
  1151. >>> # or directly generate
  1152. >>> generated = model.generate(
  1153. ... context_input_ids=docs_dict["context_input_ids"],
  1154. ... context_attention_mask=docs_dict["context_attention_mask"],
  1155. ... doc_scores=doc_scores,
  1156. ... )
  1157. >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
  1158. ```"""
  1159. n_docs = n_docs if n_docs is not None else self.config.n_docs
  1160. do_marginalize = do_marginalize if do_marginalize is not None else self.config.do_marginalize
  1161. reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss
  1162. if labels is not None:
  1163. if decoder_input_ids is None:
  1164. decoder_input_ids = labels
  1165. use_cache = False
  1166. outputs = self.rag(
  1167. input_ids=input_ids,
  1168. attention_mask=attention_mask,
  1169. encoder_outputs=encoder_outputs,
  1170. decoder_input_ids=decoder_input_ids,
  1171. decoder_attention_mask=decoder_attention_mask,
  1172. context_input_ids=context_input_ids,
  1173. context_attention_mask=context_attention_mask,
  1174. doc_scores=doc_scores,
  1175. past_key_values=past_key_values,
  1176. use_cache=use_cache,
  1177. output_attentions=output_attentions,
  1178. output_hidden_states=output_hidden_states,
  1179. output_retrieved=output_retrieved,
  1180. n_docs=n_docs,
  1181. )
  1182. loss = None
  1183. logits = outputs.logits
  1184. if labels is not None:
  1185. assert decoder_input_ids is not None
  1186. loss = self.get_nll(
  1187. outputs.logits,
  1188. outputs.doc_scores,
  1189. labels,
  1190. reduce_loss=reduce_loss,
  1191. epsilon=self.config.label_smoothing,
  1192. n_docs=n_docs,
  1193. )
  1194. if do_marginalize:
  1195. logits = self.marginalize(logits, outputs.doc_scores, n_docs)
  1196. return RetrievAugLMMarginOutput(
  1197. loss=loss,
  1198. logits=logits,
  1199. doc_scores=outputs.doc_scores,
  1200. past_key_values=outputs.past_key_values,
  1201. context_input_ids=outputs.context_input_ids,
  1202. context_attention_mask=outputs.context_attention_mask,
  1203. retrieved_doc_embeds=outputs.retrieved_doc_embeds,
  1204. retrieved_doc_ids=outputs.retrieved_doc_ids,
  1205. question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state,
  1206. question_enc_hidden_states=outputs.question_enc_hidden_states,
  1207. question_enc_attentions=outputs.question_enc_attentions,
  1208. generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state,
  1209. generator_enc_hidden_states=outputs.generator_enc_hidden_states,
  1210. generator_enc_attentions=outputs.generator_enc_attentions,
  1211. generator_dec_hidden_states=outputs.generator_dec_hidden_states,
  1212. generator_dec_attentions=outputs.generator_dec_attentions,
  1213. generator_cross_attentions=outputs.generator_cross_attentions,
  1214. )
  1215. @torch.no_grad()
  1216. def generate(
  1217. self,
  1218. input_ids: torch.LongTensor | None = None,
  1219. attention_mask: torch.LongTensor | None = None,
  1220. context_input_ids: torch.LongTensor | None = None,
  1221. context_attention_mask: torch.LongTensor | None = None,
  1222. doc_scores: torch.FloatTensor | None = None,
  1223. n_docs: int | None = None,
  1224. generation_config: GenerationConfig | None = None,
  1225. prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] | None = None,
  1226. logits_processor: LogitsProcessorList | None = LogitsProcessorList(),
  1227. stopping_criteria: StoppingCriteriaList | None = StoppingCriteriaList(),
  1228. **kwargs,
  1229. ) -> torch.LongTensor:
  1230. """
  1231. Implements RAG token decoding.
  1232. Args:
  1233. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1234. The sequence used as a prompt for the generation. If `input_ids` is not passed, then
  1235. `context_input_ids` has to be provided.
  1236. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1237. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1238. - 1 for tokens that are **not masked**,
  1239. - 0 for tokens that are **masked**.
  1240. [What are attention masks?](../glossary#attention-mask)
  1241. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  1242. Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
  1243. retriever.
  1244. If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
  1245. forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
  1246. context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  1247. Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
  1248. retriever.
  1249. If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
  1250. forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
  1251. doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
  1252. Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
  1253. `question_encoder_last_hidden_state`.
  1254. If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
  1255. forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
  1256. n_docs (`int`, *optional*, defaults to `config.n_docs`)
  1257. Number of documents to retrieve and/or number of documents for which to generate an answer.
  1258. generation_config (`~generation.GenerationConfig`, *optional*):
  1259. The generation configuration to be used as base parametrization for the generation call. `**kwargs`
  1260. passed to generate matching the attributes of `generation_config` will override them. If
  1261. `generation_config` is not provided, the default will be used, which has the following loading
  1262. priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
  1263. configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
  1264. default values, whose documentation should be checked to parameterize generation.
  1265. prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
  1266. If provided, this function constraints the beam search to allowed tokens only at each step. If not
  1267. provided no constraint is applied. This function takes 2 arguments `inputs_ids` and the batch ID
  1268. `batch_id`. It has to return a list with the allowed tokens for the next generation step conditioned on
  1269. the previously generated tokens `inputs_ids` and the batch ID `batch_id`. This argument is useful for
  1270. constrained generation conditioned on the prefix, as described in [Autoregressive Entity
  1271. Retrieval](https://huggingface.co/papers/2010.00904).
  1272. logits_processor (`LogitsProcessorList`, *optional*):
  1273. Custom logits processors that complement the default logits processors built from arguments and a
  1274. model's config. If a logit processor is passed that is already created with the arguments or a model's
  1275. config an error is thrown.
  1276. stopping_criteria (`StoppingCriteriaList`, *optional*):
  1277. Custom stopping criteria that complement the default stopping criteria built from arguments and a
  1278. model's config. If a stopping criteria is passed that is already created with the arguments or a
  1279. model's config an error is thrown.
  1280. kwargs (`dict[str, Any]`, *optional*):
  1281. Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
  1282. forwarded to the `forward` function of the model.
  1283. Return:
  1284. `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
  1285. sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches
  1286. finished early due to the `eos_token_id`.
  1287. """
  1288. # Handle `generation_config` and kwargs that might update it
  1289. generation_mode_kwargs = self._extract_generation_mode_kwargs(None, kwargs, False, None, None)
  1290. generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
  1291. generation_mode = generation_config.get_generation_mode()
  1292. if generation_mode not in [
  1293. GenerationMode.SAMPLE,
  1294. GenerationMode.GREEDY_SEARCH,
  1295. GenerationMode.BEAM_SEARCH,
  1296. GenerationMode.BEAM_SAMPLE,
  1297. ]:
  1298. raise ValueError(
  1299. f"RAG model is not compatible with {generation_mode} generation. Please check your generation parameters."
  1300. )
  1301. # type() required to access the unbound class-level method
  1302. decoding_method = getattr(type(self), GENERATION_MODES_MAPPING[generation_mode])
  1303. self._validate_model_kwargs(model_kwargs.copy())
  1304. self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)
  1305. kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
  1306. self._prepare_special_tokens(generation_config, kwargs_has_attention_mask)
  1307. # set default parameters
  1308. n_docs = n_docs if n_docs is not None else self.config.n_docs
  1309. # retrieve docs
  1310. if self.retriever is not None and context_input_ids is None:
  1311. question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
  1312. out = self.retriever(
  1313. input_ids,
  1314. question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
  1315. prefix=getattr(self.generator.config, "prefix", None),
  1316. n_docs=n_docs,
  1317. return_tensors="pt",
  1318. )
  1319. context_input_ids, context_attention_mask, retrieved_doc_embeds = (
  1320. out["context_input_ids"],
  1321. out["context_attention_mask"],
  1322. out["retrieved_doc_embeds"],
  1323. )
  1324. # set to correct device
  1325. retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states)
  1326. context_input_ids = context_input_ids.to(input_ids)
  1327. context_attention_mask = context_attention_mask.to(input_ids)
  1328. # compute doc_scores
  1329. doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze(
  1330. 1
  1331. )
  1332. assert (context_input_ids.shape[0] % n_docs) == 0, (
  1333. f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
  1334. f" {context_input_ids.shape[0]}."
  1335. )
  1336. # batch_size
  1337. batch_size = context_input_ids.shape[0] // n_docs
  1338. encoder = self.rag.generator.get_encoder()
  1339. encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
  1340. input_ids = torch.full(
  1341. (batch_size * generation_config.num_beams, 1),
  1342. generation_config.decoder_start_token_id,
  1343. dtype=torch.long,
  1344. device=next(self.parameters()).device,
  1345. )
  1346. input_ids_seq_length = input_ids.shape[-1]
  1347. last_hidden_state = encoder_outputs["last_hidden_state"]
  1348. def extend_enc_output(tensor, num_beams=None):
  1349. # split into `batch_size`, `num_beams`, `num_docs`
  1350. tensor = tensor[None, None, :].reshape((batch_size, 1, n_docs) + tensor.shape[1:])
  1351. # repeat same last hidden states over `num_beams` dimension
  1352. tensor = tensor.expand((batch_size, num_beams, n_docs) + tensor.shape[3:])
  1353. # merge `batch_size`, `num_beams`, `num_docs` dims again
  1354. return tensor.reshape((batch_size * num_beams * n_docs,) + tensor.shape[3:])
  1355. # correctly extend last_hidden_state and attention mask
  1356. context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams)
  1357. encoder_outputs["last_hidden_state"] = extend_enc_output(
  1358. last_hidden_state, num_beams=generation_config.num_beams
  1359. )
  1360. doc_scores = doc_scores.repeat_interleave(generation_config.num_beams, dim=0)
  1361. # define start_len & additional parameters
  1362. model_kwargs["doc_scores"] = doc_scores
  1363. model_kwargs["encoder_outputs"] = encoder_outputs
  1364. model_kwargs["attention_mask"] = context_attention_mask
  1365. model_kwargs["n_docs"] = n_docs
  1366. model_kwargs["use_cache"] = generation_config.use_cache
  1367. prepared_logits_processor = self._get_logits_processor(
  1368. generation_config=generation_config,
  1369. input_ids_seq_length=input_ids_seq_length,
  1370. encoder_input_ids=context_input_ids,
  1371. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  1372. logits_processor=logits_processor,
  1373. device=input_ids.device,
  1374. )
  1375. prepared_stopping_criteria = self._get_stopping_criteria(
  1376. generation_config=generation_config, stopping_criteria=stopping_criteria
  1377. )
  1378. self._prepare_cache_for_generation(
  1379. generation_config,
  1380. model_kwargs,
  1381. generation_mode=None,
  1382. batch_size=input_ids.shape[0],
  1383. max_cache_length=generation_config.max_length - 1,
  1384. )
  1385. return decoding_method(
  1386. self,
  1387. input_ids,
  1388. logits_processor=prepared_logits_processor,
  1389. stopping_criteria=prepared_stopping_criteria,
  1390. generation_config=generation_config,
  1391. **generation_mode_kwargs,
  1392. **model_kwargs,
  1393. )
  1394. # Auxiliary functions for beam search
  1395. def _temporary_reorder_cache(self, past_key_values, beam_idx):
  1396. # RAG should always use the legacy path even though the LM backbone (T5) uses new cache format
  1397. # because RAG expands input for doc-size internally. TODO: raushan, remove me when all models support
  1398. # new cache format
  1399. past_key_values = self._reorder_cache(past_key_values, beam_idx)
  1400. return past_key_values
  1401. def get_input_embeddings(self):
  1402. return self.rag.generator.get_input_embeddings()
  1403. def get_output_embeddings(self):
  1404. return self.rag.generator.get_output_embeddings()
  1405. def set_output_embeddings(self, new_embeddings):
  1406. return self.rag.generator.set_output_embeddings(new_embeddings)
  1407. def shift_tokens_right(self, input_ids, start_token_id=None):
  1408. """Shift input ids one token to the right, and pad with start_token_id"""
  1409. if start_token_id is None:
  1410. start_token_id = self.config.decoder_start_token_id
  1411. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  1412. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  1413. shifted_input_ids[:, 0] = start_token_id
  1414. return shifted_input_ids
  1415. def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, n_docs=None):
  1416. n_docs = n_docs if n_docs is not None else self.config.n_docs
  1417. # shift tokens left
  1418. target = torch.cat(
  1419. [target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1
  1420. )
  1421. def _mask_pads(ll, smooth_obj):
  1422. pad_mask = target.eq(self.config.generator.pad_token_id)
  1423. if pad_mask.any():
  1424. ll.masked_fill_(pad_mask, 0.0)
  1425. smooth_obj.masked_fill_(pad_mask, 0.0)
  1426. return ll.squeeze(-1), smooth_obj.squeeze(-1)
  1427. rag_logprobs = self.marginalize(seq_logits, doc_scores, n_docs)
  1428. target = target.unsqueeze(-1)
  1429. assert target.dim() == rag_logprobs.dim()
  1430. ll = rag_logprobs.gather(dim=-1, index=target)
  1431. smooth_obj = rag_logprobs.sum(dim=-1, keepdim=True) # total sum of all (normalised) logits
  1432. ll, smooth_obj = _mask_pads(ll, smooth_obj)
  1433. ll = ll.sum(1) # sum over tokens
  1434. smooth_obj = smooth_obj.sum(1)
  1435. nll_loss = -ll
  1436. smooth_loss = -smooth_obj
  1437. if reduce_loss:
  1438. nll_loss = nll_loss.sum()
  1439. smooth_loss = smooth_loss.sum()
  1440. eps_i = epsilon / rag_logprobs.size(-1)
  1441. loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
  1442. return loss
  1443. __all__ = ["RagModel", "RagPreTrainedModel", "RagSequenceForGeneration", "RagTokenForGeneration"]