| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- # Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """RAG model configuration"""
- from huggingface_hub.dataclasses import strict
- from ...configuration_utils import PreTrainedConfig
- from ...utils import auto_docstring
- from ..auto.configuration_auto import AutoConfig
- @auto_docstring(checkpoint="")
- @strict
- class RagConfig(PreTrainedConfig):
- r"""
- prefix (`str`, *optional*):
- A string prefix prepended to every input before passing to the generator model.
- title_sep (`str`, *optional*, defaults to `" / "`):
- Separator inserted between the title and the text of the retrieved document when calling [`RagRetriever`].
- doc_sep (`str`, *optional*, defaults to `" // "`):
- Separator inserted between the text of the retrieved document and the original input when calling
- [`RagRetriever`].
- n_docs (`int`, *optional*, defaults to 5):
- Number of documents to retrieve.
- max_combined_length (`int`, *optional*, defaults to 300):
- Max length of contextualized input returned by [`~RagRetriever.__call__`].
- retrieval_vector_size (`int`, *optional*, defaults to 768):
- Dimensionality of the document embeddings indexed by [`RagRetriever`].
- retrieval_batch_size (`int`, *optional*, defaults to 8):
- Retrieval batch size, defined as the number of queries issues concurrently to the faiss index encapsulated
- [`RagRetriever`].
- dataset (`str`, *optional*, defaults to `"wiki_dpr"`):
- A dataset identifier of the indexed dataset in HuggingFace Datasets (list all available datasets and ids
- using `datasets.list_datasets()`).
- dataset_split (`str`, *optional*, defaults to `"train"`):
- Which split of the `dataset` to load.
- index_name (`str`, *optional*, defaults to `"compressed"`):
- The index name of the index associated with the `dataset`. One can choose between `"legacy"`, `"exact"` and
- `"compressed"`.
- index_path (`str`, *optional*):
- The path to the serialized faiss index on disk.
- passages_path (`str`, *optional*):
- A path to text passages compatible with the faiss index. Required if using
- [`~models.rag.retrieval_rag.LegacyIndex`]
- use_dummy_dataset (`bool`, *optional*, defaults to `False`):
- Whether to load a "dummy" variant of the dataset specified by `dataset`.
- reduce_loss (`bool`, *optional*, defaults to `False`):
- Whether or not to reduce the NLL loss using the `torch.Tensor.sum` operation.
- label_smoothing (`float`, *optional*, defaults to 0.0):
- Only relevant if `return_loss` is set to `True`. Controls the `epsilon` parameter value for label smoothing
- in the loss calculation. If set to 0, no label smoothing is performed.
- do_deduplication (`bool`, *optional*, defaults to `True`):
- Whether or not to deduplicate the generations from different context documents for a given input. Has to be
- set to `False` if used while training with distributed backend.
- exclude_bos_score (`bool`, *optional*, defaults to `False`):
- Whether or not to disregard the BOS token when computing the loss.
- do_marginalize (`bool`, *optional*, defaults to `False`):
- If `True`, the logits are marginalized over all documents by making use of
- `torch.nn.functional.log_softmax`.
- output_retrieved (`bool`, *optional*, defaults to `False`):
- If set to `True`, `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
- `context_attention_mask` are returned. See returned tensors for more detail.
- dataset_revision (`str`, *optional*,):
- The revision (commit hash, tag, or branch) of the Hugging Face dataset used for retrieval.
- """
- model_type = "rag"
- has_no_defaults_at_init = True
- vocab_size: int | None = None
- is_encoder_decoder: bool = True
- prefix: str | None = None
- bos_token_id: int | None = None
- pad_token_id: int | None = None
- eos_token_id: int | list[int] | None = None
- decoder_start_token_id: int | None = None
- title_sep: str = " / "
- doc_sep: str = " // "
- n_docs: int = 5
- max_combined_length: int = 300
- retrieval_vector_size: int = 768
- retrieval_batch_size: int = 8
- dataset: str = "wiki_dpr"
- dataset_split: str = "train"
- index_name: str = "compressed"
- index_path: str | None = None
- passages_path: str | None = None
- use_dummy_dataset: bool = False
- reduce_loss: bool = False
- label_smoothing: float = 0.0
- do_deduplication: bool = True
- exclude_bos_score: bool = False
- do_marginalize: bool = False
- output_retrieved: bool = False
- use_cache: bool = True
- dataset_revision: str | None = None
- def __post_init__(self, **kwargs):
- if "question_encoder" not in kwargs or "generator" not in kwargs:
- raise ValueError(
- f"A configuration of type {self.model_type} cannot be instantiated because not both `question_encoder` and"
- f" `generator` sub-configurations are passed, but only {kwargs}"
- )
- question_encoder_config = kwargs.pop("question_encoder")
- question_encoder_model_type = question_encoder_config.pop("model_type")
- decoder_config = kwargs.pop("generator")
- decoder_model_type = decoder_config.pop("model_type")
- self.question_encoder = AutoConfig.for_model(question_encoder_model_type, **question_encoder_config)
- self.generator = AutoConfig.for_model(decoder_model_type, **decoder_config)
- super().__post_init__(**kwargs)
- @classmethod
- def from_question_encoder_generator_configs(
- cls, question_encoder_config: PreTrainedConfig, generator_config: PreTrainedConfig, **kwargs
- ) -> PreTrainedConfig:
- r"""
- Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and
- decoder model configuration.
- Returns:
- [`EncoderDecoderConfig`]: An instance of a configuration object
- """
- return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs)
- __all__ = ["RagConfig"]
|