exact_match.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.classification.stat_scores import (
  19. _multiclass_stat_scores_arg_validation,
  20. _multiclass_stat_scores_format,
  21. _multiclass_stat_scores_tensor_validation,
  22. _multilabel_stat_scores_arg_validation,
  23. _multilabel_stat_scores_format,
  24. _multilabel_stat_scores_tensor_validation,
  25. )
  26. from torchmetrics.utilities.compute import _safe_divide
  27. from torchmetrics.utilities.enums import ClassificationTaskNoBinary
  28. def _exact_match_reduce(
  29. correct: Tensor,
  30. total: Tensor,
  31. ) -> Tensor:
  32. """Reduce exact match."""
  33. return _safe_divide(correct, total)
  34. def _multiclass_exact_match_update(
  35. preds: Tensor,
  36. target: Tensor,
  37. multidim_average: Literal["global", "samplewise"] = "global",
  38. ignore_index: Optional[int] = None,
  39. ) -> tuple[Tensor, Tensor]:
  40. """Compute the statistics."""
  41. if ignore_index is not None:
  42. preds = preds.clone()
  43. preds[target == ignore_index] = ignore_index
  44. correct = (preds == target).sum(1) == preds.shape[1]
  45. correct = correct if multidim_average == "samplewise" else correct.sum()
  46. total = torch.tensor(preds.shape[0] if multidim_average == "global" else 1, device=correct.device)
  47. return correct, total
  48. def multiclass_exact_match(
  49. preds: Tensor,
  50. target: Tensor,
  51. num_classes: int,
  52. multidim_average: Literal["global", "samplewise"] = "global",
  53. ignore_index: Optional[int] = None,
  54. validate_args: bool = True,
  55. ) -> Tensor:
  56. r"""Compute Exact match (also known as subset accuracy) for multiclass tasks.
  57. Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be
  58. correctly classified.
  59. Accepts the following input tensors:
  60. - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point
  61. we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into
  62. an int tensor.
  63. - ``target`` (int tensor): ``(N, ...)``
  64. Args:
  65. preds: Tensor with predictions
  66. target: Tensor with true labels
  67. num_classes: Integer specifying the number of labels
  68. multidim_average:
  69. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  70. - ``global``: Additional dimensions are flatted along the batch dimension
  71. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  72. The statistics in this case are calculated over the additional dimensions.
  73. ignore_index:
  74. Specifies a target value that is ignored and does not contribute to the metric calculation
  75. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  76. Set to ``False`` for faster computations.
  77. Returns:
  78. The returned shape depends on the ``multidim_average`` argument:
  79. - If ``multidim_average`` is set to ``global`` the output will be a scalar tensor
  80. - If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)``
  81. Example (multidim tensors):
  82. >>> from torch import tensor
  83. >>> from torchmetrics.functional.classification import multiclass_exact_match
  84. >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
  85. >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
  86. >>> multiclass_exact_match(preds, target, num_classes=3, multidim_average='global')
  87. tensor(0.5000)
  88. Example (multidim tensors):
  89. >>> from torchmetrics.functional.classification import multiclass_exact_match
  90. >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
  91. >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
  92. >>> multiclass_exact_match(preds, target, num_classes=3, multidim_average='samplewise')
  93. tensor([1., 0.])
  94. """
  95. top_k, average = 1, None
  96. if validate_args:
  97. _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index)
  98. _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index)
  99. preds, target = _multiclass_stat_scores_format(preds, target, top_k)
  100. correct, total = _multiclass_exact_match_update(preds, target, multidim_average, ignore_index)
  101. return _exact_match_reduce(correct, total)
  102. def _multilabel_exact_match_update(
  103. preds: Tensor,
  104. target: Tensor,
  105. num_labels: int,
  106. multidim_average: Literal["global", "samplewise"] = "global",
  107. ignore_index: Optional[int] = None,
  108. ) -> tuple[Tensor, Tensor]:
  109. """Compute the statistics."""
  110. if ignore_index is not None:
  111. mask = target == -1
  112. target = torch.where(mask, preds.long(), target)
  113. if multidim_average == "global":
  114. preds = torch.movedim(preds, 1, -1).reshape(-1, num_labels)
  115. target = torch.movedim(target, 1, -1).reshape(-1, num_labels)
  116. correct = ((preds == target).sum(1) == num_labels).sum(dim=-1)
  117. total = torch.tensor(preds.shape[0 if multidim_average == "global" else 2], device=correct.device)
  118. return correct, total
  119. def multilabel_exact_match(
  120. preds: Tensor,
  121. target: Tensor,
  122. num_labels: int,
  123. threshold: float = 0.5,
  124. multidim_average: Literal["global", "samplewise"] = "global",
  125. ignore_index: Optional[int] = None,
  126. validate_args: bool = True,
  127. ) -> Tensor:
  128. r"""Compute Exact match (also known as subset accuracy) for multilabel tasks.
  129. Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be
  130. correctly classified.
  131. Accepts the following input tensors:
  132. - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside
  133. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  134. we convert to int tensor with thresholding using the value in ``threshold``.
  135. - ``target`` (int tensor): ``(N, C, ...)``
  136. Args:
  137. preds: Tensor with predictions
  138. target: Tensor with true labels
  139. num_labels: Integer specifying the number of labels
  140. threshold: Threshold for transforming probability to binary (0,1) predictions
  141. multidim_average:
  142. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  143. - ``global``: Additional dimensions are flatted along the batch dimension
  144. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  145. The statistics in this case are calculated over the additional dimensions.
  146. ignore_index:
  147. Specifies a target value that is ignored and does not contribute to the metric calculation
  148. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  149. Set to ``False`` for faster computations.
  150. Returns:
  151. The returned shape depends on the ``multidim_average`` argument:
  152. - If ``multidim_average`` is set to ``global`` the output will be a scalar tensor
  153. - If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)``
  154. Example (preds is int tensor):
  155. >>> from torch import tensor
  156. >>> from torchmetrics.functional.classification import multilabel_exact_match
  157. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  158. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  159. >>> multilabel_exact_match(preds, target, num_labels=3)
  160. tensor(0.5000)
  161. Example (preds is float tensor):
  162. >>> from torchmetrics.functional.classification import multilabel_exact_match
  163. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  164. >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
  165. >>> multilabel_exact_match(preds, target, num_labels=3)
  166. tensor(0.5000)
  167. Example (multidim tensors):
  168. >>> from torchmetrics.functional.classification import multilabel_exact_match
  169. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  170. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  171. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  172. >>> multilabel_exact_match(preds, target, num_labels=3, multidim_average='samplewise')
  173. tensor([0., 0.])
  174. """
  175. average = None
  176. if validate_args:
  177. _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index)
  178. _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
  179. preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
  180. correct, total = _multilabel_exact_match_update(preds, target, num_labels, multidim_average, ignore_index)
  181. return _exact_match_reduce(correct, total)
  182. def exact_match(
  183. preds: Tensor,
  184. target: Tensor,
  185. task: Literal["multiclass", "multilabel"],
  186. num_classes: Optional[int] = None,
  187. num_labels: Optional[int] = None,
  188. threshold: float = 0.5,
  189. multidim_average: Literal["global", "samplewise"] = "global",
  190. ignore_index: Optional[int] = None,
  191. validate_args: bool = True,
  192. ) -> Tensor:
  193. r"""Compute Exact match (also known as subset accuracy).
  194. Exact Match is a stricter version of accuracy where all classes/labels have to match exactly for the sample to be
  195. correctly classified.
  196. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  197. ``task`` argument to either ``'multiclass'`` or ``'multilabel'``. See the documentation of
  198. :func:`~torchmetrics.functional.classification.multiclass_exact_match` and
  199. :func:`~torchmetrics.functional.classification.multilabel_exact_match` for the specific details of
  200. each argument influence and examples.
  201. Legacy Example:
  202. >>> from torch import tensor
  203. >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
  204. >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
  205. >>> exact_match(preds, target, task="multiclass", num_classes=3, multidim_average='global')
  206. tensor(0.5000)
  207. >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
  208. >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
  209. >>> exact_match(preds, target, task="multiclass", num_classes=3, multidim_average='samplewise')
  210. tensor([1., 0.])
  211. """
  212. task = ClassificationTaskNoBinary.from_str(task)
  213. if task == ClassificationTaskNoBinary.MULTICLASS:
  214. assert num_classes is not None # noqa: S101 # needed for mypy
  215. return multiclass_exact_match(preds, target, num_classes, multidim_average, ignore_index, validate_args)
  216. if task == ClassificationTaskNoBinary.MULTILABEL:
  217. assert num_labels is not None # noqa: S101 # needed for mypy
  218. return multilabel_exact_match(
  219. preds, target, num_labels, threshold, multidim_average, ignore_index, validate_args
  220. )
  221. raise ValueError(f"Not handled value: {task}")