| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # Adapted from:
- # Link: https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
- # Link: https://github.com/huggingface/datasets/blob/master/metrics/squad/squad.py
- import re
- import string
- from collections import Counter
- from typing import Any, Callable, Union
- from torch import Tensor, tensor
- from torchmetrics.utilities import rank_zero_warn
- SINGLE_PRED_TYPE = dict[str, str]
- PREDS_TYPE = Union[SINGLE_PRED_TYPE, list[SINGLE_PRED_TYPE]]
- SINGLE_TARGET_TYPE = dict[str, Union[str, dict[str, Union[list[str], list[int]]]]]
- TARGETS_TYPE = Union[SINGLE_TARGET_TYPE, list[SINGLE_TARGET_TYPE]]
- UPDATE_METHOD_SINGLE_PRED_TYPE = Union[list[dict[str, Union[str, int]]], str, dict[str, Union[list[str], list[int]]]]
- SQuAD_FORMAT = {
- "answers": {"answer_start": [1], "text": ["This is a test text"]},
- "context": "This is a test context.",
- "id": "1",
- "question": "Is this a test?",
- "title": "train test",
- }
- def _normalize_text(s: str) -> str:
- """Lower text and remove punctuation, articles and extra whitespace."""
- def remove_articles(text: str) -> str:
- return re.sub(r"\b(a|an|the)\b", " ", text)
- def white_space_fix(text: str) -> str:
- return " ".join(text.split())
- def remove_punc(text: str) -> str:
- exclude = set(string.punctuation)
- return "".join(ch for ch in text if ch not in exclude)
- def lower(text: str) -> str:
- return text.lower()
- return white_space_fix(remove_articles(remove_punc(lower(s))))
- def _get_tokens(s: str) -> list[str]:
- """Split a sentence into separate tokens."""
- return [] if not s else _normalize_text(s).split()
- def _compute_f1_score(predicted_answer: str, target_answer: str) -> Tensor:
- """Compute F1 Score for two sentences."""
- target_tokens = _get_tokens(target_answer)
- predicted_tokens = _get_tokens(predicted_answer)
- common = Counter(target_tokens) & Counter(predicted_tokens)
- num_same = tensor(sum(common.values()))
- if len(target_tokens) == 0 or len(predicted_tokens) == 0:
- # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
- return tensor(int(target_tokens == predicted_tokens))
- if num_same == 0:
- return tensor(0.0)
- precision = 1.0 * num_same / tensor(len(predicted_tokens))
- recall = 1.0 * num_same / tensor(len(target_tokens))
- return (2 * precision * recall) / (precision + recall)
- def _compute_exact_match_score(prediction: str, ground_truth: str) -> Tensor:
- """Compute Exact Match for two sentences."""
- return tensor(int(_normalize_text(prediction) == _normalize_text(ground_truth)))
- def _metric_max_over_ground_truths(
- metric_fn: Callable[[str, str], Tensor], prediction: str, ground_truths: list[str]
- ) -> Tensor:
- """Calculate maximum score for a predicted answer with all reference answers."""
- return max(metric_fn(prediction, truth) for truth in ground_truths) # type: ignore[type-var]
- def _squad_input_check(
- preds: PREDS_TYPE, targets: TARGETS_TYPE
- ) -> tuple[dict[str, str], list[dict[str, list[dict[str, list[dict[str, Any]]]]]]]:
- """Check for types and convert the input to necessary format to compute the input."""
- if isinstance(preds, dict):
- preds = [preds]
- if isinstance(targets, dict):
- targets = [targets]
- for pred in preds:
- pred_keys = pred.keys()
- if "prediction_text" not in pred_keys or "id" not in pred_keys:
- raise KeyError(
- "Expected keys in a single prediction are 'prediction_text' and 'id'."
- "Please make sure that 'prediction_text' maps to the answer string and 'id' maps to the key string."
- )
- for target in targets:
- target_keys = target.keys()
- if "answers" not in target_keys or "id" not in target_keys:
- raise KeyError(
- "Expected keys in a single target are 'answers' and 'id'."
- "Please make sure that 'answers' maps to a `SQuAD` format dictionary and 'id' maps to the key string.\n"
- "SQuAD Format: "
- f"{SQuAD_FORMAT}"
- )
- answers: dict[str, Union[list[str], list[int]]] = target["answers"] # type: ignore[assignment]
- if "text" not in answers:
- raise KeyError(
- "Expected keys in a 'answers' are 'text'."
- "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n"
- "SQuAD Format: "
- f"{SQuAD_FORMAT}"
- )
- preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds}
- _fn_answer = lambda tgt: {"answers": [{"text": txt} for txt in tgt["answers"]["text"]], "id": tgt["id"]}
- targets_dict = [{"paragraphs": [{"qas": [_fn_answer(target) for target in targets]}]}]
- return preds_dict, targets_dict
- def _squad_update(
- preds: dict[str, str],
- target: list[dict[str, list[dict[str, list[dict[str, Any]]]]]],
- ) -> tuple[Tensor, Tensor, Tensor]:
- """Compute F1 Score and Exact Match for a collection of predictions and references.
- Args:
- preds: A dictionary mapping an `id` to the predicted `answer`.
- target:
- A list of dictionary mapping `paragraphs` to list of dictionary mapping `qas` to a list of dictionary
- containing `id` and list of all possible `answers`.
- Return:
- Tuple containing F1 score, Exact match score and total number of examples.
- Example:
- >>> from torchmetrics.functional.text.squad import _squad_update
- >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
- >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]
- >>> preds_dict = {pred["id"]: pred["prediction_text"] for pred in preds}
- >>> targets_dict = [
- ... dict(paragraphs=[dict(qas=[dict(answers=[
- ... {"text": txt} for txt in tgt["answers"]["text"]], id=tgt["id"]) for tgt in target
- ... ])])
- ... ]
- >>> _squad_update(preds_dict, targets_dict)
- (tensor(1.), tensor(1.), tensor(1))
- """
- f1 = tensor(0.0)
- exact_match = tensor(0.0)
- total = tensor(0)
- for article in target:
- for paragraph in article["paragraphs"]:
- for qa in paragraph["qas"]:
- total += 1
- if qa["id"] not in preds:
- rank_zero_warn(f"Unanswered question {qa['id']} will receive score 0.")
- continue
- ground_truths = [x["text"] for x in qa["answers"]]
- pred = preds[qa["id"]]
- exact_match += _metric_max_over_ground_truths(_compute_exact_match_score, pred, ground_truths)
- f1 += _metric_max_over_ground_truths(_compute_f1_score, pred, ground_truths)
- return f1, exact_match, total
- def _squad_compute(f1: Tensor, exact_match: Tensor, total: Tensor) -> dict[str, Tensor]:
- """Aggregate the F1 Score and Exact match for the batch.
- Return:
- Dictionary containing the F1 score, Exact match score for the batch.
- """
- exact_match = 100.0 * exact_match / total
- f1 = 100.0 * f1 / total
- return {"exact_match": exact_match, "f1": f1}
- def squad(preds: PREDS_TYPE, target: TARGETS_TYPE) -> dict[str, Tensor]:
- """Calculate `SQuAD Metric`_ .
- Args:
- preds: A Dictionary or List of Dictionary-s that map `id` and `prediction_text` to the respective values.
- Example prediction:
- .. code-block:: python
- {"prediction_text": "TorchMetrics is awesome", "id": "123"}
- target: A Dictionary or List of Dictionary-s that contain the `answers` and `id` in the SQuAD Format.
- Example target:
- .. code-block:: python
- {
- 'answers': [{'answer_start': [1], 'text': ['This is a test answer']}],
- 'id': '1',
- }
- Reference SQuAD Format:
- .. code-block:: python
- {
- 'answers': {'answer_start': [1], 'text': ['This is a test text']},
- 'context': 'This is a test context.',
- 'id': '1',
- 'question': 'Is this a test?',
- 'title': 'train test'
- }
- Return:
- Dictionary containing the F1 score, Exact match score for the batch.
- Example:
- >>> from torchmetrics.functional.text.squad import squad
- >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
- >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]},"id": "56e10a3be3433e1400422b22"}]
- >>> squad(preds, target)
- {'exact_match': tensor(100.), 'f1': tensor(100.)}
- Raises:
- KeyError:
- If the required keys are missing in either predictions or targets.
- References:
- [1] SQuAD: 100,000+ Questions for Machine Comprehension of Text by Pranav Rajpurkar, Jian Zhang, Konstantin
- Lopyrev, Percy Liang `SQuAD Metric`_ .
- """
- preds_dict, target_dict = _squad_input_check(preds, target)
- f1, exact_match, total = _squad_update(preds_dict, target_dict)
- return _squad_compute(f1, exact_match, total)
|