squad.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Adapted from:
  15. # Link: https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
  16. # Link: https://github.com/huggingface/datasets/blob/master/metrics/squad/squad.py
  17. import re
  18. import string
  19. from collections import Counter
  20. from typing import Any, Callable, Union
  21. from torch import Tensor, tensor
  22. from torchmetrics.utilities import rank_zero_warn
  23. SINGLE_PRED_TYPE = dict[str, str]
  24. PREDS_TYPE = Union[SINGLE_PRED_TYPE, list[SINGLE_PRED_TYPE]]
  25. SINGLE_TARGET_TYPE = dict[str, Union[str, dict[str, Union[list[str], list[int]]]]]
  26. TARGETS_TYPE = Union[SINGLE_TARGET_TYPE, list[SINGLE_TARGET_TYPE]]
  27. UPDATE_METHOD_SINGLE_PRED_TYPE = Union[list[dict[str, Union[str, int]]], str, dict[str, Union[list[str], list[int]]]]
  28. SQuAD_FORMAT = {
  29. "answers": {"answer_start": [1], "text": ["This is a test text"]},
  30. "context": "This is a test context.",
  31. "id": "1",
  32. "question": "Is this a test?",
  33. "title": "train test",
  34. }
  35. def _normalize_text(s: str) -> str:
  36. """Lower text and remove punctuation, articles and extra whitespace."""
  37. def remove_articles(text: str) -> str:
  38. return re.sub(r"\b(a|an|the)\b", " ", text)
  39. def white_space_fix(text: str) -> str:
  40. return " ".join(text.split())
  41. def remove_punc(text: str) -> str:
  42. exclude = set(string.punctuation)
  43. return "".join(ch for ch in text if ch not in exclude)
  44. def lower(text: str) -> str:
  45. return text.lower()
  46. return white_space_fix(remove_articles(remove_punc(lower(s))))
  47. def _get_tokens(s: str) -> list[str]:
  48. """Split a sentence into separate tokens."""
  49. return [] if not s else _normalize_text(s).split()
  50. def _compute_f1_score(predicted_answer: str, target_answer: str) -> Tensor:
  51. """Compute F1 Score for two sentences."""
  52. target_tokens = _get_tokens(target_answer)
  53. predicted_tokens = _get_tokens(predicted_answer)
  54. common = Counter(target_tokens) & Counter(predicted_tokens)
  55. num_same = tensor(sum(common.values()))
  56. if len(target_tokens) == 0 or len(predicted_tokens) == 0:
  57. # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
  58. return tensor(int(target_tokens == predicted_tokens))
  59. if num_same == 0:
  60. return tensor(0.0)
  61. precision = 1.0 * num_same / tensor(len(predicted_tokens))
  62. recall = 1.0 * num_same / tensor(len(target_tokens))
  63. return (2 * precision * recall) / (precision + recall)
  64. def _compute_exact_match_score(prediction: str, ground_truth: str) -> Tensor:
  65. """Compute Exact Match for two sentences."""
  66. return tensor(int(_normalize_text(prediction) == _normalize_text(ground_truth)))
  67. def _metric_max_over_ground_truths(
  68. metric_fn: Callable[[str, str], Tensor], prediction: str, ground_truths: list[str]
  69. ) -> Tensor:
  70. """Calculate maximum score for a predicted answer with all reference answers."""
  71. return max(metric_fn(prediction, truth) for truth in ground_truths) # type: ignore[type-var]
  72. def _squad_input_check(
  73. preds: PREDS_TYPE, targets: TARGETS_TYPE
  74. ) -> tuple[dict[str, str], list[dict[str, list[dict[str, list[dict[str, Any]]]]]]]:
  75. """Check for types and convert the input to necessary format to compute the input."""
  76. if isinstance(preds, dict):
  77. preds = [preds]
  78. if isinstance(targets, dict):
  79. targets = [targets]
  80. for pred in preds:
  81. pred_keys = pred.keys()
  82. if "prediction_text" not in pred_keys or "id" not in pred_keys:
  83. raise KeyError(
  84. "Expected keys in a single prediction are 'prediction_text' and 'id'."
  85. "Please make sure that 'prediction_text' maps to the answer string and 'id' maps to the key string."
  86. )
  87. for target in targets:
  88. target_keys = target.keys()
  89. if "answers" not in target_keys or "id" not in target_keys:
  90. raise KeyError(
  91. "Expected keys in a single target are 'answers' and 'id'."
  92. "Please make sure that 'answers' maps to a `SQuAD` format dictionary and 'id' maps to the key string.\n"
  93. "SQuAD Format: "
  94. f"{SQuAD_FORMAT}"
  95. )
  96. answers: dict[str, Union[list[str], list[int]]] = target["answers"] # type: ignore[assignment]
  97. if "text" not in answers:
  98. raise KeyError(
  99. "Expected keys in a 'answers' are 'text'."
  100. "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n"
  101. "SQuAD Format: "
  102. f"{SQuAD_FORMAT}"
  103. )
  104. preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds}
  105. _fn_answer = lambda tgt: {"answers": [{"text": txt} for txt in tgt["answers"]["text"]], "id": tgt["id"]}
  106. targets_dict = [{"paragraphs": [{"qas": [_fn_answer(target) for target in targets]}]}]
  107. return preds_dict, targets_dict
  108. def _squad_update(
  109. preds: dict[str, str],
  110. target: list[dict[str, list[dict[str, list[dict[str, Any]]]]]],
  111. ) -> tuple[Tensor, Tensor, Tensor]:
  112. """Compute F1 Score and Exact Match for a collection of predictions and references.
  113. Args:
  114. preds: A dictionary mapping an `id` to the predicted `answer`.
  115. target:
  116. A list of dictionary mapping `paragraphs` to list of dictionary mapping `qas` to a list of dictionary
  117. containing `id` and list of all possible `answers`.
  118. Return:
  119. Tuple containing F1 score, Exact match score and total number of examples.
  120. Example:
  121. >>> from torchmetrics.functional.text.squad import _squad_update
  122. >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
  123. >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]
  124. >>> preds_dict = {pred["id"]: pred["prediction_text"] for pred in preds}
  125. >>> targets_dict = [
  126. ... dict(paragraphs=[dict(qas=[dict(answers=[
  127. ... {"text": txt} for txt in tgt["answers"]["text"]], id=tgt["id"]) for tgt in target
  128. ... ])])
  129. ... ]
  130. >>> _squad_update(preds_dict, targets_dict)
  131. (tensor(1.), tensor(1.), tensor(1))
  132. """
  133. f1 = tensor(0.0)
  134. exact_match = tensor(0.0)
  135. total = tensor(0)
  136. for article in target:
  137. for paragraph in article["paragraphs"]:
  138. for qa in paragraph["qas"]:
  139. total += 1
  140. if qa["id"] not in preds:
  141. rank_zero_warn(f"Unanswered question {qa['id']} will receive score 0.")
  142. continue
  143. ground_truths = [x["text"] for x in qa["answers"]]
  144. pred = preds[qa["id"]]
  145. exact_match += _metric_max_over_ground_truths(_compute_exact_match_score, pred, ground_truths)
  146. f1 += _metric_max_over_ground_truths(_compute_f1_score, pred, ground_truths)
  147. return f1, exact_match, total
  148. def _squad_compute(f1: Tensor, exact_match: Tensor, total: Tensor) -> dict[str, Tensor]:
  149. """Aggregate the F1 Score and Exact match for the batch.
  150. Return:
  151. Dictionary containing the F1 score, Exact match score for the batch.
  152. """
  153. exact_match = 100.0 * exact_match / total
  154. f1 = 100.0 * f1 / total
  155. return {"exact_match": exact_match, "f1": f1}
  156. def squad(preds: PREDS_TYPE, target: TARGETS_TYPE) -> dict[str, Tensor]:
  157. """Calculate `SQuAD Metric`_ .
  158. Args:
  159. preds: A Dictionary or List of Dictionary-s that map `id` and `prediction_text` to the respective values.
  160. Example prediction:
  161. .. code-block:: python
  162. {"prediction_text": "TorchMetrics is awesome", "id": "123"}
  163. target: A Dictionary or List of Dictionary-s that contain the `answers` and `id` in the SQuAD Format.
  164. Example target:
  165. .. code-block:: python
  166. {
  167. 'answers': [{'answer_start': [1], 'text': ['This is a test answer']}],
  168. 'id': '1',
  169. }
  170. Reference SQuAD Format:
  171. .. code-block:: python
  172. {
  173. 'answers': {'answer_start': [1], 'text': ['This is a test text']},
  174. 'context': 'This is a test context.',
  175. 'id': '1',
  176. 'question': 'Is this a test?',
  177. 'title': 'train test'
  178. }
  179. Return:
  180. Dictionary containing the F1 score, Exact match score for the batch.
  181. Example:
  182. >>> from torchmetrics.functional.text.squad import squad
  183. >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
  184. >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]},"id": "56e10a3be3433e1400422b22"}]
  185. >>> squad(preds, target)
  186. {'exact_match': tensor(100.), 'f1': tensor(100.)}
  187. Raises:
  188. KeyError:
  189. If the required keys are missing in either predictions or targets.
  190. References:
  191. [1] SQuAD: 100,000+ Questions for Machine Comprehension of Text by Pranav Rajpurkar, Jian Zhang, Konstantin
  192. Lopyrev, Percy Liang `SQuAD Metric`_ .
  193. """
  194. preds_dict, target_dict = _squad_input_check(preds, target)
  195. f1, exact_match, total = _squad_update(preds_dict, target_dict)
  196. return _squad_compute(f1, exact_match, total)