retrieval_rag.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. # Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """RAG Retriever model implementation."""
  15. import os
  16. import pickle
  17. import time
  18. from collections.abc import Iterable
  19. import numpy as np
  20. from ...tokenization_python import PreTrainedTokenizer
  21. from ...tokenization_utils_base import BatchEncoding
  22. from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends, strtobool
  23. from .configuration_rag import RagConfig
  24. from .tokenization_rag import RagTokenizer
  25. if is_datasets_available():
  26. from datasets import Dataset, load_dataset, load_from_disk
  27. if is_faiss_available():
  28. import faiss
  29. logger = logging.get_logger(__name__)
  30. LEGACY_INDEX_PATH = "https://storage.googleapis.com/huggingface-nlp/datasets/wiki_dpr/"
  31. class Index:
  32. """
  33. A base class for the Indices encapsulated by the [`RagRetriever`].
  34. """
  35. def get_doc_dicts(self, doc_ids: np.ndarray) -> list[dict]:
  36. """
  37. Returns a list of dictionaries, containing titles and text of the retrieved documents.
  38. Args:
  39. doc_ids (`np.ndarray` of shape `(batch_size, n_docs)`):
  40. A tensor of document indices.
  41. """
  42. raise NotImplementedError
  43. def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> tuple[np.ndarray, np.ndarray]:
  44. """
  45. For each query in the batch, retrieves `n_docs` documents.
  46. Args:
  47. question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`):
  48. An array of query vectors.
  49. n_docs (`int`):
  50. The number of docs retrieved per query.
  51. Returns:
  52. `np.ndarray` of shape `(batch_size, n_docs)`: A tensor of indices of retrieved documents. `np.ndarray` of
  53. shape `(batch_size, vector_size)`: A tensor of vector representations of retrieved documents.
  54. """
  55. raise NotImplementedError
  56. def is_initialized(self):
  57. """
  58. Returns `True` if index is already initialized.
  59. """
  60. raise NotImplementedError
  61. def init_index(self):
  62. """
  63. A function responsible for loading the index into memory. Should be called only once per training run of a RAG
  64. model. E.g. if the model is trained on multiple GPUs in a distributed setup, only one of the workers will load
  65. the index.
  66. """
  67. raise NotImplementedError
  68. class LegacyIndex(Index):
  69. """
  70. An index which can be deserialized from the files built using https://github.com/facebookresearch/DPR. We use
  71. default faiss index parameters as specified in that repository.
  72. Args:
  73. vector_size (`int`):
  74. The dimension of indexed vectors.
  75. index_path (`str`):
  76. A path to a *directory* containing index files compatible with [`~models.rag.retrieval_rag.LegacyIndex`]
  77. """
  78. INDEX_FILENAME = "hf_bert_base.hnswSQ8_correct_phi_128.c_index"
  79. PASSAGE_FILENAME = "psgs_w100.tsv.pkl"
  80. def __init__(self, vector_size, index_path):
  81. requires_backends(self, ["faiss"])
  82. self.index_id_to_db_id = []
  83. self.index_path = index_path
  84. self.passages = self._load_passages()
  85. self.vector_size = vector_size
  86. self.index = None
  87. self._index_initialized = False
  88. def _resolve_path(self, index_path, filename):
  89. is_local = os.path.isdir(index_path)
  90. try:
  91. # Load from URL or cache if already cached
  92. resolved_archive_file = cached_file(index_path, filename)
  93. except OSError:
  94. msg = (
  95. f"Can't load '{filename}'. Make sure that:\n\n"
  96. f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}\n\n"
  97. f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
  98. )
  99. raise OSError(msg)
  100. if is_local:
  101. logger.info(f"loading file {resolved_archive_file}")
  102. else:
  103. logger.info(f"loading file {filename} from cache at {resolved_archive_file}")
  104. return resolved_archive_file
  105. def _load_passages(self):
  106. logger.info(f"Loading passages from {self.index_path}")
  107. passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME)
  108. if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
  109. raise ValueError(
  110. "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
  111. "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
  112. "that could have been tampered with. If you already verified the pickle data and decided to use it, "
  113. "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
  114. )
  115. with open(passages_path, "rb") as passages_file:
  116. passages = pickle.load(passages_file)
  117. return passages
  118. def _deserialize_index(self):
  119. logger.info(f"Loading index from {self.index_path}")
  120. resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr")
  121. self.index = faiss.read_index(resolved_index_path)
  122. resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr")
  123. if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
  124. raise ValueError(
  125. "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
  126. "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
  127. "that could have been tampered with. If you already verified the pickle data and decided to use it, "
  128. "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
  129. )
  130. with open(resolved_meta_path, "rb") as metadata_file:
  131. self.index_id_to_db_id = pickle.load(metadata_file)
  132. assert len(self.index_id_to_db_id) == self.index.ntotal, (
  133. "Deserialized index_id_to_db_id should match faiss index size"
  134. )
  135. def is_initialized(self):
  136. return self._index_initialized
  137. def init_index(self):
  138. index = faiss.IndexHNSWFlat(self.vector_size + 1, 512)
  139. index.hnsw.efSearch = 128
  140. index.hnsw.efConstruction = 200
  141. self.index = index
  142. self._deserialize_index()
  143. self._index_initialized = True
  144. def get_doc_dicts(self, doc_ids: np.ndarray):
  145. doc_list = []
  146. for doc_ids_i in doc_ids:
  147. ids = [str(int(doc_id)) for doc_id in doc_ids_i]
  148. docs = [self.passages[doc_id] for doc_id in ids]
  149. doc_list.append(docs)
  150. doc_dicts = []
  151. for docs in doc_list:
  152. doc_dict = {}
  153. doc_dict["title"] = [doc[1] for doc in docs]
  154. doc_dict["text"] = [doc[0] for doc in docs]
  155. doc_dicts.append(doc_dict)
  156. return doc_dicts
  157. def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> tuple[np.ndarray, np.ndarray]:
  158. aux_dim = np.zeros(len(question_hidden_states), dtype="float32").reshape(-1, 1)
  159. query_nhsw_vectors = np.hstack((question_hidden_states, aux_dim))
  160. _, docs_ids = self.index.search(query_nhsw_vectors, n_docs)
  161. vectors = [[self.index.reconstruct(int(doc_id))[:-1] for doc_id in doc_ids] for doc_ids in docs_ids]
  162. ids = [[int(self.index_id_to_db_id[doc_id]) for doc_id in doc_ids] for doc_ids in docs_ids]
  163. return np.array(ids), np.array(vectors)
  164. class HFIndexBase(Index):
  165. def __init__(self, vector_size, dataset, index_initialized=False):
  166. requires_backends(self, ["faiss"])
  167. self.vector_size = vector_size
  168. self.dataset = dataset
  169. self._index_initialized = index_initialized
  170. self._check_dataset_format(with_index=index_initialized)
  171. dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True, dtype="float32")
  172. def _check_dataset_format(self, with_index: bool):
  173. if not isinstance(self.dataset, Dataset):
  174. raise TypeError(f"Dataset should be a datasets.Dataset object, but got {type(self.dataset)}")
  175. if len({"title", "text", "embeddings"} - set(self.dataset.column_names)) > 0:
  176. raise ValueError(
  177. "Dataset should be a dataset with the following columns: "
  178. "title (str), text (str) and embeddings (arrays of dimension vector_size), "
  179. f"but got columns {self.dataset.column_names}"
  180. )
  181. if with_index and "embeddings" not in self.dataset.list_indexes():
  182. raise ValueError(
  183. "Missing faiss index in the dataset. Make sure you called `dataset.add_faiss_index` to compute it "
  184. "or `dataset.load_faiss_index` to load one from the disk."
  185. )
  186. def init_index(self):
  187. raise NotImplementedError()
  188. def is_initialized(self):
  189. return self._index_initialized
  190. def get_doc_dicts(self, doc_ids: np.ndarray) -> list[dict]:
  191. return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])]
  192. def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> tuple[np.ndarray, np.ndarray]:
  193. _, ids = self.dataset.search_batch("embeddings", question_hidden_states, n_docs)
  194. docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids]
  195. vectors = [doc["embeddings"] for doc in docs]
  196. for i in range(len(vectors)):
  197. if len(vectors[i]) < n_docs:
  198. vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))])
  199. return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
  200. class CanonicalHFIndex(HFIndexBase):
  201. """
  202. A wrapper around an instance of [`~datasets.Datasets`]. If `index_path` is set to `None`, we load the pre-computed
  203. index available with the [`~datasets.arrow_dataset.Dataset`], otherwise, we load the index from the indicated path
  204. on disk.
  205. Args:
  206. vector_size (`int`): the dimension of the passages embeddings used by the index
  207. dataset_name (`str`, optional, defaults to `wiki_dpr`):
  208. A dataset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids
  209. with `datasets.list_datasets()`).
  210. dataset_split (`str`, optional, defaults to `train`)
  211. Which split of the `dataset` to load.
  212. index_name (`str`, optional, defaults to `train`)
  213. The index_name of the index associated with the `dataset`. The index loaded from `index_path` will be saved
  214. under this name.
  215. index_path (`str`, optional, defaults to `None`)
  216. The path to the serialized faiss index on disk.
  217. use_dummy_dataset (`bool`, optional, defaults to `False`):
  218. If True, use the dummy configuration of the dataset for tests.
  219. """
  220. def __init__(
  221. self,
  222. vector_size: int,
  223. dataset_name: str = "wiki_dpr",
  224. dataset_split: str = "train",
  225. index_name: str | None = None,
  226. index_path: str | None = None,
  227. use_dummy_dataset=False,
  228. dataset_revision=None,
  229. ):
  230. requires_backends(self, ["faiss"])
  231. if int(index_path is None) + int(index_name is None) != 1:
  232. raise ValueError("Please provide `index_name` or `index_path`.")
  233. self.dataset_name = dataset_name
  234. self.dataset_split = dataset_split
  235. self.index_name = index_name
  236. self.index_path = index_path
  237. self.use_dummy_dataset = use_dummy_dataset
  238. self.dataset_revision = dataset_revision
  239. logger.info(f"Loading passages from {self.dataset_name}")
  240. dataset = load_dataset(
  241. self.dataset_name,
  242. with_index=False,
  243. split=self.dataset_split,
  244. dummy=self.use_dummy_dataset,
  245. revision=dataset_revision,
  246. )
  247. super().__init__(vector_size, dataset, index_initialized=False)
  248. def init_index(self):
  249. if self.index_path is not None:
  250. logger.info(f"Loading index from {self.index_path}")
  251. self.dataset.load_faiss_index("embeddings", file=self.index_path)
  252. else:
  253. logger.info(f"Loading index from {self.dataset_name} with index name {self.index_name}")
  254. self.dataset = load_dataset(
  255. self.dataset_name,
  256. with_embeddings=True,
  257. with_index=True,
  258. split=self.dataset_split,
  259. index_name=self.index_name,
  260. dummy=self.use_dummy_dataset,
  261. revision=self.dataset_revision,
  262. )
  263. self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
  264. self._index_initialized = True
  265. class CustomHFIndex(HFIndexBase):
  266. """
  267. A wrapper around an instance of [`~datasets.Datasets`]. The dataset and the index are both loaded from the
  268. indicated paths on disk.
  269. Args:
  270. vector_size (`int`): the dimension of the passages embeddings used by the index
  271. dataset_path (`str`):
  272. The path to the serialized dataset on disk. The dataset should have 3 columns: title (str), text (str) and
  273. embeddings (arrays of dimension vector_size)
  274. index_path (`str`)
  275. The path to the serialized faiss index on disk.
  276. """
  277. def __init__(self, vector_size: int, dataset, index_path=None):
  278. requires_backends(self, ["faiss"])
  279. super().__init__(vector_size, dataset, index_initialized=index_path is None)
  280. self.index_path = index_path
  281. @classmethod
  282. def load_from_disk(cls, vector_size, dataset_path, index_path):
  283. logger.info(f"Loading passages from {dataset_path}")
  284. if dataset_path is None or index_path is None:
  285. raise ValueError(
  286. "Please provide `dataset_path` and `index_path` after calling `dataset.save_to_disk(dataset_path)` "
  287. "and `dataset.get_index('embeddings').save(index_path)`."
  288. )
  289. dataset = load_from_disk(dataset_path)
  290. return cls(vector_size=vector_size, dataset=dataset, index_path=index_path)
  291. def init_index(self):
  292. if not self.is_initialized():
  293. logger.info(f"Loading index from {self.index_path}")
  294. self.dataset.load_faiss_index("embeddings", file=self.index_path)
  295. self._index_initialized = True
  296. class RagRetriever:
  297. """
  298. Retriever used to get documents from vector queries. It retrieves the documents embeddings as well as the documents
  299. contents, and it formats them to be used with a RagModel.
  300. Args:
  301. config ([`RagConfig`]):
  302. The configuration of the RAG model this Retriever is used with. Contains parameters indicating which
  303. `Index` to build. You can load your own custom dataset with `config.index_name="custom"` or use a canonical
  304. one (default) from the datasets library with `config.index_name="wiki_dpr"` for example.
  305. question_encoder_tokenizer ([`PreTrainedTokenizer`]):
  306. The tokenizer that was used to tokenize the question. It is used to decode the question and then use the
  307. generator_tokenizer.
  308. generator_tokenizer ([`PreTrainedTokenizer`]):
  309. The tokenizer used for the generator part of the RagModel.
  310. index ([`~models.rag.retrieval_rag.Index`], optional, defaults to the one defined by the configuration):
  311. If specified, use this index instead of the one built using the configuration
  312. Examples:
  313. ```python
  314. >>> # To load the default "wiki_dpr" dataset with 21M passages from wikipedia (index name is 'compressed' or 'exact')
  315. >>> from transformers import RagRetriever
  316. >>> retriever = RagRetriever.from_pretrained(
  317. ... "facebook/dpr-ctx_encoder-single-nq-base", dataset="wiki_dpr", index_name="compressed"
  318. ... )
  319. >>> # To load your own indexed dataset built with the datasets library. More info on how to build the indexed dataset in examples/rag/use_own_knowledge_dataset.py
  320. >>> from transformers import RagRetriever
  321. >>> dataset = (
  322. ... ...
  323. ... ) # dataset must be a datasets.Datasets object with columns "title", "text" and "embeddings", and it must have a supported index (e.g., Faiss or other index types depending on your setup)
  324. >>> retriever = RagRetriever.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", indexed_dataset=dataset)
  325. >>> # To load your own indexed dataset built with the datasets library that was saved on disk. More info in examples/rag/use_own_knowledge_dataset.py
  326. >>> from transformers import RagRetriever
  327. >>> dataset_path = "path/to/my/dataset" # dataset saved via *dataset.save_to_disk(...)*
  328. >>> index_path = "path/to/my/index" # index saved via *dataset.get_index("embeddings").save(...)*
  329. >>> retriever = RagRetriever.from_pretrained(
  330. ... "facebook/dpr-ctx_encoder-single-nq-base",
  331. ... index_name="custom",
  332. ... passages_path=dataset_path,
  333. ... index_path=index_path,
  334. ... )
  335. >>> # To load the legacy index built originally for Rag's paper
  336. >>> from transformers import RagRetriever
  337. >>> retriever = RagRetriever.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", index_name="legacy")
  338. ```"""
  339. def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None, init_retrieval=True):
  340. self._init_retrieval = init_retrieval
  341. requires_backends(self, ["datasets"])
  342. super().__init__()
  343. self.index = index or self._build_index(config)
  344. self.generator_tokenizer = generator_tokenizer
  345. self.question_encoder_tokenizer = question_encoder_tokenizer
  346. self.n_docs = config.n_docs
  347. self.batch_size = config.retrieval_batch_size
  348. self.config = config
  349. if self._init_retrieval:
  350. self.init_retrieval()
  351. self.ctx_encoder_tokenizer = None
  352. self.return_tokenized_docs = False
  353. @staticmethod
  354. def _build_index(config):
  355. if config.index_name == "legacy":
  356. return LegacyIndex(
  357. config.retrieval_vector_size,
  358. config.index_path or LEGACY_INDEX_PATH,
  359. )
  360. elif config.index_name == "custom":
  361. return CustomHFIndex.load_from_disk(
  362. vector_size=config.retrieval_vector_size,
  363. dataset_path=config.passages_path,
  364. index_path=config.index_path,
  365. )
  366. else:
  367. return CanonicalHFIndex(
  368. vector_size=config.retrieval_vector_size,
  369. dataset_name=config.dataset,
  370. dataset_split=config.dataset_split,
  371. index_name=config.index_name,
  372. index_path=config.index_path,
  373. use_dummy_dataset=config.use_dummy_dataset,
  374. dataset_revision=config.dataset_revision,
  375. )
  376. @classmethod
  377. def from_pretrained(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
  378. requires_backends(cls, ["datasets"])
  379. config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
  380. rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
  381. question_encoder_tokenizer = rag_tokenizer.question_encoder
  382. generator_tokenizer = rag_tokenizer.generator
  383. if indexed_dataset is not None:
  384. config.index_name = "custom"
  385. index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
  386. else:
  387. index = cls._build_index(config)
  388. return cls(
  389. config,
  390. question_encoder_tokenizer=question_encoder_tokenizer,
  391. generator_tokenizer=generator_tokenizer,
  392. index=index,
  393. )
  394. def save_pretrained(self, save_directory):
  395. if isinstance(self.index, CustomHFIndex):
  396. if self.config.index_path is None:
  397. index_path = os.path.join(save_directory, "hf_dataset_index.faiss")
  398. self.index.dataset.get_index("embeddings").save(index_path)
  399. self.config.index_path = index_path
  400. if self.config.passages_path is None:
  401. passages_path = os.path.join(save_directory, "hf_dataset")
  402. # datasets don't support save_to_disk with indexes right now
  403. faiss_index = self.index.dataset._indexes.pop("embeddings")
  404. self.index.dataset.save_to_disk(passages_path)
  405. self.index.dataset._indexes["embeddings"] = faiss_index
  406. self.config.passages_path = passages_path
  407. self.config.save_pretrained(save_directory)
  408. rag_tokenizer = RagTokenizer(
  409. question_encoder=self.question_encoder_tokenizer,
  410. generator=self.generator_tokenizer,
  411. )
  412. rag_tokenizer.save_pretrained(save_directory)
  413. def init_retrieval(self):
  414. """
  415. Retriever initialization function. It loads the index into memory.
  416. """
  417. logger.info("initializing retrieval")
  418. self.index.init_index()
  419. def postprocess_docs(self, docs, input_strings, prefix, n_docs, return_tensors=None):
  420. r"""
  421. Postprocessing retrieved `docs` and combining them with `input_strings`.
  422. Args:
  423. docs (`dict`):
  424. Retrieved documents.
  425. input_strings (`str`):
  426. Input strings decoded by `preprocess_query`.
  427. prefix (`str`):
  428. Prefix added at the beginning of each input, typically used with T5-based models.
  429. Return:
  430. `tuple(tensors)`: a tuple consisting of two elements: contextualized `input_ids` and a compatible
  431. `attention_mask`.
  432. """
  433. def cat_input_and_doc(doc_title, doc_text, input_string, prefix):
  434. # TODO(Patrick): if we train more RAG models, I want to put the input first to take advantage of effortless truncation
  435. # TODO(piktus): better handling of truncation
  436. doc_title = doc_title.removeprefix('"').removesuffix('"')
  437. if prefix is None:
  438. prefix = ""
  439. out = (prefix + doc_title + self.config.title_sep + doc_text + self.config.doc_sep + input_string).replace(
  440. " ", " "
  441. )
  442. return out
  443. rag_input_strings = [
  444. cat_input_and_doc(
  445. docs[i]["title"][j],
  446. docs[i]["text"][j],
  447. input_strings[i],
  448. prefix,
  449. )
  450. for i in range(len(docs))
  451. for j in range(n_docs)
  452. ]
  453. contextualized_inputs = self.generator_tokenizer(
  454. rag_input_strings,
  455. max_length=self.config.max_combined_length,
  456. return_tensors=return_tensors,
  457. padding="max_length",
  458. truncation=True,
  459. )
  460. return contextualized_inputs["input_ids"], contextualized_inputs["attention_mask"]
  461. def _chunk_tensor(self, t: Iterable, chunk_size: int) -> list[Iterable]:
  462. return [t[i : i + chunk_size] for i in range(0, len(t), chunk_size)]
  463. def _main_retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> tuple[np.ndarray, np.ndarray]:
  464. question_hidden_states_batched = self._chunk_tensor(question_hidden_states, self.batch_size)
  465. ids_batched = []
  466. vectors_batched = []
  467. for question_hidden_states in question_hidden_states_batched:
  468. start_time = time.time()
  469. ids, vectors = self.index.get_top_docs(question_hidden_states, n_docs)
  470. logger.debug(
  471. f"index search time: {time.time() - start_time} sec, batch size {question_hidden_states.shape}"
  472. )
  473. ids_batched.extend(ids)
  474. vectors_batched.extend(vectors)
  475. return (
  476. np.array(ids_batched),
  477. np.array(vectors_batched),
  478. ) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
  479. def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> tuple[np.ndarray, np.ndarray, list[dict]]:
  480. """
  481. Retrieves documents for specified `question_hidden_states`.
  482. Args:
  483. question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`):
  484. A batch of query vectors to retrieve with.
  485. n_docs (`int`):
  486. The number of docs retrieved per query.
  487. Return:
  488. `tuple[np.ndarray, np.ndarray, list[dict]]`: A tuple with the following objects:
  489. - **retrieved_doc_embeds** (`np.ndarray` of shape `(batch_size, n_docs, dim)`) -- The retrieval embeddings
  490. of the retrieved docs per query.
  491. - **doc_ids** (`np.ndarray` of shape `(batch_size, n_docs)`) -- The ids of the documents in the index
  492. - **doc_dicts** (`list[dict]`): The `retrieved_doc_embeds` examples per query.
  493. """
  494. doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
  495. return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
  496. def set_ctx_encoder_tokenizer(self, ctx_encoder_tokenizer: PreTrainedTokenizer):
  497. # used in end2end retriever training
  498. self.ctx_encoder_tokenizer = ctx_encoder_tokenizer
  499. self.return_tokenized_docs = True
  500. def __call__(
  501. self,
  502. question_input_ids: list[list[int]],
  503. question_hidden_states: np.ndarray,
  504. prefix=None,
  505. n_docs=None,
  506. return_tensors=None,
  507. ) -> BatchEncoding:
  508. """
  509. Retrieves documents for specified `question_hidden_states`.
  510. Args:
  511. question_input_ids (`list[list[int]]`) batch of input ids
  512. question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`:
  513. A batch of query vectors to retrieve with.
  514. prefix (`str`, *optional*):
  515. The prefix used by the generator's tokenizer.
  516. n_docs (`int`, *optional*):
  517. The number of docs retrieved per query.
  518. return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to "pt"):
  519. If set, will return tensors instead of list of python integers. Acceptable values are:
  520. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  521. - `'np'`: Return Numpy `np.ndarray` objects.
  522. Returns: [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
  523. - **context_input_ids** -- List of token ids to be fed to a model.
  524. [What are input IDs?](../glossary#input-ids)
  525. - **context_attention_mask** -- List of indices specifying which tokens should be attended to by the model
  526. (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`).
  527. [What are attention masks?](../glossary#attention-mask)
  528. - **retrieved_doc_embeds** -- List of embeddings of the retrieved documents
  529. - **doc_ids** -- List of ids of the retrieved documents
  530. """
  531. n_docs = n_docs if n_docs is not None else self.n_docs
  532. prefix = prefix if prefix is not None else getattr(self.config.generator, "prefix", None)
  533. retrieved_doc_embeds, doc_ids, docs = self.retrieve(question_hidden_states, n_docs)
  534. input_strings = self.question_encoder_tokenizer.decode(question_input_ids, skip_special_tokens=True)
  535. context_input_ids, context_attention_mask = self.postprocess_docs(
  536. docs, input_strings, prefix, n_docs, return_tensors=return_tensors
  537. )
  538. if self.return_tokenized_docs:
  539. retrieved_doc_text = []
  540. retrieved_doc_title = []
  541. for b_idx in range(len(docs)):
  542. for doc_idx in range(n_docs):
  543. retrieved_doc_text.append(docs[b_idx]["text"][doc_idx])
  544. retrieved_doc_title.append(docs[b_idx]["title"][doc_idx])
  545. tokenized_docs = self.ctx_encoder_tokenizer(
  546. retrieved_doc_title,
  547. retrieved_doc_text,
  548. truncation=True,
  549. padding="longest",
  550. return_tensors=return_tensors,
  551. )
  552. return BatchEncoding(
  553. {
  554. "context_input_ids": context_input_ids,
  555. "context_attention_mask": context_attention_mask,
  556. "retrieved_doc_embeds": retrieved_doc_embeds,
  557. "doc_ids": doc_ids,
  558. "tokenized_doc_ids": tokenized_docs["input_ids"],
  559. "tokenized_doc_attention_mask": tokenized_docs["attention_mask"],
  560. },
  561. tensor_type=return_tensors,
  562. )
  563. else:
  564. return BatchEncoding(
  565. {
  566. "context_input_ids": context_input_ids,
  567. "context_attention_mask": context_attention_mask,
  568. "retrieved_doc_embeds": retrieved_doc_embeds,
  569. "doc_ids": doc_ids,
  570. },
  571. tensor_type=return_tensors,
  572. )
  573. __all__ = ["RagRetriever"]