squad_metrics.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was modified by XLNet authors to
  16. update `find_best_threshold` scripts for SQuAD V2.0
  17. In addition to basic functionality, we also compute additional statistics and plot precision-recall curves if an
  18. additional na_prob.json file is provided. This file is expected to map question ID's to the model's predicted
  19. probability that a question is unanswerable.
  20. """
  21. import collections
  22. import json
  23. import math
  24. import re
  25. import string
  26. from ...models.bert import BasicTokenizer
  27. from ...utils import logging
  28. logger = logging.get_logger(__name__)
  29. def normalize_answer(s):
  30. """Lower text and remove punctuation, articles and extra whitespace."""
  31. def remove_articles(text):
  32. regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
  33. return re.sub(regex, " ", text)
  34. def white_space_fix(text):
  35. return " ".join(text.split())
  36. def remove_punc(text):
  37. exclude = set(string.punctuation)
  38. return "".join(ch for ch in text if ch not in exclude)
  39. def lower(text):
  40. return text.lower()
  41. return white_space_fix(remove_articles(remove_punc(lower(s))))
  42. def get_tokens(s):
  43. if not s:
  44. return []
  45. return normalize_answer(s).split()
  46. def compute_exact(a_gold, a_pred):
  47. return int(normalize_answer(a_gold) == normalize_answer(a_pred))
  48. def compute_f1(a_gold, a_pred):
  49. gold_toks = get_tokens(a_gold)
  50. pred_toks = get_tokens(a_pred)
  51. common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
  52. num_same = sum(common.values())
  53. if len(gold_toks) == 0 or len(pred_toks) == 0:
  54. # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
  55. return int(gold_toks == pred_toks)
  56. if num_same == 0:
  57. return 0
  58. precision = 1.0 * num_same / len(pred_toks)
  59. recall = 1.0 * num_same / len(gold_toks)
  60. f1 = (2 * precision * recall) / (precision + recall)
  61. return f1
  62. def get_raw_scores(examples, preds):
  63. """
  64. Computes the exact and f1 scores from the examples and the model predictions
  65. """
  66. exact_scores = {}
  67. f1_scores = {}
  68. for example in examples:
  69. qas_id = example.qas_id
  70. gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])]
  71. if not gold_answers:
  72. # For unanswerable questions, only correct answer is empty string
  73. gold_answers = [""]
  74. if qas_id not in preds:
  75. print(f"Missing prediction for {qas_id}")
  76. continue
  77. prediction = preds[qas_id]
  78. exact_scores[qas_id] = max(compute_exact(a, prediction) for a in gold_answers)
  79. f1_scores[qas_id] = max(compute_f1(a, prediction) for a in gold_answers)
  80. return exact_scores, f1_scores
  81. def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
  82. new_scores = {}
  83. for qid, s in scores.items():
  84. pred_na = na_probs[qid] > na_prob_thresh
  85. if pred_na:
  86. new_scores[qid] = float(not qid_to_has_ans[qid])
  87. else:
  88. new_scores[qid] = s
  89. return new_scores
  90. def make_eval_dict(exact_scores, f1_scores, qid_list=None):
  91. if not qid_list:
  92. total = len(exact_scores)
  93. return collections.OrderedDict(
  94. [
  95. ("exact", 100.0 * sum(exact_scores.values()) / total),
  96. ("f1", 100.0 * sum(f1_scores.values()) / total),
  97. ("total", total),
  98. ]
  99. )
  100. else:
  101. total = len(qid_list)
  102. return collections.OrderedDict(
  103. [
  104. ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
  105. ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
  106. ("total", total),
  107. ]
  108. )
  109. def merge_eval(main_eval, new_eval, prefix):
  110. for k in new_eval:
  111. main_eval[f"{prefix}_{k}"] = new_eval[k]
  112. def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
  113. num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
  114. cur_score = num_no_ans
  115. best_score = cur_score
  116. best_thresh = 0.0
  117. qid_list = sorted(na_probs, key=lambda k: na_probs[k])
  118. for qid in qid_list:
  119. if qid not in scores:
  120. continue
  121. if qid_to_has_ans[qid]:
  122. diff = scores[qid]
  123. else:
  124. if preds[qid]:
  125. diff = -1
  126. else:
  127. diff = 0
  128. cur_score += diff
  129. if cur_score > best_score:
  130. best_score = cur_score
  131. best_thresh = na_probs[qid]
  132. has_ans_score, has_ans_cnt = 0, 0
  133. for qid in qid_list:
  134. if not qid_to_has_ans[qid]:
  135. continue
  136. has_ans_cnt += 1
  137. if qid not in scores:
  138. continue
  139. has_ans_score += scores[qid]
  140. return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
  141. def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
  142. best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
  143. best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
  144. main_eval["best_exact"] = best_exact
  145. main_eval["best_exact_thresh"] = exact_thresh
  146. main_eval["best_f1"] = best_f1
  147. main_eval["best_f1_thresh"] = f1_thresh
  148. main_eval["has_ans_exact"] = has_ans_exact
  149. main_eval["has_ans_f1"] = has_ans_f1
  150. def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
  151. num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
  152. cur_score = num_no_ans
  153. best_score = cur_score
  154. best_thresh = 0.0
  155. qid_list = sorted(na_probs, key=lambda k: na_probs[k])
  156. for _, qid in enumerate(qid_list):
  157. if qid not in scores:
  158. continue
  159. if qid_to_has_ans[qid]:
  160. diff = scores[qid]
  161. else:
  162. if preds[qid]:
  163. diff = -1
  164. else:
  165. diff = 0
  166. cur_score += diff
  167. if cur_score > best_score:
  168. best_score = cur_score
  169. best_thresh = na_probs[qid]
  170. return 100.0 * best_score / len(scores), best_thresh
  171. def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
  172. best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
  173. best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
  174. main_eval["best_exact"] = best_exact
  175. main_eval["best_exact_thresh"] = exact_thresh
  176. main_eval["best_f1"] = best_f1
  177. main_eval["best_f1_thresh"] = f1_thresh
  178. def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
  179. qas_id_to_has_answer = {example.qas_id: bool(example.answers) for example in examples}
  180. has_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if has_answer]
  181. no_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if not has_answer]
  182. if no_answer_probs is None:
  183. no_answer_probs = dict.fromkeys(preds, 0.0)
  184. exact, f1 = get_raw_scores(examples, preds)
  185. exact_threshold = apply_no_ans_threshold(
  186. exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold
  187. )
  188. f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
  189. evaluation = make_eval_dict(exact_threshold, f1_threshold)
  190. if has_answer_qids:
  191. has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
  192. merge_eval(evaluation, has_ans_eval, "HasAns")
  193. if no_answer_qids:
  194. no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
  195. merge_eval(evaluation, no_ans_eval, "NoAns")
  196. if no_answer_probs:
  197. find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
  198. return evaluation
  199. def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
  200. """Project the tokenized prediction back to the original text."""
  201. # When we created the data, we kept track of the alignment between original
  202. # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
  203. # now `orig_text` contains the span of our original text corresponding to the
  204. # span that we predicted.
  205. #
  206. # However, `orig_text` may contain extra characters that we don't want in
  207. # our prediction.
  208. #
  209. # For example, let's say:
  210. # pred_text = steve smith
  211. # orig_text = Steve Smith's
  212. #
  213. # We don't want to return `orig_text` because it contains the extra "'s".
  214. #
  215. # We don't want to return `pred_text` because it's already been normalized
  216. # (the SQuAD eval script also does punctuation stripping/lower casing but
  217. # our tokenizer does additional normalization like stripping accent
  218. # characters).
  219. #
  220. # What we really want to return is "Steve Smith".
  221. #
  222. # Therefore, we have to apply a semi-complicated alignment heuristic between
  223. # `pred_text` and `orig_text` to get a character-to-character alignment. This
  224. # can fail in certain cases in which case we just return `orig_text`.
  225. def _strip_spaces(text):
  226. ns_chars = []
  227. ns_to_s_map = collections.OrderedDict()
  228. for i, c in enumerate(text):
  229. if c == " ":
  230. continue
  231. ns_to_s_map[len(ns_chars)] = i
  232. ns_chars.append(c)
  233. ns_text = "".join(ns_chars)
  234. return (ns_text, ns_to_s_map)
  235. # We first tokenize `orig_text`, strip whitespace from the result
  236. # and `pred_text`, and check if they are the same length. If they are
  237. # NOT the same length, the heuristic has failed. If they are the same
  238. # length, we assume the characters are one-to-one aligned.
  239. tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
  240. tok_text = " ".join(tokenizer.tokenize(orig_text))
  241. start_position = tok_text.find(pred_text)
  242. if start_position == -1:
  243. if verbose_logging:
  244. logger.info(f"Unable to find text: '{pred_text}' in '{orig_text}'")
  245. return orig_text
  246. end_position = start_position + len(pred_text) - 1
  247. (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
  248. (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
  249. if len(orig_ns_text) != len(tok_ns_text):
  250. if verbose_logging:
  251. logger.info(f"Length not equal after stripping spaces: '{orig_ns_text}' vs '{tok_ns_text}'")
  252. return orig_text
  253. # We then project the characters in `pred_text` back to `orig_text` using
  254. # the character-to-character alignment.
  255. tok_s_to_ns_map = {}
  256. for i, tok_index in tok_ns_to_s_map.items():
  257. tok_s_to_ns_map[tok_index] = i
  258. orig_start_position = None
  259. if start_position in tok_s_to_ns_map:
  260. ns_start_position = tok_s_to_ns_map[start_position]
  261. if ns_start_position in orig_ns_to_s_map:
  262. orig_start_position = orig_ns_to_s_map[ns_start_position]
  263. if orig_start_position is None:
  264. if verbose_logging:
  265. logger.info("Couldn't map start position")
  266. return orig_text
  267. orig_end_position = None
  268. if end_position in tok_s_to_ns_map:
  269. ns_end_position = tok_s_to_ns_map[end_position]
  270. if ns_end_position in orig_ns_to_s_map:
  271. orig_end_position = orig_ns_to_s_map[ns_end_position]
  272. if orig_end_position is None:
  273. if verbose_logging:
  274. logger.info("Couldn't map end position")
  275. return orig_text
  276. output_text = orig_text[orig_start_position : (orig_end_position + 1)]
  277. return output_text
  278. def _get_best_indexes(logits, n_best_size):
  279. """Get the n-best logits from a list."""
  280. index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
  281. best_indexes = []
  282. for i in range(len(index_and_score)):
  283. if i >= n_best_size:
  284. break
  285. best_indexes.append(index_and_score[i][0])
  286. return best_indexes
  287. def _compute_softmax(scores):
  288. """Compute softmax probability over raw logits."""
  289. if not scores:
  290. return []
  291. max_score = None
  292. for score in scores:
  293. if max_score is None or score > max_score:
  294. max_score = score
  295. exp_scores = []
  296. total_sum = 0.0
  297. for score in scores:
  298. x = math.exp(score - max_score)
  299. exp_scores.append(x)
  300. total_sum += x
  301. probs = []
  302. for score in exp_scores:
  303. probs.append(score / total_sum)
  304. return probs
  305. def compute_predictions_logits(
  306. all_examples,
  307. all_features,
  308. all_results,
  309. n_best_size,
  310. max_answer_length,
  311. do_lower_case,
  312. output_prediction_file,
  313. output_nbest_file,
  314. output_null_log_odds_file,
  315. verbose_logging,
  316. version_2_with_negative,
  317. null_score_diff_threshold,
  318. tokenizer,
  319. ):
  320. """Write final predictions to the json file and log-odds of null if needed."""
  321. if output_prediction_file:
  322. logger.info(f"Writing predictions to: {output_prediction_file}")
  323. if output_nbest_file:
  324. logger.info(f"Writing nbest to: {output_nbest_file}")
  325. if output_null_log_odds_file and version_2_with_negative:
  326. logger.info(f"Writing null_log_odds to: {output_null_log_odds_file}")
  327. example_index_to_features = collections.defaultdict(list)
  328. for feature in all_features:
  329. example_index_to_features[feature.example_index].append(feature)
  330. unique_id_to_result = {}
  331. for result in all_results:
  332. unique_id_to_result[result.unique_id] = result
  333. _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
  334. "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
  335. )
  336. all_predictions = collections.OrderedDict()
  337. all_nbest_json = collections.OrderedDict()
  338. scores_diff_json = collections.OrderedDict()
  339. for example_index, example in enumerate(all_examples):
  340. features = example_index_to_features[example_index]
  341. prelim_predictions = []
  342. # keep track of the minimum score of null start+end of position 0
  343. score_null = 1000000 # large and positive
  344. min_null_feature_index = 0 # the paragraph slice with min null score
  345. null_start_logit = 0 # the start logit at the slice with min null score
  346. null_end_logit = 0 # the end logit at the slice with min null score
  347. for feature_index, feature in enumerate(features):
  348. result = unique_id_to_result[feature.unique_id]
  349. start_indexes = _get_best_indexes(result.start_logits, n_best_size)
  350. end_indexes = _get_best_indexes(result.end_logits, n_best_size)
  351. # if we could have irrelevant answers, get the min score of irrelevant
  352. if version_2_with_negative:
  353. feature_null_score = result.start_logits[0] + result.end_logits[0]
  354. if feature_null_score < score_null:
  355. score_null = feature_null_score
  356. min_null_feature_index = feature_index
  357. null_start_logit = result.start_logits[0]
  358. null_end_logit = result.end_logits[0]
  359. for start_index in start_indexes:
  360. for end_index in end_indexes:
  361. # We could hypothetically create invalid predictions, e.g., predict
  362. # that the start of the span is in the question. We throw out all
  363. # invalid predictions.
  364. if start_index >= len(feature.tokens):
  365. continue
  366. if end_index >= len(feature.tokens):
  367. continue
  368. if start_index not in feature.token_to_orig_map:
  369. continue
  370. if end_index not in feature.token_to_orig_map:
  371. continue
  372. if not feature.token_is_max_context.get(start_index, False):
  373. continue
  374. if end_index < start_index:
  375. continue
  376. length = end_index - start_index + 1
  377. if length > max_answer_length:
  378. continue
  379. prelim_predictions.append(
  380. _PrelimPrediction(
  381. feature_index=feature_index,
  382. start_index=start_index,
  383. end_index=end_index,
  384. start_logit=result.start_logits[start_index],
  385. end_logit=result.end_logits[end_index],
  386. )
  387. )
  388. if version_2_with_negative:
  389. prelim_predictions.append(
  390. _PrelimPrediction(
  391. feature_index=min_null_feature_index,
  392. start_index=0,
  393. end_index=0,
  394. start_logit=null_start_logit,
  395. end_logit=null_end_logit,
  396. )
  397. )
  398. prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
  399. _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
  400. "NbestPrediction", ["text", "start_logit", "end_logit"]
  401. )
  402. seen_predictions = {}
  403. nbest = []
  404. for pred in prelim_predictions:
  405. if len(nbest) >= n_best_size:
  406. break
  407. feature = features[pred.feature_index]
  408. if pred.start_index > 0: # this is a non-null prediction
  409. tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
  410. orig_doc_start = feature.token_to_orig_map[pred.start_index]
  411. orig_doc_end = feature.token_to_orig_map[pred.end_index]
  412. orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
  413. tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
  414. # tok_text = " ".join(tok_tokens)
  415. #
  416. # # De-tokenize WordPieces that have been split off.
  417. # tok_text = tok_text.replace(" ##", "")
  418. # tok_text = tok_text.replace("##", "")
  419. # Clean whitespace
  420. tok_text = tok_text.strip()
  421. tok_text = " ".join(tok_text.split())
  422. orig_text = " ".join(orig_tokens)
  423. final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
  424. if final_text in seen_predictions:
  425. continue
  426. seen_predictions[final_text] = True
  427. else:
  428. final_text = ""
  429. seen_predictions[final_text] = True
  430. nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
  431. # if we didn't include the empty option in the n-best, include it
  432. if version_2_with_negative:
  433. if "" not in seen_predictions:
  434. nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
  435. # In very rare edge cases we could only have single null prediction.
  436. # So we just create a nonce prediction in this case to avoid failure.
  437. if len(nbest) == 1:
  438. nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
  439. # In very rare edge cases we could have no valid predictions. So we
  440. # just create a nonce prediction in this case to avoid failure.
  441. if not nbest:
  442. nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
  443. if len(nbest) < 1:
  444. raise ValueError("No valid predictions")
  445. total_scores = []
  446. best_non_null_entry = None
  447. for entry in nbest:
  448. total_scores.append(entry.start_logit + entry.end_logit)
  449. if not best_non_null_entry:
  450. if entry.text:
  451. best_non_null_entry = entry
  452. probs = _compute_softmax(total_scores)
  453. nbest_json = []
  454. for i, entry in enumerate(nbest):
  455. output = collections.OrderedDict()
  456. output["text"] = entry.text
  457. output["probability"] = probs[i]
  458. output["start_logit"] = entry.start_logit
  459. output["end_logit"] = entry.end_logit
  460. nbest_json.append(output)
  461. if len(nbest_json) < 1:
  462. raise ValueError("No valid predictions")
  463. if not version_2_with_negative:
  464. all_predictions[example.qas_id] = nbest_json[0]["text"]
  465. else:
  466. # predict "" iff the null score - the score of best non-null > threshold
  467. score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
  468. scores_diff_json[example.qas_id] = score_diff
  469. if score_diff > null_score_diff_threshold:
  470. all_predictions[example.qas_id] = ""
  471. else:
  472. all_predictions[example.qas_id] = best_non_null_entry.text
  473. all_nbest_json[example.qas_id] = nbest_json
  474. if output_prediction_file:
  475. with open(output_prediction_file, "w") as writer:
  476. writer.write(json.dumps(all_predictions, indent=4) + "\n")
  477. if output_nbest_file:
  478. with open(output_nbest_file, "w") as writer:
  479. writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
  480. if output_null_log_odds_file and version_2_with_negative:
  481. with open(output_null_log_odds_file, "w") as writer:
  482. writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
  483. return all_predictions
  484. def compute_predictions_log_probs(
  485. all_examples,
  486. all_features,
  487. all_results,
  488. n_best_size,
  489. max_answer_length,
  490. output_prediction_file,
  491. output_nbest_file,
  492. output_null_log_odds_file,
  493. start_n_top,
  494. end_n_top,
  495. version_2_with_negative,
  496. tokenizer,
  497. verbose_logging,
  498. ):
  499. """
  500. XLNet write prediction logic (more complex than Bert's). Write final predictions to the json file and log-odds of
  501. null if needed.
  502. Requires utils_squad_evaluate.py
  503. """
  504. _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
  505. "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
  506. )
  507. _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
  508. "NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
  509. )
  510. logger.info(f"Writing predictions to: {output_prediction_file}")
  511. example_index_to_features = collections.defaultdict(list)
  512. for feature in all_features:
  513. example_index_to_features[feature.example_index].append(feature)
  514. unique_id_to_result = {}
  515. for result in all_results:
  516. unique_id_to_result[result.unique_id] = result
  517. all_predictions = collections.OrderedDict()
  518. all_nbest_json = collections.OrderedDict()
  519. scores_diff_json = collections.OrderedDict()
  520. for example_index, example in enumerate(all_examples):
  521. features = example_index_to_features[example_index]
  522. prelim_predictions = []
  523. # keep track of the minimum score of null start+end of position 0
  524. score_null = 1000000 # large and positive
  525. for feature_index, feature in enumerate(features):
  526. result = unique_id_to_result[feature.unique_id]
  527. cur_null_score = result.cls_logits
  528. # if we could have irrelevant answers, get the min score of irrelevant
  529. score_null = min(score_null, cur_null_score)
  530. for i in range(start_n_top):
  531. for j in range(end_n_top):
  532. start_log_prob = result.start_logits[i]
  533. start_index = result.start_top_index[i]
  534. j_index = i * end_n_top + j
  535. end_log_prob = result.end_logits[j_index]
  536. end_index = result.end_top_index[j_index]
  537. # We could hypothetically create invalid predictions, e.g., predict
  538. # that the start of the span is in the question. We throw out all
  539. # invalid predictions.
  540. if start_index >= feature.paragraph_len - 1:
  541. continue
  542. if end_index >= feature.paragraph_len - 1:
  543. continue
  544. if not feature.token_is_max_context.get(start_index, False):
  545. continue
  546. if end_index < start_index:
  547. continue
  548. length = end_index - start_index + 1
  549. if length > max_answer_length:
  550. continue
  551. prelim_predictions.append(
  552. _PrelimPrediction(
  553. feature_index=feature_index,
  554. start_index=start_index,
  555. end_index=end_index,
  556. start_log_prob=start_log_prob,
  557. end_log_prob=end_log_prob,
  558. )
  559. )
  560. prelim_predictions = sorted(
  561. prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
  562. )
  563. seen_predictions = {}
  564. nbest = []
  565. for pred in prelim_predictions:
  566. if len(nbest) >= n_best_size:
  567. break
  568. feature = features[pred.feature_index]
  569. # XLNet un-tokenizer
  570. # Let's keep it simple for now and see if we need all this later.
  571. #
  572. # tok_start_to_orig_index = feature.tok_start_to_orig_index
  573. # tok_end_to_orig_index = feature.tok_end_to_orig_index
  574. # start_orig_pos = tok_start_to_orig_index[pred.start_index]
  575. # end_orig_pos = tok_end_to_orig_index[pred.end_index]
  576. # paragraph_text = example.paragraph_text
  577. # final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
  578. # Previously used Bert untokenizer
  579. tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
  580. orig_doc_start = feature.token_to_orig_map[pred.start_index]
  581. orig_doc_end = feature.token_to_orig_map[pred.end_index]
  582. orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
  583. tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
  584. # Clean whitespace
  585. tok_text = tok_text.strip()
  586. tok_text = " ".join(tok_text.split())
  587. orig_text = " ".join(orig_tokens)
  588. if hasattr(tokenizer, "do_lower_case"):
  589. do_lower_case = tokenizer.do_lower_case
  590. else:
  591. do_lower_case = tokenizer.do_lowercase_and_remove_accent
  592. final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
  593. if final_text in seen_predictions:
  594. continue
  595. seen_predictions[final_text] = True
  596. nbest.append(
  597. _NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
  598. )
  599. # In very rare edge cases we could have no valid predictions. So we
  600. # just create a nonce prediction in this case to avoid failure.
  601. if not nbest:
  602. nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
  603. total_scores = []
  604. best_non_null_entry = None
  605. for entry in nbest:
  606. total_scores.append(entry.start_log_prob + entry.end_log_prob)
  607. if not best_non_null_entry:
  608. best_non_null_entry = entry
  609. probs = _compute_softmax(total_scores)
  610. nbest_json = []
  611. for i, entry in enumerate(nbest):
  612. output = collections.OrderedDict()
  613. output["text"] = entry.text
  614. output["probability"] = probs[i]
  615. output["start_log_prob"] = entry.start_log_prob
  616. output["end_log_prob"] = entry.end_log_prob
  617. nbest_json.append(output)
  618. if len(nbest_json) < 1:
  619. raise ValueError("No valid predictions")
  620. if best_non_null_entry is None:
  621. raise ValueError("No valid predictions")
  622. score_diff = score_null
  623. scores_diff_json[example.qas_id] = score_diff
  624. # note(zhiliny): always predict best_non_null_entry
  625. # and the evaluation script will search for the best threshold
  626. all_predictions[example.qas_id] = best_non_null_entry.text
  627. all_nbest_json[example.qas_id] = nbest_json
  628. with open(output_prediction_file, "w") as writer:
  629. writer.write(json.dumps(all_predictions, indent=4) + "\n")
  630. with open(output_nbest_file, "w") as writer:
  631. writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
  632. if version_2_with_negative:
  633. with open(output_null_log_odds_file, "w") as writer:
  634. writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
  635. return all_predictions