utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import csv
  16. import dataclasses
  17. import json
  18. from dataclasses import dataclass
  19. from ...utils import is_torch_available, logging
  20. logger = logging.get_logger(__name__)
  21. @dataclass
  22. class InputExample:
  23. """
  24. A single training/test example for simple sequence classification.
  25. Args:
  26. guid: Unique id for the example.
  27. text_a: string. The untokenized text of the first sequence. For single
  28. sequence tasks, only this sequence must be specified.
  29. text_b: (Optional) string. The untokenized text of the second sequence.
  30. Only must be specified for sequence pair tasks.
  31. label: (Optional) string. The label of the example. This should be
  32. specified for train and dev examples, but not for test examples.
  33. """
  34. guid: str
  35. text_a: str
  36. text_b: str | None = None
  37. label: str | None = None
  38. def to_json_string(self):
  39. """Serializes this instance to a JSON string."""
  40. return json.dumps(dataclasses.asdict(self), indent=2) + "\n"
  41. @dataclass(frozen=True)
  42. class InputFeatures:
  43. """
  44. A single set of features of data. Property names are the same names as the corresponding inputs to a model.
  45. Args:
  46. input_ids: Indices of input sequence tokens in the vocabulary.
  47. attention_mask: Mask to avoid performing attention on padding token indices.
  48. Mask values selected in `[0, 1]`: Usually `1` for tokens that are NOT MASKED, `0` for MASKED (padded)
  49. tokens.
  50. token_type_ids: (Optional) Segment token indices to indicate first and second
  51. portions of the inputs. Only some models use them.
  52. label: (Optional) Label corresponding to the input. Int for classification problems,
  53. float for regression problems.
  54. """
  55. input_ids: list[int]
  56. attention_mask: list[int] | None = None
  57. token_type_ids: list[int] | None = None
  58. label: int | float | None = None
  59. def to_json_string(self):
  60. """Serializes this instance to a JSON string."""
  61. return json.dumps(dataclasses.asdict(self)) + "\n"
  62. class DataProcessor:
  63. """Base class for data converters for sequence classification data sets."""
  64. def get_example_from_tensor_dict(self, tensor_dict):
  65. """
  66. Gets an example from a dict.
  67. Args:
  68. tensor_dict: Keys and values should match the corresponding Glue
  69. tensorflow_dataset examples.
  70. """
  71. raise NotImplementedError()
  72. def get_train_examples(self, data_dir):
  73. """Gets a collection of [`InputExample`] for the train set."""
  74. raise NotImplementedError()
  75. def get_dev_examples(self, data_dir):
  76. """Gets a collection of [`InputExample`] for the dev set."""
  77. raise NotImplementedError()
  78. def get_test_examples(self, data_dir):
  79. """Gets a collection of [`InputExample`] for the test set."""
  80. raise NotImplementedError()
  81. def get_labels(self):
  82. """Gets the list of labels for this data set."""
  83. raise NotImplementedError()
  84. def tfds_map(self, example):
  85. """
  86. Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. This method converts
  87. examples to the correct format.
  88. """
  89. if len(self.get_labels()) > 1:
  90. example.label = self.get_labels()[int(example.label)]
  91. return example
  92. @classmethod
  93. def _read_tsv(cls, input_file, quotechar=None):
  94. """Reads a tab separated value file."""
  95. with open(input_file, "r", encoding="utf-8-sig") as f:
  96. return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
  97. class SingleSentenceClassificationProcessor(DataProcessor):
  98. """Generic processor for a single sentence classification data set."""
  99. def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
  100. self.labels = [] if labels is None else labels
  101. self.examples = [] if examples is None else examples
  102. self.mode = mode
  103. self.verbose = verbose
  104. def __len__(self):
  105. return len(self.examples)
  106. def __getitem__(self, idx):
  107. if isinstance(idx, slice):
  108. return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
  109. return self.examples[idx]
  110. @classmethod
  111. def create_from_csv(
  112. cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
  113. ):
  114. processor = cls(**kwargs)
  115. processor.add_examples_from_csv(
  116. file_name,
  117. split_name=split_name,
  118. column_label=column_label,
  119. column_text=column_text,
  120. column_id=column_id,
  121. skip_first_row=skip_first_row,
  122. overwrite_labels=True,
  123. overwrite_examples=True,
  124. )
  125. return processor
  126. @classmethod
  127. def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
  128. processor = cls(**kwargs)
  129. processor.add_examples(texts_or_text_and_labels, labels=labels)
  130. return processor
  131. def add_examples_from_csv(
  132. self,
  133. file_name,
  134. split_name="",
  135. column_label=0,
  136. column_text=1,
  137. column_id=None,
  138. skip_first_row=False,
  139. overwrite_labels=False,
  140. overwrite_examples=False,
  141. ):
  142. lines = self._read_tsv(file_name)
  143. if skip_first_row:
  144. lines = lines[1:]
  145. texts = []
  146. labels = []
  147. ids = []
  148. for i, line in enumerate(lines):
  149. texts.append(line[column_text])
  150. labels.append(line[column_label])
  151. if column_id is not None:
  152. ids.append(line[column_id])
  153. else:
  154. guid = f"{split_name}-{i}" if split_name else str(i)
  155. ids.append(guid)
  156. return self.add_examples(
  157. texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
  158. )
  159. def add_examples(
  160. self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
  161. ):
  162. if labels is not None and len(texts_or_text_and_labels) != len(labels):
  163. raise ValueError(
  164. f"Text and labels have mismatched lengths {len(texts_or_text_and_labels)} and {len(labels)}"
  165. )
  166. if ids is not None and len(texts_or_text_and_labels) != len(ids):
  167. raise ValueError(f"Text and ids have mismatched lengths {len(texts_or_text_and_labels)} and {len(ids)}")
  168. if ids is None:
  169. ids = [None] * len(texts_or_text_and_labels)
  170. if labels is None:
  171. labels = [None] * len(texts_or_text_and_labels)
  172. examples = []
  173. added_labels = set()
  174. for text_or_text_and_label, label, guid in zip(texts_or_text_and_labels, labels, ids):
  175. if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
  176. text, label = text_or_text_and_label
  177. else:
  178. text = text_or_text_and_label
  179. added_labels.add(label)
  180. examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
  181. # Update examples
  182. if overwrite_examples:
  183. self.examples = examples
  184. else:
  185. self.examples.extend(examples)
  186. # Update labels
  187. if overwrite_labels:
  188. self.labels = list(added_labels)
  189. else:
  190. self.labels = list(set(self.labels).union(added_labels))
  191. return self.examples
  192. def get_features(
  193. self,
  194. tokenizer,
  195. max_length=None,
  196. pad_on_left=False,
  197. pad_token=0,
  198. mask_padding_with_zero=True,
  199. return_tensors=None,
  200. ):
  201. """
  202. Convert examples in a list of `InputFeatures`
  203. Args:
  204. tokenizer: Instance of a tokenizer that will tokenize the examples
  205. max_length: Maximum example length
  206. pad_on_left: If set to `True`, the examples will be padded on the left rather than on the right (default)
  207. pad_token: Padding token
  208. mask_padding_with_zero: If set to `True`, the attention mask will be filled by `1` for actual values
  209. and by `0` for padded values. If set to `False`, inverts it (`1` for padded values, `0` for actual
  210. values)
  211. Returns:
  212. Will return a list of task-specific `InputFeatures` which can be fed to the model.
  213. """
  214. if max_length is None:
  215. max_length = tokenizer.max_len
  216. label_map = {label: i for i, label in enumerate(self.labels)}
  217. all_input_ids = []
  218. for ex_index, example in enumerate(self.examples):
  219. if ex_index % 10000 == 0:
  220. logger.info(f"Tokenizing example {ex_index}")
  221. input_ids = tokenizer.encode(
  222. example.text_a,
  223. add_special_tokens=True,
  224. max_length=min(max_length, tokenizer.max_len),
  225. )
  226. all_input_ids.append(input_ids)
  227. batch_length = max(len(input_ids) for input_ids in all_input_ids)
  228. features = []
  229. for ex_index, (input_ids, example) in enumerate(zip(all_input_ids, self.examples)):
  230. if ex_index % 10000 == 0:
  231. logger.info(f"Writing example {ex_index}/{len(self.examples)}")
  232. # The mask has 1 for real tokens and 0 for padding tokens. Only real
  233. # tokens are attended to.
  234. attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
  235. # Zero-pad up to the sequence length.
  236. padding_length = batch_length - len(input_ids)
  237. if pad_on_left:
  238. input_ids = ([pad_token] * padding_length) + input_ids
  239. attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
  240. else:
  241. input_ids = input_ids + ([pad_token] * padding_length)
  242. attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
  243. if len(input_ids) != batch_length:
  244. raise ValueError(f"Error with input length {len(input_ids)} vs {batch_length}")
  245. if len(attention_mask) != batch_length:
  246. raise ValueError(f"Error with input length {len(attention_mask)} vs {batch_length}")
  247. if self.mode == "classification":
  248. label = label_map[example.label]
  249. elif self.mode == "regression":
  250. label = float(example.label)
  251. else:
  252. raise ValueError(self.mode)
  253. if ex_index < 5 and self.verbose:
  254. logger.info("*** Example ***")
  255. logger.info(f"guid: {example.guid}")
  256. logger.info(f"input_ids: {' '.join([str(x) for x in input_ids])}")
  257. logger.info(f"attention_mask: {' '.join([str(x) for x in attention_mask])}")
  258. logger.info(f"label: {example.label} (id = {label})")
  259. features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
  260. if return_tensors is None:
  261. return features
  262. elif return_tensors == "pt":
  263. if not is_torch_available():
  264. raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported")
  265. import torch
  266. from torch.utils.data import TensorDataset
  267. all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
  268. all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
  269. if self.mode == "classification":
  270. all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
  271. elif self.mode == "regression":
  272. all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
  273. dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels)
  274. return dataset
  275. else:
  276. raise ValueError("return_tensors should be `'pt'` or `None`")