configuration_rag.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """RAG model configuration"""
  15. from huggingface_hub.dataclasses import strict
  16. from ...configuration_utils import PreTrainedConfig
  17. from ...utils import auto_docstring
  18. from ..auto.configuration_auto import AutoConfig
  19. @auto_docstring(checkpoint="")
  20. @strict
  21. class RagConfig(PreTrainedConfig):
  22. r"""
  23. prefix (`str`, *optional*):
  24. A string prefix prepended to every input before passing to the generator model.
  25. title_sep (`str`, *optional*, defaults to `" / "`):
  26. Separator inserted between the title and the text of the retrieved document when calling [`RagRetriever`].
  27. doc_sep (`str`, *optional*, defaults to `" // "`):
  28. Separator inserted between the text of the retrieved document and the original input when calling
  29. [`RagRetriever`].
  30. n_docs (`int`, *optional*, defaults to 5):
  31. Number of documents to retrieve.
  32. max_combined_length (`int`, *optional*, defaults to 300):
  33. Max length of contextualized input returned by [`~RagRetriever.__call__`].
  34. retrieval_vector_size (`int`, *optional*, defaults to 768):
  35. Dimensionality of the document embeddings indexed by [`RagRetriever`].
  36. retrieval_batch_size (`int`, *optional*, defaults to 8):
  37. Retrieval batch size, defined as the number of queries issues concurrently to the faiss index encapsulated
  38. [`RagRetriever`].
  39. dataset (`str`, *optional*, defaults to `"wiki_dpr"`):
  40. A dataset identifier of the indexed dataset in HuggingFace Datasets (list all available datasets and ids
  41. using `datasets.list_datasets()`).
  42. dataset_split (`str`, *optional*, defaults to `"train"`):
  43. Which split of the `dataset` to load.
  44. index_name (`str`, *optional*, defaults to `"compressed"`):
  45. The index name of the index associated with the `dataset`. One can choose between `"legacy"`, `"exact"` and
  46. `"compressed"`.
  47. index_path (`str`, *optional*):
  48. The path to the serialized faiss index on disk.
  49. passages_path (`str`, *optional*):
  50. A path to text passages compatible with the faiss index. Required if using
  51. [`~models.rag.retrieval_rag.LegacyIndex`]
  52. use_dummy_dataset (`bool`, *optional*, defaults to `False`):
  53. Whether to load a "dummy" variant of the dataset specified by `dataset`.
  54. reduce_loss (`bool`, *optional*, defaults to `False`):
  55. Whether or not to reduce the NLL loss using the `torch.Tensor.sum` operation.
  56. label_smoothing (`float`, *optional*, defaults to 0.0):
  57. Only relevant if `return_loss` is set to `True`. Controls the `epsilon` parameter value for label smoothing
  58. in the loss calculation. If set to 0, no label smoothing is performed.
  59. do_deduplication (`bool`, *optional*, defaults to `True`):
  60. Whether or not to deduplicate the generations from different context documents for a given input. Has to be
  61. set to `False` if used while training with distributed backend.
  62. exclude_bos_score (`bool`, *optional*, defaults to `False`):
  63. Whether or not to disregard the BOS token when computing the loss.
  64. do_marginalize (`bool`, *optional*, defaults to `False`):
  65. If `True`, the logits are marginalized over all documents by making use of
  66. `torch.nn.functional.log_softmax`.
  67. output_retrieved (`bool`, *optional*, defaults to `False`):
  68. If set to `True`, `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
  69. `context_attention_mask` are returned. See returned tensors for more detail.
  70. dataset_revision (`str`, *optional*,):
  71. The revision (commit hash, tag, or branch) of the Hugging Face dataset used for retrieval.
  72. """
  73. model_type = "rag"
  74. has_no_defaults_at_init = True
  75. vocab_size: int | None = None
  76. is_encoder_decoder: bool = True
  77. prefix: str | None = None
  78. bos_token_id: int | None = None
  79. pad_token_id: int | None = None
  80. eos_token_id: int | list[int] | None = None
  81. decoder_start_token_id: int | None = None
  82. title_sep: str = " / "
  83. doc_sep: str = " // "
  84. n_docs: int = 5
  85. max_combined_length: int = 300
  86. retrieval_vector_size: int = 768
  87. retrieval_batch_size: int = 8
  88. dataset: str = "wiki_dpr"
  89. dataset_split: str = "train"
  90. index_name: str = "compressed"
  91. index_path: str | None = None
  92. passages_path: str | None = None
  93. use_dummy_dataset: bool = False
  94. reduce_loss: bool = False
  95. label_smoothing: float = 0.0
  96. do_deduplication: bool = True
  97. exclude_bos_score: bool = False
  98. do_marginalize: bool = False
  99. output_retrieved: bool = False
  100. use_cache: bool = True
  101. dataset_revision: str | None = None
  102. def __post_init__(self, **kwargs):
  103. if "question_encoder" not in kwargs or "generator" not in kwargs:
  104. raise ValueError(
  105. f"A configuration of type {self.model_type} cannot be instantiated because not both `question_encoder` and"
  106. f" `generator` sub-configurations are passed, but only {kwargs}"
  107. )
  108. question_encoder_config = kwargs.pop("question_encoder")
  109. question_encoder_model_type = question_encoder_config.pop("model_type")
  110. decoder_config = kwargs.pop("generator")
  111. decoder_model_type = decoder_config.pop("model_type")
  112. self.question_encoder = AutoConfig.for_model(question_encoder_model_type, **question_encoder_config)
  113. self.generator = AutoConfig.for_model(decoder_model_type, **decoder_config)
  114. super().__post_init__(**kwargs)
  115. @classmethod
  116. def from_question_encoder_generator_configs(
  117. cls, question_encoder_config: PreTrainedConfig, generator_config: PreTrainedConfig, **kwargs
  118. ) -> PreTrainedConfig:
  119. r"""
  120. Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and
  121. decoder model configuration.
  122. Returns:
  123. [`EncoderDecoderConfig`]: An instance of a configuration object
  124. """
  125. return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs)
  126. __all__ = ["RagConfig"]