| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Optional
- import torch
- from torch import Tensor
- from typing_extensions import Literal
- from torchmetrics.functional.classification.stat_scores import (
- _multiclass_stat_scores_arg_validation,
- _multiclass_stat_scores_format,
- _multiclass_stat_scores_tensor_validation,
- _multilabel_stat_scores_arg_validation,
- _multilabel_stat_scores_format,
- _multilabel_stat_scores_tensor_validation,
- )
- from torchmetrics.utilities.compute import _safe_divide
- from torchmetrics.utilities.enums import ClassificationTaskNoBinary
- def _exact_match_reduce(
- correct: Tensor,
- total: Tensor,
- ) -> Tensor:
- """Reduce exact match."""
- return _safe_divide(correct, total)
- def _multiclass_exact_match_update(
- preds: Tensor,
- target: Tensor,
- multidim_average: Literal["global", "samplewise"] = "global",
- ignore_index: Optional[int] = None,
- ) -> tuple[Tensor, Tensor]:
- """Compute the statistics."""
- if ignore_index is not None:
- preds = preds.clone()
- preds[target == ignore_index] = ignore_index
- correct = (preds == target).sum(1) == preds.shape[1]
- correct = correct if multidim_average == "samplewise" else correct.sum()
- total = torch.tensor(preds.shape[0] if multidim_average == "global" else 1, device=correct.device)
- return correct, total
- def multiclass_exact_match(
- preds: Tensor,
- target: Tensor,
- num_classes: int,
- multidim_average: Literal["global", "samplewise"] = "global",
- ignore_index: Optional[int] = None,
- validate_args: bool = True,
- ) -> Tensor:
- r"""Compute Exact match (also known as subset accuracy) for multiclass tasks.
- Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be
- correctly classified.
- Accepts the following input tensors:
- - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point
- we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into
- an int tensor.
- - ``target`` (int tensor): ``(N, ...)``
- Args:
- preds: Tensor with predictions
- target: Tensor with true labels
- num_classes: Integer specifying the number of labels
- multidim_average:
- Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
- - ``global``: Additional dimensions are flatted along the batch dimension
- - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
- The statistics in this case are calculated over the additional dimensions.
- ignore_index:
- Specifies a target value that is ignored and does not contribute to the metric calculation
- validate_args: bool indicating if input arguments and tensors should be validated for correctness.
- Set to ``False`` for faster computations.
- Returns:
- The returned shape depends on the ``multidim_average`` argument:
- - If ``multidim_average`` is set to ``global`` the output will be a scalar tensor
- - If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)``
- Example (multidim tensors):
- >>> from torch import tensor
- >>> from torchmetrics.functional.classification import multiclass_exact_match
- >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
- >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
- >>> multiclass_exact_match(preds, target, num_classes=3, multidim_average='global')
- tensor(0.5000)
- Example (multidim tensors):
- >>> from torchmetrics.functional.classification import multiclass_exact_match
- >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
- >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
- >>> multiclass_exact_match(preds, target, num_classes=3, multidim_average='samplewise')
- tensor([1., 0.])
- """
- top_k, average = 1, None
- if validate_args:
- _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index)
- _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index)
- preds, target = _multiclass_stat_scores_format(preds, target, top_k)
- correct, total = _multiclass_exact_match_update(preds, target, multidim_average, ignore_index)
- return _exact_match_reduce(correct, total)
- def _multilabel_exact_match_update(
- preds: Tensor,
- target: Tensor,
- num_labels: int,
- multidim_average: Literal["global", "samplewise"] = "global",
- ignore_index: Optional[int] = None,
- ) -> tuple[Tensor, Tensor]:
- """Compute the statistics."""
- if ignore_index is not None:
- mask = target == -1
- target = torch.where(mask, preds.long(), target)
- if multidim_average == "global":
- preds = torch.movedim(preds, 1, -1).reshape(-1, num_labels)
- target = torch.movedim(target, 1, -1).reshape(-1, num_labels)
- correct = ((preds == target).sum(1) == num_labels).sum(dim=-1)
- total = torch.tensor(preds.shape[0 if multidim_average == "global" else 2], device=correct.device)
- return correct, total
- def multilabel_exact_match(
- preds: Tensor,
- target: Tensor,
- num_labels: int,
- threshold: float = 0.5,
- multidim_average: Literal["global", "samplewise"] = "global",
- ignore_index: Optional[int] = None,
- validate_args: bool = True,
- ) -> Tensor:
- r"""Compute Exact match (also known as subset accuracy) for multilabel tasks.
- Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be
- correctly classified.
- Accepts the following input tensors:
- - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside
- [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
- we convert to int tensor with thresholding using the value in ``threshold``.
- - ``target`` (int tensor): ``(N, C, ...)``
- Args:
- preds: Tensor with predictions
- target: Tensor with true labels
- num_labels: Integer specifying the number of labels
- threshold: Threshold for transforming probability to binary (0,1) predictions
- multidim_average:
- Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
- - ``global``: Additional dimensions are flatted along the batch dimension
- - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
- The statistics in this case are calculated over the additional dimensions.
- ignore_index:
- Specifies a target value that is ignored and does not contribute to the metric calculation
- validate_args: bool indicating if input arguments and tensors should be validated for correctness.
- Set to ``False`` for faster computations.
- Returns:
- The returned shape depends on the ``multidim_average`` argument:
- - If ``multidim_average`` is set to ``global`` the output will be a scalar tensor
- - If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)``
- Example (preds is int tensor):
- >>> from torch import tensor
- >>> from torchmetrics.functional.classification import multilabel_exact_match
- >>> target = tensor([[0, 1, 0], [1, 0, 1]])
- >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
- >>> multilabel_exact_match(preds, target, num_labels=3)
- tensor(0.5000)
- Example (preds is float tensor):
- >>> from torchmetrics.functional.classification import multilabel_exact_match
- >>> target = tensor([[0, 1, 0], [1, 0, 1]])
- >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
- >>> multilabel_exact_match(preds, target, num_labels=3)
- tensor(0.5000)
- Example (multidim tensors):
- >>> from torchmetrics.functional.classification import multilabel_exact_match
- >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
- >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
- ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
- >>> multilabel_exact_match(preds, target, num_labels=3, multidim_average='samplewise')
- tensor([0., 0.])
- """
- average = None
- if validate_args:
- _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index)
- _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
- preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
- correct, total = _multilabel_exact_match_update(preds, target, num_labels, multidim_average, ignore_index)
- return _exact_match_reduce(correct, total)
- def exact_match(
- preds: Tensor,
- target: Tensor,
- task: Literal["multiclass", "multilabel"],
- num_classes: Optional[int] = None,
- num_labels: Optional[int] = None,
- threshold: float = 0.5,
- multidim_average: Literal["global", "samplewise"] = "global",
- ignore_index: Optional[int] = None,
- validate_args: bool = True,
- ) -> Tensor:
- r"""Compute Exact match (also known as subset accuracy).
- Exact Match is a stricter version of accuracy where all classes/labels have to match exactly for the sample to be
- correctly classified.
- This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
- ``task`` argument to either ``'multiclass'`` or ``'multilabel'``. See the documentation of
- :func:`~torchmetrics.functional.classification.multiclass_exact_match` and
- :func:`~torchmetrics.functional.classification.multilabel_exact_match` for the specific details of
- each argument influence and examples.
- Legacy Example:
- >>> from torch import tensor
- >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
- >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
- >>> exact_match(preds, target, task="multiclass", num_classes=3, multidim_average='global')
- tensor(0.5000)
- >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
- >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
- >>> exact_match(preds, target, task="multiclass", num_classes=3, multidim_average='samplewise')
- tensor([1., 0.])
- """
- task = ClassificationTaskNoBinary.from_str(task)
- if task == ClassificationTaskNoBinary.MULTICLASS:
- assert num_classes is not None # noqa: S101 # needed for mypy
- return multiclass_exact_match(preds, target, num_classes, multidim_average, ignore_index, validate_args)
- if task == ClassificationTaskNoBinary.MULTILABEL:
- assert num_labels is not None # noqa: S101 # needed for mypy
- return multilabel_exact_match(
- preds, target, num_labels, threshold, multidim_average, ignore_index, validate_args
- )
- raise ValueError(f"Not handled value: {task}")
|