squad.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. from dataclasses import dataclass, field
  17. from enum import Enum
  18. import torch
  19. from filelock import FileLock
  20. from torch.utils.data import Dataset
  21. from ...models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
  22. from ...tokenization_python import PreTrainedTokenizer
  23. from ...utils import check_torch_load_is_safe, logging
  24. from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
  25. logger = logging.get_logger(__name__)
  26. MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
  27. MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
  28. @dataclass
  29. class SquadDataTrainingArguments:
  30. """
  31. Arguments pertaining to what data we are going to input our model for training and eval.
  32. """
  33. model_type: str = field(
  34. default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)}
  35. )
  36. data_dir: str = field(
  37. default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
  38. )
  39. max_seq_length: int = field(
  40. default=128,
  41. metadata={
  42. "help": (
  43. "The maximum total input sequence length after tokenization. Sequences longer "
  44. "than this will be truncated, sequences shorter will be padded."
  45. )
  46. },
  47. )
  48. doc_stride: int = field(
  49. default=128,
  50. metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
  51. )
  52. max_query_length: int = field(
  53. default=64,
  54. metadata={
  55. "help": (
  56. "The maximum number of tokens for the question. Questions longer than this will "
  57. "be truncated to this length."
  58. )
  59. },
  60. )
  61. max_answer_length: int = field(
  62. default=30,
  63. metadata={
  64. "help": (
  65. "The maximum length of an answer that can be generated. This is needed because the start "
  66. "and end predictions are not conditioned on one another."
  67. )
  68. },
  69. )
  70. overwrite_cache: bool = field(
  71. default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
  72. )
  73. version_2_with_negative: bool = field(
  74. default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."}
  75. )
  76. null_score_diff_threshold: float = field(
  77. default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
  78. )
  79. n_best_size: int = field(
  80. default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
  81. )
  82. lang_id: int = field(
  83. default=0,
  84. metadata={
  85. "help": (
  86. "language id of input for language-specific xlm models (see"
  87. " tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
  88. )
  89. },
  90. )
  91. threads: int = field(default=1, metadata={"help": "multiple threads for converting example to features"})
  92. class Split(Enum):
  93. train = "train"
  94. dev = "dev"
  95. class SquadDataset(Dataset):
  96. args: SquadDataTrainingArguments
  97. features: list[SquadFeatures]
  98. mode: Split
  99. is_language_sensitive: bool
  100. def __init__(
  101. self,
  102. args: SquadDataTrainingArguments,
  103. tokenizer: PreTrainedTokenizer,
  104. limit_length: int | None = None,
  105. mode: str | Split = Split.train,
  106. is_language_sensitive: bool = False,
  107. cache_dir: str | None = None,
  108. dataset_format: str = "pt",
  109. ):
  110. self.args = args
  111. self.is_language_sensitive = is_language_sensitive
  112. self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
  113. if isinstance(mode, str):
  114. try:
  115. mode = Split[mode]
  116. except KeyError:
  117. raise KeyError("mode is not a valid split name")
  118. self.mode = mode
  119. # Load data features from cache or dataset file
  120. version_tag = "v2" if args.version_2_with_negative else "v1"
  121. cached_features_file = os.path.join(
  122. cache_dir if cache_dir is not None else args.data_dir,
  123. f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{version_tag}",
  124. )
  125. # Make sure only the first process in distributed training processes the dataset,
  126. # and the others will use the cache.
  127. lock_path = cached_features_file + ".lock"
  128. with FileLock(lock_path):
  129. if os.path.exists(cached_features_file) and not args.overwrite_cache:
  130. start = time.time()
  131. check_torch_load_is_safe()
  132. self.old_features = torch.load(cached_features_file, weights_only=True)
  133. # Legacy cache files have only features, while new cache files
  134. # will have dataset and examples also.
  135. self.features = self.old_features["features"]
  136. self.dataset = self.old_features.get("dataset", None)
  137. self.examples = self.old_features.get("examples", None)
  138. logger.info(
  139. f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
  140. )
  141. if self.dataset is None or self.examples is None:
  142. logger.warning(
  143. f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in"
  144. " future run"
  145. )
  146. else:
  147. if mode == Split.dev:
  148. self.examples = self.processor.get_dev_examples(args.data_dir)
  149. else:
  150. self.examples = self.processor.get_train_examples(args.data_dir)
  151. self.features, self.dataset = squad_convert_examples_to_features(
  152. examples=self.examples,
  153. tokenizer=tokenizer,
  154. max_seq_length=args.max_seq_length,
  155. doc_stride=args.doc_stride,
  156. max_query_length=args.max_query_length,
  157. is_training=mode == Split.train,
  158. threads=args.threads,
  159. return_dataset=dataset_format,
  160. )
  161. start = time.time()
  162. torch.save(
  163. {"features": self.features, "dataset": self.dataset, "examples": self.examples},
  164. cached_features_file,
  165. )
  166. # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
  167. logger.info(
  168. f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
  169. )
  170. def __len__(self):
  171. return len(self.features)
  172. def __getitem__(self, i) -> dict[str, torch.Tensor]:
  173. # Convert to Tensors and build dataset
  174. feature = self.features[i]
  175. input_ids = torch.tensor(feature.input_ids, dtype=torch.long)
  176. attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long)
  177. token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long)
  178. cls_index = torch.tensor(feature.cls_index, dtype=torch.long)
  179. p_mask = torch.tensor(feature.p_mask, dtype=torch.float)
  180. is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float)
  181. inputs = {
  182. "input_ids": input_ids,
  183. "attention_mask": attention_mask,
  184. "token_type_ids": token_type_ids,
  185. }
  186. if self.args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
  187. del inputs["token_type_ids"]
  188. if self.args.model_type in ["xlnet", "xlm"]:
  189. inputs.update({"cls_index": cls_index, "p_mask": p_mask})
  190. if self.args.version_2_with_negative:
  191. inputs.update({"is_impossible": is_impossible})
  192. if self.is_language_sensitive:
  193. inputs.update({"langs": (torch.ones(input_ids.shape, dtype=torch.int64) * self.args.lang_id)})
  194. if self.mode == Split.train:
  195. start_positions = torch.tensor(feature.start_position, dtype=torch.long)
  196. end_positions = torch.tensor(feature.end_position, dtype=torch.long)
  197. inputs.update({"start_positions": start_positions, "end_positions": end_positions})
  198. return inputs