| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Any, List, Optional, Sequence, Tuple, Type, Union
- from torch import Tensor
- from typing_extensions import Literal
- from torchmetrics.classification.base import _ClassificationTaskWrapper
- from torchmetrics.classification.roc import BinaryROC, MulticlassROC, MultilabelROC
- from torchmetrics.functional.classification.logauc import (
- _binary_logauc_compute,
- _reduce_logauc,
- _validate_fpr_range,
- )
- from torchmetrics.metric import Metric
- from torchmetrics.utilities.enums import ClassificationTask
- from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
- from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
- if not _MATPLOTLIB_AVAILABLE:
- __doctest_skip__ = ["BinaryLogAUC.plot", "MulticlassLogAUC.plot", "MultilabelLogAUC.plot"]
- class BinaryLogAUC(BinaryROC):
- r"""Compute the `Log AUC`_ score for binary classification tasks.
- The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
- positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
- score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
- is of high importance.
- As input to ``forward`` and ``update`` the metric accepts the following input:
- - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)`` containing probabilities or logits for
- each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
- sigmoid per element.
- - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and
- therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the
- positive class.
- As output to ``forward`` and ``compute`` the metric returns the following output:
- - ``logauc`` (:class:`~torch.Tensor`): A single scalar with the logauc score.
- Additional dimension ``...`` will be flattened into the batch dimension.
- The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
- that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
- non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
- argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
- size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
- Args:
- fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log
- AUC score.
- thresholds:
- Can be one of:
- - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
- all the data. Most accurate but also most memory consuming approach.
- - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
- 0 to 1 as bins for the calculation.
- - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
- bins for the calculation.
- ignore_index:
- Specifies a target value that is ignored and does not contribute to the metric calculation
- validate_args: bool indicating if input arguments and tensors should be validated for correctness.
- Set to ``False`` for faster computations.
- kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
- Example:
- >>> from torch import tensor
- >>> from torchmetrics.classification import BinaryLogAUC
- >>> preds = tensor([0.75, 0.05, 0.05, 0.05, 0.05])
- >>> target = tensor([1, 0, 0, 0, 0])
- >>> metric = BinaryLogAUC()
- >>> metric(preds, target)
- tensor(1.)
- """
- is_differentiable: bool = False
- higher_is_better: bool = True
- full_state_update: bool = False
- plot_lower_bound: float = 0.0
- plot_upper_bound: float = 1.0
- def __init__(
- self,
- fpr_range: Tuple[float, float] = (0.001, 0.1),
- thresholds: Optional[Union[int, List[float], Tensor]] = None,
- ignore_index: Optional[int] = None,
- validate_args: bool = False,
- **kwargs: Any,
- ) -> None:
- super().__init__(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args, **kwargs)
- if validate_args:
- _validate_fpr_range(fpr_range)
- self.fpr_range = fpr_range
- def compute(self) -> Tensor: # type: ignore[override]
- """Computes the log AUC score."""
- fpr, tpr, _ = super().compute()
- return _binary_logauc_compute(fpr, tpr, fpr_range=self.fpr_range)
- def plot( # type: ignore[override]
- self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
- ) -> _PLOT_OUT_TYPE:
- """Plot a single or multiple values from the metric.
- Args:
- val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
- If no value is provided, will automatically call `metric.compute` and plot that result.
- ax: An matplotlib axis object. If provided will add plot to that axis
- Returns:
- Figure and Axes object
- Raises:
- ModuleNotFoundError:
- If `matplotlib` is not installed
- .. plot::
- :scale: 75
- >>> # Example plotting a single
- >>> import torch
- >>> from torchmetrics.classification import BinaryLogAUC
- >>> metric = BinaryLogAUC()
- >>> metric.update(torch.rand(20,), torch.randint(2, (20,)))
- >>> fig_, ax_ = metric.plot()
- .. plot::
- :scale: 75
- >>> # Example plotting multiple values
- >>> import torch
- >>> from torchmetrics.classification import BinaryLogAUC
- >>> metric = BinaryLogAUC()
- >>> values = [ ]
- >>> for _ in range(10):
- ... values.append(metric(torch.rand(20,), torch.randint(2, (20,))))
- >>> fig_, ax_ = metric.plot(values)
- """
- return self._plot(val, ax)
- class MulticlassLogAUC(MulticlassROC):
- r"""Compute the `Log AUC`_ score for multiclass classification tasks.
- The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
- positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
- score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
- is of high importance.
- As input to ``forward`` and ``update`` the metric accepts the following input:
- - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)`` containing probabilities or logits
- for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto
- apply softmax per sample.
- - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and
- therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).
- As output to ``forward`` and ``compute`` the metric returns the following output:
- - ``logauc`` (:class:`~torch.Tensor`): If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will
- be returned with logauc score per class. If `average="macro"` then a single scalar is returned.
- Additional dimension ``...`` will be flattened into the batch dimension.
- The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
- that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
- non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
- argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
- size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
- Args:
- num_classes: Integer specifying the number of classes
- fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log
- AUC score.
- average:
- Defines the reduction that is applied over classes. Should be one of the following:
- - ``"macro"``: Calculate score for each class and average them
- - ``"weighted"``: calculates score for each class and computes weighted average using their support
- - ``"none"`` or ``None``: calculates score for each class and applies no reduction
- thresholds:
- Can be one of:
- - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
- all the data. Most accurate but also most memory consuming approach.
- - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
- 0 to 1 as bins for the calculation.
- - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
- bins for the calculation.
- validate_args: bool indicating if input arguments and tensors should be validated for correctness.
- Set to ``False`` for faster computations.
- kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
- Example:
- >>> from torch import tensor
- >>> from torchmetrics.classification import MulticlassLogAUC
- >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
- ... [0.05, 0.75, 0.05, 0.05, 0.05],
- ... [0.05, 0.05, 0.75, 0.05, 0.05],
- ... [0.05, 0.05, 0.05, 0.75, 0.05]])
- >>> target = tensor([0, 1, 3, 2])
- >>> metric = MulticlassLogAUC(num_classes=5, average="macro", thresholds=None)
- >>> metric(preds, target)
- tensor(0.4000)
- >>> metric = MulticlassLogAUC(num_classes=5, average=None, thresholds=None)
- >>> metric(preds, target)
- tensor([1., 1., 0., 0., 0.])
- """
- is_differentiable: bool = False
- higher_is_better: bool = True
- full_state_update: bool = False
- plot_lower_bound: float = 0.0
- plot_upper_bound: float = 1.0
- plot_legend_name: str = "Class"
- def __init__(
- self,
- num_classes: int,
- fpr_range: Tuple[float, float] = (0.001, 0.1),
- average: Optional[Literal["macro", "none"]] = None,
- thresholds: Optional[Union[int, List[float], Tensor]] = None,
- ignore_index: Optional[int] = None,
- validate_args: bool = True,
- **kwargs: Any,
- ) -> None:
- super().__init__(
- num_classes=num_classes,
- thresholds=thresholds,
- average=None,
- ignore_index=ignore_index,
- validate_args=validate_args,
- **kwargs,
- )
- if validate_args:
- _validate_fpr_range(fpr_range)
- self.fpr_range = fpr_range
- self.average2 = average # self.average is already used by parent class
- def compute(self) -> Tensor: # type: ignore[override]
- """Computes the log AUC score."""
- fpr, tpr, _ = super().compute()
- return _reduce_logauc(fpr, tpr, fpr_range=self.fpr_range, average=self.average2)
- def plot( # type: ignore[override]
- self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
- ) -> _PLOT_OUT_TYPE:
- """Plot a single or multiple values from the metric.
- Args:
- val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
- If no value is provided, will automatically call `metric.compute` and plot that result.
- ax: An matplotlib axis object. If provided will add plot to that axis
- Returns:
- Figure and Axes object
- Raises:
- ModuleNotFoundError:
- If `matplotlib` is not installed
- .. plot::
- :scale: 75
- >>> # Example plotting a single
- >>> import torch
- >>> from torchmetrics.classification import MulticlassLogAUC
- >>> metric = MulticlassLogAUC(num_classes=3)
- >>> metric.update(torch.randn(20, 3), torch.randint(3,(20,)))
- >>> fig_, ax_ = metric.plot()
- .. plot::
- :scale: 75
- >>> # Example plotting multiple values
- >>> import torch
- >>> from torchmetrics.classification import MulticlassLogAUC
- >>> metric = MulticlassLogAUC(num_classes=3)
- >>> values = [ ]
- >>> for _ in range(10):
- ... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,))))
- >>> fig_, ax_ = metric.plot(values)
- """
- return self._plot(val, ax)
- class MultilabelLogAUC(MultilabelROC):
- r"""Compute the `Log AUC`_ score for multiclass classification tasks.
- The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
- positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
- score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
- is of high importance.
- As input to ``forward`` and ``update`` the metric accepts the following input:
- - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)`` containing probabilities or logits
- for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto
- apply sigmoid per element.
- - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)`` containing ground truth labels, and
- therefore only contain {0,1} values (except if `ignore_index` is specified).
- As output to ``forward`` and ``compute`` the metric returns the following output:
- - ``logauc`` (:class:`~torch.Tensor`): If `average=None|"none"` then a 1d tensor of shape (num_labels, ) will
- be returned with logauc score per class. If `average="macro"` then a single scalar is returned.
- Additional dimension ``...`` will be flattened into the batch dimension.
- The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
- that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
- non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
- argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
- size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
- Args:
- num_labels: Integer specifying the number of labels
- fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log
- AUC score.
- average:
- Defines the reduction that is applied over labels. Should be one of the following:
- - ``"macro"``: Calculate the score for each label and average them
- - ``"none"`` or ``None``: calculates score for each label and applies no reduction
- thresholds:
- Can be one of:
- - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
- all the data. Most accurate but also most memory consuming approach.
- - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
- 0 to 1 as bins for the calculation.
- - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
- bins for the calculation.
- validate_args: bool indicating if input arguments and tensors should be validated for correctness.
- Set to ``False`` for faster computations.
- kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
- Example:
- >>> from torch import tensor
- >>> from torchmetrics.classification import MultilabelLogAUC
- >>> preds = tensor([[0.75, 0.05, 0.35],
- ... [0.45, 0.75, 0.05],
- ... [0.05, 0.55, 0.75],
- ... [0.05, 0.65, 0.05]])
- >>> target = tensor([[1, 0, 1],
- ... [0, 0, 0],
- ... [0, 1, 1],
- ... [1, 1, 1]])
- >>> metric = MultilabelLogAUC(num_labels=3, average="macro", thresholds=None)
- >>> metric(preds, target)
- tensor(0.3945)
- >>> metric = MultilabelLogAUC(num_labels=3, average=None, thresholds=None)
- >>> metric(preds, target)
- tensor([0.5000, 0.0000, 0.6835])
- """
- is_differentiable: bool = False
- higher_is_better: bool = True
- full_state_update: bool = False
- plot_lower_bound: float = 0.0
- plot_upper_bound: float = 1.0
- plot_legend_name: str = "Label"
- def __init__(
- self,
- num_labels: int,
- fpr_range: Tuple[float, float] = (0.001, 0.1),
- average: Optional[Literal["macro", "none"]] = None,
- thresholds: Optional[Union[int, List[float], Tensor]] = None,
- ignore_index: Optional[int] = None,
- validate_args: bool = True,
- **kwargs: Any,
- ) -> None:
- if validate_args:
- _validate_fpr_range(fpr_range)
- self.fpr_range = fpr_range
- self.average2 = average # self.average is already used by parent class
- super().__init__(
- num_labels=num_labels,
- thresholds=thresholds,
- ignore_index=ignore_index,
- validate_args=validate_args,
- **kwargs,
- )
- def compute(self) -> Tensor: # type: ignore[override]
- """Computes the log AUC score."""
- fpr, tpr, _ = super().compute()
- return _reduce_logauc(fpr, tpr, fpr_range=self.fpr_range, average=self.average2)
- def plot( # type: ignore[override]
- self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
- ) -> _PLOT_OUT_TYPE:
- """Plot a single or multiple values from the metric.
- Args:
- val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
- If no value is provided, will automatically call `metric.compute` and plot that result.
- ax: An matplotlib axis object. If provided will add plot to that axis
- Returns:
- Figure and Axes object
- Raises:
- ModuleNotFoundError:
- If `matplotlib` is not installed
- .. plot::
- :scale: 75
- >>> # Example plotting a single
- >>> import torch
- >>> from torchmetrics.classification import MultilabelLogAUC
- >>> metric = MultilabelLogAUC(num_labels=3)
- >>> metric.update(torch.rand(20,3), torch.randint(2, (20,3)))
- >>> fig_, ax_ = metric.plot()
- .. plot::
- :scale: 75
- >>> # Example plotting multiple values
- >>> import torch
- >>> from torchmetrics.classification import MultilabelLogAUC
- >>> metric = MultilabelLogAUC(num_labels=3)
- >>> values = [ ]
- >>> for _ in range(10):
- ... values.append(metric(torch.rand(20,3), torch.randint(2, (20,3))))
- >>> fig_, ax_ = metric.plot(values)
- """
- return self._plot(val, ax)
- class LogAUC(_ClassificationTaskWrapper):
- r"""Compute the `Log AUC`_ score for multiclass classification tasks.
- The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
- positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
- score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
- is of high importance.
- This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the
- ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
- :class:`~torchmetrics.classification.BinaryLogAUC`, :class:`~torchmetrics.classification.MulticlassLogAUC` and
- :class:`~torchmetrics.classification.MultilabelLogAUC` for the specific details of each argument influence and
- examples.
- """
- def __new__( # type: ignore[misc]
- cls: Type["LogAUC"],
- task: Literal["binary", "multiclass", "multilabel"],
- thresholds: Optional[Union[int, List[float], Tensor]] = None,
- fpr_range: Optional[Tuple[float, float]] = (0.001, 0.1),
- num_classes: Optional[int] = None,
- num_labels: Optional[int] = None,
- ignore_index: Optional[int] = None,
- validate_args: bool = True,
- **kwargs: Any,
- ) -> Metric:
- """Initialize task metric."""
- task = ClassificationTask.from_str(task)
- kwargs.update({
- "thresholds": thresholds,
- "fpr_range": fpr_range,
- "ignore_index": ignore_index,
- "validate_args": validate_args,
- })
- if task == ClassificationTask.BINARY:
- return BinaryLogAUC(**kwargs)
- if task == ClassificationTask.MULTICLASS:
- if not isinstance(num_classes, int):
- raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
- return MulticlassLogAUC(num_classes, **kwargs)
- if task == ClassificationTask.MULTILABEL:
- if not isinstance(num_labels, int):
- raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
- return MultilabelLogAUC(num_labels, **kwargs)
- raise ValueError(f"Task {task} not supported!")
|