| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import List, Optional, Union
- import torch
- from torch import Tensor
- from typing_extensions import Literal
- from torchmetrics.functional.classification.precision_recall_curve import (
- _binary_precision_recall_curve_arg_validation,
- _binary_precision_recall_curve_compute,
- _binary_precision_recall_curve_format,
- _binary_precision_recall_curve_tensor_validation,
- _binary_precision_recall_curve_update,
- _multiclass_precision_recall_curve_arg_validation,
- _multiclass_precision_recall_curve_compute,
- _multiclass_precision_recall_curve_format,
- _multiclass_precision_recall_curve_tensor_validation,
- _multiclass_precision_recall_curve_update,
- _multilabel_precision_recall_curve_arg_validation,
- _multilabel_precision_recall_curve_compute,
- _multilabel_precision_recall_curve_format,
- _multilabel_precision_recall_curve_tensor_validation,
- _multilabel_precision_recall_curve_update,
- )
- from torchmetrics.utilities.compute import _safe_divide
- from torchmetrics.utilities.data import _bincount
- from torchmetrics.utilities.enums import ClassificationTask
- from torchmetrics.utilities.prints import rank_zero_warn
- def _reduce_average_precision(
- precision: Union[Tensor, List[Tensor]],
- recall: Union[Tensor, List[Tensor]],
- average: Optional[Literal["macro", "weighted", "none"]] = "macro",
- weights: Optional[Tensor] = None,
- ) -> Tensor:
- """Reduce multiple average precision score into one number."""
- if isinstance(precision, Tensor) and isinstance(recall, Tensor):
- precision = torch.where(torch.isnan(precision), torch.zeros_like(precision), precision)
- recall = torch.where(torch.isnan(recall), torch.zeros_like(recall), recall)
- res = -torch.sum((recall[:, 1:] - recall[:, :-1]) * precision[:, :-1], 1)
- else:
- res = torch.stack([-torch.sum((r[1:] - r[:-1]) * p[:-1]) for p, r in zip(precision, recall)])
- if average is None or average == "none":
- return res
- if torch.isnan(res).any():
- rank_zero_warn(
- f"Average precision score for one or more classes was `nan`. Ignoring these classes in {average}-average",
- UserWarning,
- )
- idx = ~torch.isnan(res)
- if average == "macro":
- return res[idx].mean()
- if average == "weighted" and weights is not None:
- weights = _safe_divide(weights[idx], weights[idx].sum())
- return (res[idx] * weights).sum()
- raise ValueError("Received an incompatible combinations of inputs to make reduction.")
- def _binary_average_precision_compute(
- state: Union[Tensor, tuple[Tensor, Tensor]],
- thresholds: Optional[Tensor],
- ) -> Tensor:
- precision, recall, _ = _binary_precision_recall_curve_compute(state, thresholds)
- precision = torch.where(torch.isnan(precision), torch.zeros_like(precision), precision)
- recall = torch.where(torch.isnan(recall), torch.zeros_like(recall), recall)
- return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1])
- def binary_average_precision(
- preds: Tensor,
- target: Tensor,
- thresholds: Optional[Union[int, list[float], Tensor]] = None,
- ignore_index: Optional[int] = None,
- validate_args: bool = True,
- ) -> Tensor:
- r"""Compute the average precision (AP) score for binary tasks.
- The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the
- difference in recall from the previous threshold as weight:
- .. math::
- AP = \sum{n} (R_n - R_{n-1}) P_n
- where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is
- equivalent to the area under the precision-recall curve (AUPRC).
- Accepts the following input tensors:
- - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each
- observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
- sigmoid per element.
- - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
- only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class.
- Additional dimension ``...`` will be flattened into the batch dimension.
- The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
- that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
- non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
- argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
- size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
- Args:
- preds: Tensor with predictions
- target: Tensor with true labels
- thresholds:
- Can be one of:
- - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
- all the data. Most accurate but also most memory consuming approach.
- - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
- 0 to 1 as bins for the calculation.
- - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
- bins for the calculation.
- ignore_index:
- Specifies a target value that is ignored and does not contribute to the metric calculation
- validate_args: bool indicating if input arguments and tensors should be validated for correctness.
- Set to ``False`` for faster computations.
- Returns:
- A single scalar with the average precision score
- Example:
- >>> from torchmetrics.functional.classification import binary_average_precision
- >>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
- >>> target = torch.tensor([0, 1, 1, 0])
- >>> binary_average_precision(preds, target, thresholds=None)
- tensor(0.5833)
- >>> binary_average_precision(preds, target, thresholds=5)
- tensor(0.6667)
- """
- if validate_args:
- _binary_precision_recall_curve_arg_validation(thresholds, ignore_index)
- _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index)
- preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index)
- state = _binary_precision_recall_curve_update(preds, target, thresholds)
- return _binary_average_precision_compute(state, thresholds)
- def _multiclass_average_precision_arg_validation(
- num_classes: int,
- average: Optional[Literal["macro", "weighted", "none"]] = "macro",
- thresholds: Optional[Union[int, list[float], Tensor]] = None,
- ignore_index: Optional[int] = None,
- ) -> None:
- _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index)
- allowed_average = ("macro", "weighted", "none", None)
- if average not in allowed_average:
- raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}")
- def _multiclass_average_precision_compute(
- state: Union[Tensor, tuple[Tensor, Tensor]],
- num_classes: int,
- average: Optional[Literal["macro", "weighted", "none"]] = "macro",
- thresholds: Optional[Tensor] = None,
- ) -> Tensor:
- precision, recall, _ = _multiclass_precision_recall_curve_compute(state, num_classes, thresholds)
- return _reduce_average_precision(
- precision,
- recall,
- average,
- weights=_bincount(state[1], minlength=num_classes).float() if thresholds is None else state[0][:, 1, :].sum(-1),
- )
- def multiclass_average_precision(
- preds: Tensor,
- target: Tensor,
- num_classes: int,
- average: Optional[Literal["macro", "weighted", "none"]] = "macro",
- thresholds: Optional[Union[int, list[float], Tensor]] = None,
- ignore_index: Optional[int] = None,
- validate_args: bool = True,
- ) -> Tensor:
- r"""Compute the average precision (AP) score for multiclass tasks.
- The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the
- difference in recall from the previous threshold as weight:
- .. math::
- AP = \sum{n} (R_n - R_{n-1}) P_n
- where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is
- equivalent to the area under the precision-recall curve (AUPRC).
- Accepts the following input tensors:
- - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
- observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
- softmax per sample.
- - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
- only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).
- Additional dimension ``...`` will be flattened into the batch dimension.
- The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
- that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
- non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
- argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
- size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory).
- Args:
- preds: Tensor with predictions
- target: Tensor with true labels
- num_classes: Integer specifying the number of classes
- average:
- Defines the reduction that is applied over classes. Should be one of the following:
- - ``macro``: Calculate score for each class and average them
- - ``weighted``: calculates score for each class and computes weighted average using their support
- - ``"none"`` or ``None``: calculates score for each class and applies no reduction
- thresholds:
- Can be one of:
- - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
- all the data. Most accurate but also most memory consuming approach.
- - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
- 0 to 1 as bins for the calculation.
- - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
- bins for the calculation.
- ignore_index:
- Specifies a target value that is ignored and does not contribute to the metric calculation
- validate_args: bool indicating if input arguments and tensors should be validated for correctness.
- Set to ``False`` for faster computations.
- Returns:
- If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with AP score per class.
- If `average="macro"|"weighted"` then a single scalar is returned.
- Example:
- >>> from torchmetrics.functional.classification import multiclass_average_precision
- >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
- ... [0.05, 0.75, 0.05, 0.05, 0.05],
- ... [0.05, 0.05, 0.75, 0.05, 0.05],
- ... [0.05, 0.05, 0.05, 0.75, 0.05]])
- >>> target = torch.tensor([0, 1, 3, 2])
- >>> multiclass_average_precision(preds, target, num_classes=5, average="macro", thresholds=None)
- tensor(0.6250)
- >>> multiclass_average_precision(preds, target, num_classes=5, average=None, thresholds=None)
- tensor([1.0000, 1.0000, 0.2500, 0.2500, nan])
- >>> multiclass_average_precision(preds, target, num_classes=5, average="macro", thresholds=5)
- tensor(0.5000)
- >>> multiclass_average_precision(preds, target, num_classes=5, average=None, thresholds=5)
- tensor([1.0000, 1.0000, 0.2500, 0.2500, -0.0000])
- """
- if validate_args:
- _multiclass_average_precision_arg_validation(num_classes, average, thresholds, ignore_index)
- _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index)
- preds, target, thresholds = _multiclass_precision_recall_curve_format(
- preds, target, num_classes, thresholds, ignore_index
- )
- state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds)
- return _multiclass_average_precision_compute(state, num_classes, average, thresholds)
- def _multilabel_average_precision_arg_validation(
- num_labels: int,
- average: Optional[Literal["micro", "macro", "weighted", "none"]],
- thresholds: Optional[Union[int, list[float], Tensor]] = None,
- ignore_index: Optional[int] = None,
- ) -> None:
- _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index)
- allowed_average = ("micro", "macro", "weighted", "none", None)
- if average not in allowed_average:
- raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}")
- def _multilabel_average_precision_compute(
- state: Union[Tensor, tuple[Tensor, Tensor]],
- num_labels: int,
- average: Optional[Literal["micro", "macro", "weighted", "none"]],
- thresholds: Optional[Tensor],
- ignore_index: Optional[int] = None,
- ) -> Tensor:
- if average == "micro":
- if isinstance(state, Tensor) and thresholds is not None:
- state = state.sum(1)
- else:
- preds, target = state[0].flatten(), state[1].flatten()
- if ignore_index is not None:
- idx = target == ignore_index
- preds = preds[~idx]
- target = target[~idx]
- state = (preds, target)
- return _binary_average_precision_compute(state, thresholds)
- precision, recall, _ = _multilabel_precision_recall_curve_compute(state, num_labels, thresholds, ignore_index)
- return _reduce_average_precision(
- precision,
- recall,
- average,
- weights=(state[1] == 1).sum(dim=0).float() if thresholds is None else state[0][:, 1, :].sum(-1),
- )
- def multilabel_average_precision(
- preds: Tensor,
- target: Tensor,
- num_labels: int,
- average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
- thresholds: Optional[Union[int, list[float], Tensor]] = None,
- ignore_index: Optional[int] = None,
- validate_args: bool = True,
- ) -> Tensor:
- r"""Compute the average precision (AP) score for multilabel tasks.
- The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the
- difference in recall from the previous threshold as weight:
- .. math::
- AP = \sum{n} (R_n - R_{n-1}) P_n
- where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is
- equivalent to the area under the precision-recall curve (AUPRC).
- Accepts the following input tensors:
- - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
- observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
- sigmoid per element.
- - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore
- only contain {0,1} values (except if `ignore_index` is specified).
- Additional dimension ``...`` will be flattened into the batch dimension.
- The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
- that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
- non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
- argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
- size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
- Args:
- preds: Tensor with predictions
- target: Tensor with true labels
- num_labels: Integer specifying the number of labels
- average:
- Defines the reduction that is applied over labels. Should be one of the following:
- - ``micro``: Sum score over all labels
- - ``macro``: Calculate score for each label and average them
- - ``weighted``: calculates score for each label and computes weighted average using their support
- - ``"none"`` or ``None``: calculates score for each label and applies no reduction
- thresholds:
- Can be one of:
- - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
- all the data. Most accurate but also most memory consuming approach.
- - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
- 0 to 1 as bins for the calculation.
- - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
- bins for the calculation.
- ignore_index:
- Specifies a target value that is ignored and does not contribute to the metric calculation
- validate_args: bool indicating if input arguments and tensors should be validated for correctness.
- Set to ``False`` for faster computations.
- Returns:
- If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with AP score per class.
- If `average="micro|macro"|"weighted"` then a single scalar is returned.
- Example:
- >>> from torchmetrics.functional.classification import multilabel_average_precision
- >>> preds = torch.tensor([[0.75, 0.05, 0.35],
- ... [0.45, 0.75, 0.05],
- ... [0.05, 0.55, 0.75],
- ... [0.05, 0.65, 0.05]])
- >>> target = torch.tensor([[1, 0, 1],
- ... [0, 0, 0],
- ... [0, 1, 1],
- ... [1, 1, 1]])
- >>> multilabel_average_precision(preds, target, num_labels=3, average="macro", thresholds=None)
- tensor(0.7500)
- >>> multilabel_average_precision(preds, target, num_labels=3, average=None, thresholds=None)
- tensor([0.7500, 0.5833, 0.9167])
- >>> multilabel_average_precision(preds, target, num_labels=3, average="macro", thresholds=5)
- tensor(0.7778)
- >>> multilabel_average_precision(preds, target, num_labels=3, average=None, thresholds=5)
- tensor([0.7500, 0.6667, 0.9167])
- """
- if validate_args:
- _multilabel_average_precision_arg_validation(num_labels, average, thresholds, ignore_index)
- _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index)
- preds, target, thresholds = _multilabel_precision_recall_curve_format(
- preds, target, num_labels, thresholds, ignore_index
- )
- state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds)
- return _multilabel_average_precision_compute(state, num_labels, average, thresholds, ignore_index)
- def average_precision(
- preds: Tensor,
- target: Tensor,
- task: Literal["binary", "multiclass", "multilabel"],
- thresholds: Optional[Union[int, list[float], Tensor]] = None,
- num_classes: Optional[int] = None,
- num_labels: Optional[int] = None,
- average: Optional[Literal["macro", "weighted", "none"]] = "macro",
- ignore_index: Optional[int] = None,
- validate_args: bool = True,
- ) -> Optional[Tensor]:
- r"""Compute the average precision (AP) score.
- The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the
- difference in recall from the previous threshold as weight:
- .. math::
- AP = \sum{n} (R_n - R_{n-1}) P_n
- where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is
- equivalent to the area under the precision-recall curve (AUPRC).
- This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
- ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
- :func:`~torchmetrics.functional.classification.binary_average_precision`,
- :func:`~torchmetrics.functional.classification.multiclass_average_precision` and
- :func:`~torchmetrics.functional.classification.multilabel_average_precision`
- for the specific details of each argument influence and examples.
- Legacy Example:
- >>> from torchmetrics.functional.classification import average_precision
- >>> pred = torch.tensor([0.0, 1.0, 2.0, 3.0])
- >>> target = torch.tensor([0, 1, 1, 1])
- >>> average_precision(pred, target, task="binary")
- tensor(1.)
- >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
- ... [0.05, 0.75, 0.05, 0.05, 0.05],
- ... [0.05, 0.05, 0.75, 0.05, 0.05],
- ... [0.05, 0.05, 0.05, 0.75, 0.05]])
- >>> target = torch.tensor([0, 1, 3, 2])
- >>> average_precision(pred, target, task="multiclass", num_classes=5, average=None)
- tensor([1.0000, 1.0000, 0.2500, 0.2500, nan])
- """
- task = ClassificationTask.from_str(task)
- if task == ClassificationTask.BINARY:
- return binary_average_precision(preds, target, thresholds, ignore_index, validate_args)
- if task == ClassificationTask.MULTICLASS:
- if not isinstance(num_classes, int):
- raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
- return multiclass_average_precision(
- preds, target, num_classes, average, thresholds, ignore_index, validate_args
- )
- if task == ClassificationTask.MULTILABEL:
- if not isinstance(num_labels, int):
- raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
- return multilabel_average_precision(preds, target, num_labels, average, thresholds, ignore_index, validate_args)
- return None
|