squad.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. import torch
  17. from torch import Tensor
  18. from torchmetrics import Metric
  19. from torchmetrics.functional.text.squad import (
  20. PREDS_TYPE,
  21. TARGETS_TYPE,
  22. _squad_compute,
  23. _squad_input_check,
  24. _squad_update,
  25. )
  26. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  27. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  28. if not _MATPLOTLIB_AVAILABLE:
  29. __doctest_skip__ = ["SQuAD.plot"]
  30. class SQuAD(Metric):
  31. """Calculate `SQuAD Metric`_ which is a metric for evaluating question answering models.
  32. This metric corresponds to the scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD).
  33. As input to ``forward`` and ``update`` the metric accepts the following input:
  34. - ``preds`` (:class:`~Dict`): A Dictionary or List of Dictionary-s that map ``id`` and ``prediction_text`` to
  35. the respective values
  36. Example ``prediction``:
  37. .. code-block:: python
  38. {"prediction_text": "TorchMetrics is awesome", "id": "123"}
  39. - ``target`` (:class:`~Dict`): A Dictionary or List of Dictionary-s that contain the ``answers`` and ``id`` in
  40. the SQuAD Format.
  41. Example ``target``:
  42. .. code-block:: python
  43. {
  44. 'answers': [{'answer_start': [1], 'text': ['This is a test answer']}],
  45. 'id': '1',
  46. }
  47. Reference SQuAD Format:
  48. .. code-block:: python
  49. {
  50. 'answers': {'answer_start': [1], 'text': ['This is a test text']},
  51. 'context': 'This is a test context.',
  52. 'id': '1',
  53. 'question': 'Is this a test?',
  54. 'title': 'train test'
  55. }
  56. As output of ``forward`` and ``compute`` the metric returns the following output:
  57. - ``squad`` (:class:`~Dict`): A dictionary containing the F1 score (key: "f1"),
  58. and Exact match score (key: "exact_match") for the batch.
  59. Args:
  60. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  61. Example:
  62. >>> from torchmetrics.text import SQuAD
  63. >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
  64. >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]
  65. >>> squad = SQuAD()
  66. >>> squad(preds, target)
  67. {'exact_match': tensor(100.), 'f1': tensor(100.)}
  68. """
  69. is_differentiable: bool = False
  70. higher_is_better: bool = True
  71. full_state_update: bool = False
  72. plot_lower_bound: float = 0.0
  73. plot_upper_bound: float = 100.0
  74. f1_score: Tensor
  75. exact_match: Tensor
  76. total: Tensor
  77. def __init__(
  78. self,
  79. **kwargs: Any,
  80. ) -> None:
  81. super().__init__(**kwargs)
  82. self.add_state(name="f1_score", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum")
  83. self.add_state(name="exact_match", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum")
  84. self.add_state(name="total", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum")
  85. def update(self, preds: PREDS_TYPE, target: TARGETS_TYPE) -> None:
  86. """Update state with predictions and targets."""
  87. preds_dict, target_dict = _squad_input_check(preds, target)
  88. f1_score, exact_match, total = _squad_update(preds_dict, target_dict)
  89. self.f1_score += f1_score
  90. self.exact_match += exact_match
  91. self.total += total
  92. def compute(self) -> dict[str, Tensor]:
  93. """Aggregate the F1 Score and Exact match for the batch."""
  94. return _squad_compute(self.f1_score, self.exact_match, self.total)
  95. def plot(
  96. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  97. ) -> _PLOT_OUT_TYPE:
  98. """Plot a single or multiple values from the metric.
  99. Args:
  100. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  101. If no value is provided, will automatically call `metric.compute` and plot that result.
  102. ax: An matplotlib axis object. If provided will add plot to that axis
  103. Returns:
  104. Figure and Axes object
  105. Raises:
  106. ModuleNotFoundError:
  107. If `matplotlib` is not installed
  108. .. plot::
  109. :scale: 75
  110. >>> # Example plotting a single value
  111. >>> from torchmetrics.text import SQuAD
  112. >>> metric = SQuAD()
  113. >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
  114. >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]
  115. >>> metric.update(preds, target)
  116. >>> fig_, ax_ = metric.plot()
  117. .. plot::
  118. :scale: 75
  119. >>> # Example plotting multiple values
  120. >>> from torchmetrics.text import SQuAD
  121. >>> metric = SQuAD()
  122. >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
  123. >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]
  124. >>> values = [ ]
  125. >>> for _ in range(10):
  126. ... values.append(metric(preds, target))
  127. >>> fig_, ax_ = metric.plot(values)
  128. """
  129. return self._plot(val, ax)