glue.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """GLUE processors and helpers"""
  16. import os
  17. import warnings
  18. from enum import Enum
  19. from ...tokenization_python import PreTrainedTokenizer
  20. from ...utils import logging
  21. from .utils import DataProcessor, InputExample, InputFeatures
  22. logger = logging.get_logger(__name__)
  23. DEPRECATION_WARNING = (
  24. "This {0} will be removed from the library soon, preprocessing should be handled with the Hugging Face Datasets "
  25. "library. You can have a look at this example script for pointers: "
  26. "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
  27. )
  28. def glue_convert_examples_to_features(
  29. examples: list[InputExample],
  30. tokenizer: PreTrainedTokenizer,
  31. max_length: int | None = None,
  32. task=None,
  33. label_list=None,
  34. output_mode=None,
  35. ):
  36. """
  37. Loads a data file into a list of `InputFeatures`
  38. Args:
  39. examples: List of `InputExamples` containing the examples.
  40. tokenizer: Instance of a tokenizer that will tokenize the examples
  41. max_length: Maximum example length. Defaults to the tokenizer's max_len
  42. task: GLUE task
  43. label_list: List of labels. Can be obtained from the processor using the `processor.get_labels()` method
  44. output_mode: String indicating the output mode. Either `regression` or `classification`
  45. Returns:
  46. Will return a list of task-specific `InputFeatures` which can be fed to the model.
  47. """
  48. warnings.warn(DEPRECATION_WARNING.format("function"), FutureWarning)
  49. return _glue_convert_examples_to_features(
  50. examples, tokenizer, max_length=max_length, task=task, label_list=label_list, output_mode=output_mode
  51. )
  52. def _glue_convert_examples_to_features(
  53. examples: list[InputExample],
  54. tokenizer: PreTrainedTokenizer,
  55. max_length: int | None = None,
  56. task=None,
  57. label_list=None,
  58. output_mode=None,
  59. ):
  60. if max_length is None:
  61. max_length = tokenizer.model_max_length
  62. if task is not None:
  63. processor = glue_processors[task]()
  64. if label_list is None:
  65. label_list = processor.get_labels()
  66. logger.info(f"Using label list {label_list} for task {task}")
  67. if output_mode is None:
  68. output_mode = glue_output_modes[task]
  69. logger.info(f"Using output mode {output_mode} for task {task}")
  70. label_map = {label: i for i, label in enumerate(label_list)}
  71. def label_from_example(example: InputExample) -> int | float | None:
  72. if example.label is None:
  73. return None
  74. if output_mode == "classification":
  75. return label_map[example.label]
  76. elif output_mode == "regression":
  77. return float(example.label)
  78. raise KeyError(output_mode)
  79. labels = [label_from_example(example) for example in examples]
  80. batch_encoding = tokenizer(
  81. [(example.text_a, example.text_b) for example in examples],
  82. max_length=max_length,
  83. padding="max_length",
  84. truncation=True,
  85. )
  86. features = []
  87. for i in range(len(examples)):
  88. inputs = {k: batch_encoding[k][i] for k in batch_encoding}
  89. feature = InputFeatures(**inputs, label=labels[i])
  90. features.append(feature)
  91. for i, example in enumerate(examples[:5]):
  92. logger.info("*** Example ***")
  93. logger.info(f"guid: {example.guid}")
  94. logger.info(f"features: {features[i]}")
  95. return features
  96. class OutputMode(Enum):
  97. classification = "classification"
  98. regression = "regression"
  99. class MrpcProcessor(DataProcessor):
  100. """Processor for the MRPC data set (GLUE version)."""
  101. def __init__(self, *args, **kwargs):
  102. super().__init__(*args, **kwargs)
  103. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  104. def get_example_from_tensor_dict(self, tensor_dict):
  105. """See base class."""
  106. return InputExample(
  107. tensor_dict["idx"].numpy(),
  108. tensor_dict["sentence1"].numpy().decode("utf-8"),
  109. tensor_dict["sentence2"].numpy().decode("utf-8"),
  110. str(tensor_dict["label"].numpy()),
  111. )
  112. def get_train_examples(self, data_dir):
  113. """See base class."""
  114. logger.info(f"LOOKING AT {os.path.join(data_dir, 'train.tsv')}")
  115. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  116. def get_dev_examples(self, data_dir):
  117. """See base class."""
  118. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  119. def get_test_examples(self, data_dir):
  120. """See base class."""
  121. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  122. def get_labels(self):
  123. """See base class."""
  124. return ["0", "1"]
  125. def _create_examples(self, lines, set_type):
  126. """Creates examples for the training, dev and test sets."""
  127. examples = []
  128. for i, line in enumerate(lines):
  129. if i == 0:
  130. continue
  131. guid = f"{set_type}-{i}"
  132. text_a = line[3]
  133. text_b = line[4]
  134. label = None if set_type == "test" else line[0]
  135. examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  136. return examples
  137. class MnliProcessor(DataProcessor):
  138. """Processor for the MultiNLI data set (GLUE version)."""
  139. def __init__(self, *args, **kwargs):
  140. super().__init__(*args, **kwargs)
  141. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  142. def get_example_from_tensor_dict(self, tensor_dict):
  143. """See base class."""
  144. return InputExample(
  145. tensor_dict["idx"].numpy(),
  146. tensor_dict["premise"].numpy().decode("utf-8"),
  147. tensor_dict["hypothesis"].numpy().decode("utf-8"),
  148. str(tensor_dict["label"].numpy()),
  149. )
  150. def get_train_examples(self, data_dir):
  151. """See base class."""
  152. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  153. def get_dev_examples(self, data_dir):
  154. """See base class."""
  155. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
  156. def get_test_examples(self, data_dir):
  157. """See base class."""
  158. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
  159. def get_labels(self):
  160. """See base class."""
  161. return ["contradiction", "entailment", "neutral"]
  162. def _create_examples(self, lines, set_type):
  163. """Creates examples for the training, dev and test sets."""
  164. examples = []
  165. for i, line in enumerate(lines):
  166. if i == 0:
  167. continue
  168. guid = f"{set_type}-{line[0]}"
  169. text_a = line[8]
  170. text_b = line[9]
  171. label = None if set_type.startswith("test") else line[-1]
  172. examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  173. return examples
  174. class MnliMismatchedProcessor(MnliProcessor):
  175. """Processor for the MultiNLI Mismatched data set (GLUE version)."""
  176. def __init__(self, *args, **kwargs):
  177. super().__init__(*args, **kwargs)
  178. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  179. def get_dev_examples(self, data_dir):
  180. """See base class."""
  181. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
  182. def get_test_examples(self, data_dir):
  183. """See base class."""
  184. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
  185. class ColaProcessor(DataProcessor):
  186. """Processor for the CoLA data set (GLUE version)."""
  187. def __init__(self, *args, **kwargs):
  188. super().__init__(*args, **kwargs)
  189. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  190. def get_example_from_tensor_dict(self, tensor_dict):
  191. """See base class."""
  192. return InputExample(
  193. tensor_dict["idx"].numpy(),
  194. tensor_dict["sentence"].numpy().decode("utf-8"),
  195. None,
  196. str(tensor_dict["label"].numpy()),
  197. )
  198. def get_train_examples(self, data_dir):
  199. """See base class."""
  200. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  201. def get_dev_examples(self, data_dir):
  202. """See base class."""
  203. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  204. def get_test_examples(self, data_dir):
  205. """See base class."""
  206. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  207. def get_labels(self):
  208. """See base class."""
  209. return ["0", "1"]
  210. def _create_examples(self, lines, set_type):
  211. """Creates examples for the training, dev and test sets."""
  212. test_mode = set_type == "test"
  213. if test_mode:
  214. lines = lines[1:]
  215. text_index = 1 if test_mode else 3
  216. examples = []
  217. for i, line in enumerate(lines):
  218. guid = f"{set_type}-{i}"
  219. text_a = line[text_index]
  220. label = None if test_mode else line[1]
  221. examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
  222. return examples
  223. class Sst2Processor(DataProcessor):
  224. """Processor for the SST-2 data set (GLUE version)."""
  225. def __init__(self, *args, **kwargs):
  226. super().__init__(*args, **kwargs)
  227. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  228. def get_example_from_tensor_dict(self, tensor_dict):
  229. """See base class."""
  230. return InputExample(
  231. tensor_dict["idx"].numpy(),
  232. tensor_dict["sentence"].numpy().decode("utf-8"),
  233. None,
  234. str(tensor_dict["label"].numpy()),
  235. )
  236. def get_train_examples(self, data_dir):
  237. """See base class."""
  238. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  239. def get_dev_examples(self, data_dir):
  240. """See base class."""
  241. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  242. def get_test_examples(self, data_dir):
  243. """See base class."""
  244. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  245. def get_labels(self):
  246. """See base class."""
  247. return ["0", "1"]
  248. def _create_examples(self, lines, set_type):
  249. """Creates examples for the training, dev and test sets."""
  250. examples = []
  251. text_index = 1 if set_type == "test" else 0
  252. for i, line in enumerate(lines):
  253. if i == 0:
  254. continue
  255. guid = f"{set_type}-{i}"
  256. text_a = line[text_index]
  257. label = None if set_type == "test" else line[1]
  258. examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
  259. return examples
  260. class StsbProcessor(DataProcessor):
  261. """Processor for the STS-B data set (GLUE version)."""
  262. def __init__(self, *args, **kwargs):
  263. super().__init__(*args, **kwargs)
  264. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  265. def get_example_from_tensor_dict(self, tensor_dict):
  266. """See base class."""
  267. return InputExample(
  268. tensor_dict["idx"].numpy(),
  269. tensor_dict["sentence1"].numpy().decode("utf-8"),
  270. tensor_dict["sentence2"].numpy().decode("utf-8"),
  271. str(tensor_dict["label"].numpy()),
  272. )
  273. def get_train_examples(self, data_dir):
  274. """See base class."""
  275. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  276. def get_dev_examples(self, data_dir):
  277. """See base class."""
  278. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  279. def get_test_examples(self, data_dir):
  280. """See base class."""
  281. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  282. def get_labels(self):
  283. """See base class."""
  284. return [None]
  285. def _create_examples(self, lines, set_type):
  286. """Creates examples for the training, dev and test sets."""
  287. examples = []
  288. for i, line in enumerate(lines):
  289. if i == 0:
  290. continue
  291. guid = f"{set_type}-{line[0]}"
  292. text_a = line[7]
  293. text_b = line[8]
  294. label = None if set_type == "test" else line[-1]
  295. examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  296. return examples
  297. class QqpProcessor(DataProcessor):
  298. """Processor for the QQP data set (GLUE version)."""
  299. def __init__(self, *args, **kwargs):
  300. super().__init__(*args, **kwargs)
  301. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  302. def get_example_from_tensor_dict(self, tensor_dict):
  303. """See base class."""
  304. return InputExample(
  305. tensor_dict["idx"].numpy(),
  306. tensor_dict["question1"].numpy().decode("utf-8"),
  307. tensor_dict["question2"].numpy().decode("utf-8"),
  308. str(tensor_dict["label"].numpy()),
  309. )
  310. def get_train_examples(self, data_dir):
  311. """See base class."""
  312. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  313. def get_dev_examples(self, data_dir):
  314. """See base class."""
  315. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  316. def get_test_examples(self, data_dir):
  317. """See base class."""
  318. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  319. def get_labels(self):
  320. """See base class."""
  321. return ["0", "1"]
  322. def _create_examples(self, lines, set_type):
  323. """Creates examples for the training, dev and test sets."""
  324. test_mode = set_type == "test"
  325. q1_index = 1 if test_mode else 3
  326. q2_index = 2 if test_mode else 4
  327. examples = []
  328. for i, line in enumerate(lines):
  329. if i == 0:
  330. continue
  331. guid = f"{set_type}-{line[0]}"
  332. try:
  333. text_a = line[q1_index]
  334. text_b = line[q2_index]
  335. label = None if test_mode else line[5]
  336. except IndexError:
  337. continue
  338. examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  339. return examples
  340. class QnliProcessor(DataProcessor):
  341. """Processor for the QNLI data set (GLUE version)."""
  342. def __init__(self, *args, **kwargs):
  343. super().__init__(*args, **kwargs)
  344. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  345. def get_example_from_tensor_dict(self, tensor_dict):
  346. """See base class."""
  347. return InputExample(
  348. tensor_dict["idx"].numpy(),
  349. tensor_dict["question"].numpy().decode("utf-8"),
  350. tensor_dict["sentence"].numpy().decode("utf-8"),
  351. str(tensor_dict["label"].numpy()),
  352. )
  353. def get_train_examples(self, data_dir):
  354. """See base class."""
  355. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  356. def get_dev_examples(self, data_dir):
  357. """See base class."""
  358. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  359. def get_test_examples(self, data_dir):
  360. """See base class."""
  361. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  362. def get_labels(self):
  363. """See base class."""
  364. return ["entailment", "not_entailment"]
  365. def _create_examples(self, lines, set_type):
  366. """Creates examples for the training, dev and test sets."""
  367. examples = []
  368. for i, line in enumerate(lines):
  369. if i == 0:
  370. continue
  371. guid = f"{set_type}-{line[0]}"
  372. text_a = line[1]
  373. text_b = line[2]
  374. label = None if set_type == "test" else line[-1]
  375. examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  376. return examples
  377. class RteProcessor(DataProcessor):
  378. """Processor for the RTE data set (GLUE version)."""
  379. def __init__(self, *args, **kwargs):
  380. super().__init__(*args, **kwargs)
  381. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  382. def get_example_from_tensor_dict(self, tensor_dict):
  383. """See base class."""
  384. return InputExample(
  385. tensor_dict["idx"].numpy(),
  386. tensor_dict["sentence1"].numpy().decode("utf-8"),
  387. tensor_dict["sentence2"].numpy().decode("utf-8"),
  388. str(tensor_dict["label"].numpy()),
  389. )
  390. def get_train_examples(self, data_dir):
  391. """See base class."""
  392. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  393. def get_dev_examples(self, data_dir):
  394. """See base class."""
  395. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  396. def get_test_examples(self, data_dir):
  397. """See base class."""
  398. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  399. def get_labels(self):
  400. """See base class."""
  401. return ["entailment", "not_entailment"]
  402. def _create_examples(self, lines, set_type):
  403. """Creates examples for the training, dev and test sets."""
  404. examples = []
  405. for i, line in enumerate(lines):
  406. if i == 0:
  407. continue
  408. guid = f"{set_type}-{line[0]}"
  409. text_a = line[1]
  410. text_b = line[2]
  411. label = None if set_type == "test" else line[-1]
  412. examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  413. return examples
  414. class WnliProcessor(DataProcessor):
  415. """Processor for the WNLI data set (GLUE version)."""
  416. def __init__(self, *args, **kwargs):
  417. super().__init__(*args, **kwargs)
  418. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  419. def get_example_from_tensor_dict(self, tensor_dict):
  420. """See base class."""
  421. return InputExample(
  422. tensor_dict["idx"].numpy(),
  423. tensor_dict["sentence1"].numpy().decode("utf-8"),
  424. tensor_dict["sentence2"].numpy().decode("utf-8"),
  425. str(tensor_dict["label"].numpy()),
  426. )
  427. def get_train_examples(self, data_dir):
  428. """See base class."""
  429. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  430. def get_dev_examples(self, data_dir):
  431. """See base class."""
  432. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  433. def get_test_examples(self, data_dir):
  434. """See base class."""
  435. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  436. def get_labels(self):
  437. """See base class."""
  438. return ["0", "1"]
  439. def _create_examples(self, lines, set_type):
  440. """Creates examples for the training, dev and test sets."""
  441. examples = []
  442. for i, line in enumerate(lines):
  443. if i == 0:
  444. continue
  445. guid = f"{set_type}-{line[0]}"
  446. text_a = line[1]
  447. text_b = line[2]
  448. label = None if set_type == "test" else line[-1]
  449. examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  450. return examples
  451. glue_tasks_num_labels = {
  452. "cola": 2,
  453. "mnli": 3,
  454. "mrpc": 2,
  455. "sst-2": 2,
  456. "sts-b": 1,
  457. "qqp": 2,
  458. "qnli": 2,
  459. "rte": 2,
  460. "wnli": 2,
  461. }
  462. glue_processors = {
  463. "cola": ColaProcessor,
  464. "mnli": MnliProcessor,
  465. "mnli-mm": MnliMismatchedProcessor,
  466. "mrpc": MrpcProcessor,
  467. "sst-2": Sst2Processor,
  468. "sts-b": StsbProcessor,
  469. "qqp": QqpProcessor,
  470. "qnli": QnliProcessor,
  471. "rte": RteProcessor,
  472. "wnli": WnliProcessor,
  473. }
  474. glue_output_modes = {
  475. "cola": "classification",
  476. "mnli": "classification",
  477. "mnli-mm": "classification",
  478. "mrpc": "classification",
  479. "sst-2": "classification",
  480. "sts-b": "regression",
  481. "qqp": "classification",
  482. "qnli": "classification",
  483. "rte": "classification",
  484. "wnli": "classification",
  485. }