configuration_tapas.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright 2020 Google Research and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. TAPAS configuration. Based on the BERT configuration with added parameters.
  16. Hyperparameters are taken from run_task_main.py and hparam_utils.py of the original implementation. URLS:
  17. - https://github.com/google-research/tapas/blob/master/tapas/run_task_main.py
  18. - https://github.com/google-research/tapas/blob/master/tapas/utils/hparam_utils.py
  19. """
  20. from huggingface_hub.dataclasses import strict
  21. from ...configuration_utils import PreTrainedConfig
  22. from ...utils import auto_docstring
  23. @auto_docstring(checkpoint="google/tapas-base-finetuned-sqa")
  24. @strict
  25. class TapasConfig(PreTrainedConfig):
  26. r"""
  27. type_vocab_sizes (`list[int]`, *optional*, defaults to `[3, 256, 256, 2, 256, 256, 10]`):
  28. The vocabulary sizes of the `token_type_ids` passed when calling [`TapasModel`].
  29. positive_label_weight (`float`, *optional*, defaults to 10.0):
  30. Weight for positive labels.
  31. num_aggregation_labels (`int`, *optional*, defaults to 0):
  32. The number of aggregation operators to predict.
  33. aggregation_loss_weight (`float`, *optional*, defaults to 1.0):
  34. Importance weight for the aggregation loss.
  35. use_answer_as_supervision (`bool`, *optional*):
  36. Whether to use the answer as the only supervision for aggregation examples.
  37. answer_loss_importance (`float`, *optional*, defaults to 1.0):
  38. Importance weight for the regression loss.
  39. use_normalized_answer_loss (`bool`, *optional*, defaults to `False`):
  40. Whether to normalize the answer loss by the maximum of the predicted and expected value.
  41. huber_loss_delta (`float`, *optional*):
  42. Delta parameter used to calculate the regression loss.
  43. temperature (`float`, *optional*, defaults to 1.0):
  44. Value used to control (OR change) the skewness of cell logits probabilities.
  45. aggregation_temperature (`float`, *optional*, defaults to 1.0):
  46. Scales aggregation logits to control the skewness of probabilities.
  47. use_gumbel_for_cells (`bool`, *optional*, defaults to `False`):
  48. Whether to apply Gumbel-Softmax to cell selection.
  49. use_gumbel_for_aggregation (`bool`, *optional*, defaults to `False`):
  50. Whether to apply Gumbel-Softmax to aggregation selection.
  51. average_approximation_function (`string`, *optional*, defaults to `"ratio"`):
  52. Method to calculate the expected average of cells in the weak supervision case. One of `"ratio"`,
  53. `"first_order"` or `"second_order"`.
  54. cell_selection_preference (`float`, *optional*):
  55. Preference for cell selection in ambiguous cases. Only applicable in case of weak supervision for
  56. aggregation (WTQ, WikiSQL). If the total mass of the aggregation probabilities (excluding the "NONE"
  57. operator) is higher than this hyperparameter, then aggregation is predicted for an example.
  58. answer_loss_cutoff (`float`, *optional*):
  59. Ignore examples with answer loss larger than cutoff.
  60. max_num_rows (`int`, *optional*, defaults to 64):
  61. Maximum number of rows.
  62. max_num_columns (`int`, *optional*, defaults to 32):
  63. Maximum number of columns.
  64. average_logits_per_cell (`bool`, *optional*, defaults to `False`):
  65. Whether to average logits per cell.
  66. select_one_column (`bool`, *optional*, defaults to `True`):
  67. Whether to constrain the model to only select cells from a single column.
  68. allow_empty_column_selection (`bool`, *optional*, defaults to `False`):
  69. Whether to allow not to select any column.
  70. init_cell_selection_weights_to_zero (`bool`, *optional*, defaults to `False`):
  71. Whether to initialize cell selection weights to 0 so that the initial probabilities are 50%.
  72. reset_position_index_per_cell (`bool`, *optional*, defaults to `True`):
  73. Whether to restart position indexes at every cell (i.e. use relative position embeddings).
  74. disable_per_token_loss (`bool`, *optional*, defaults to `False`):
  75. Whether to disable any (strong or weak) supervision on cells.
  76. aggregation_labels (`dict[int, label]`, *optional*):
  77. The aggregation labels used to aggregate the results. For example, the WTQ models have the following
  78. aggregation labels: `{0: "NONE", 1: "SUM", 2: "AVERAGE", 3: "COUNT"}`
  79. no_aggregation_label_index (`int`, *optional*):
  80. If the aggregation labels are defined and one of these labels represents "No aggregation", this should be
  81. set to its index. For example, the WTQ models have the "NONE" aggregation label at index 0, so that value
  82. should be set to 0 for these models.
  83. Example:
  84. ```python
  85. >>> from transformers import TapasModel, TapasConfig
  86. >>> # Initializing a default (SQA) Tapas configuration
  87. >>> configuration = TapasConfig()
  88. >>> # Initializing a model from the configuration
  89. >>> model = TapasModel(configuration)
  90. >>> # Accessing the model configuration
  91. >>> configuration = model.config
  92. ```"""
  93. model_type = "tapas"
  94. vocab_size: int = 30522
  95. hidden_size: int = 768
  96. num_hidden_layers: int = 12
  97. num_attention_heads: int = 12
  98. intermediate_size: int = 3072
  99. hidden_act: str = "gelu"
  100. hidden_dropout_prob: float | int = 0.1
  101. attention_probs_dropout_prob: float | int = 0.1
  102. max_position_embeddings: int = 1024
  103. type_vocab_sizes: list[int] | tuple[int, ...] = (3, 256, 256, 2, 256, 256, 10)
  104. initializer_range: float = 0.02
  105. layer_norm_eps: float = 1e-12
  106. pad_token_id: int | None = 0
  107. bos_token_id: int | None = None
  108. eos_token_id: int | list[int] | None = None
  109. positive_label_weight: float = 10.0
  110. num_aggregation_labels: int = 0
  111. aggregation_loss_weight: float = 1.0
  112. use_answer_as_supervision: bool | None = None
  113. answer_loss_importance: float = 1.0
  114. use_normalized_answer_loss: bool = False
  115. huber_loss_delta: float | None = None
  116. temperature: float = 1.0
  117. aggregation_temperature: float = 1.0
  118. use_gumbel_for_cells: bool = False
  119. use_gumbel_for_aggregation: bool = False
  120. average_approximation_function: str = "ratio"
  121. cell_selection_preference: float | None = None
  122. answer_loss_cutoff: float | int | None = None
  123. max_num_rows: int = 64
  124. max_num_columns: int = 32
  125. average_logits_per_cell: bool = False
  126. select_one_column: bool = True
  127. allow_empty_column_selection: bool = False
  128. init_cell_selection_weights_to_zero: bool = False
  129. reset_position_index_per_cell: bool = True
  130. disable_per_token_loss: bool = False
  131. aggregation_labels: dict | None = None
  132. no_aggregation_label_index: int | None = None
  133. is_decoder: bool = False
  134. add_cross_attention: bool = False
  135. tie_word_embeddings: bool = True
  136. def __post_init__(self, **kwargs):
  137. if isinstance(self.aggregation_labels, dict):
  138. self.aggregation_labels = {int(k): v for k, v in self.aggregation_labels.items()}
  139. super().__post_init__(**kwargs)
  140. __all__ = ["TapasConfig"]