| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Sequence
- from typing import Any, Optional, Union
- import torch
- from torch import Tensor
- from typing_extensions import Literal
- from torchmetrics.classification.base import _ClassificationTaskWrapper
- from torchmetrics.functional.classification.hinge import (
- _binary_confusion_matrix_format,
- _binary_hinge_loss_arg_validation,
- _binary_hinge_loss_tensor_validation,
- _binary_hinge_loss_update,
- _hinge_loss_compute,
- _multiclass_confusion_matrix_format,
- _multiclass_hinge_loss_arg_validation,
- _multiclass_hinge_loss_tensor_validation,
- _multiclass_hinge_loss_update,
- )
- from torchmetrics.metric import Metric
- from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel
- from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
- from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
- if not _MATPLOTLIB_AVAILABLE:
- __doctest_skip__ = ["BinaryHingeLoss.plot", "MulticlassHingeLoss.plot"]
- class BinaryHingeLoss(Metric):
- r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for binary tasks.
- .. math::
- \text{Hinge loss} = \max(0, 1 - y \times \hat{y})
- Where :math:`y \in {-1, 1}` is the target, and :math:`\hat{y} \in \mathbb{R}` is the prediction.
- As input to ``forward`` and ``update`` the metric accepts the following input:
- - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)``. Preds should be a tensor containing
- probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input
- to be logits and will auto apply sigmoid per element.
- - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing
- ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value
- 1 always encodes the positive class.
- .. tip::
- Additional dimension ``...`` will be flattened into the batch dimension.
- As output to ``forward`` and ``compute`` the metric returns the following output:
- - ``bhl`` (:class:`~torch.Tensor`): A tensor containing the hinge loss.
- Args:
- squared:
- If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.
- ignore_index:
- Specifies a target value that is ignored and does not contribute to the metric calculation
- validate_args: bool indicating if input arguments and tensors should be validated for correctness.
- Set to ``False`` for faster computations.
- kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
- Example:
- >>> from torchmetrics.classification import BinaryHingeLoss
- >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75])
- >>> target = torch.tensor([0, 0, 1, 1, 1])
- >>> bhl = BinaryHingeLoss()
- >>> bhl(preds, target)
- tensor(0.6900)
- >>> bhl = BinaryHingeLoss(squared=True)
- >>> bhl(preds, target)
- tensor(0.6905)
- """
- is_differentiable: bool = True
- higher_is_better: bool = False
- full_state_update: bool = False
- plot_lower_bound: float = 0.0
- plot_upper_bound: float = 1.0
- measures: Tensor
- total: Tensor
- def __init__(
- self,
- squared: bool = False,
- ignore_index: Optional[int] = None,
- validate_args: bool = True,
- **kwargs: Any,
- ) -> None:
- super().__init__(**kwargs)
- if validate_args:
- _binary_hinge_loss_arg_validation(squared, ignore_index)
- self.validate_args = validate_args
- self.squared = squared
- self.ignore_index = ignore_index
- self.add_state("measures", default=torch.tensor(0.0), dist_reduce_fx="sum")
- self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
- def update(self, preds: Tensor, target: Tensor) -> None:
- """Update metric state."""
- if self.validate_args:
- _binary_hinge_loss_tensor_validation(preds, target, self.ignore_index)
- preds, target = _binary_confusion_matrix_format(
- preds, target, threshold=0.0, ignore_index=self.ignore_index, convert_to_labels=False
- )
- measures, total = _binary_hinge_loss_update(preds, target, self.squared)
- self.measures += measures
- self.total += total
- def compute(self) -> Tensor:
- """Compute metric."""
- return _hinge_loss_compute(self.measures, self.total)
- def plot(
- self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
- ) -> _PLOT_OUT_TYPE:
- """Plot a single or multiple values from the metric.
- Args:
- val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
- If no value is provided, will automatically call `metric.compute` and plot that result.
- ax: An matplotlib axis object. If provided will add plot to that axis
- Returns:
- Figure object and Axes object
- Raises:
- ModuleNotFoundError:
- If `matplotlib` is not installed
- .. plot::
- :scale: 75
- >>> # Example plotting a single value
- >>> from torch import rand, randint
- >>> from torchmetrics.classification import BinaryHingeLoss
- >>> metric = BinaryHingeLoss()
- >>> metric.update(rand(10), randint(2,(10,)))
- >>> fig_, ax_ = metric.plot()
- .. plot::
- :scale: 75
- >>> # Example plotting multiple values
- >>> from torch import rand, randint
- >>> from torchmetrics.classification import BinaryHingeLoss
- >>> metric = BinaryHingeLoss()
- >>> values = [ ]
- >>> for _ in range(10):
- ... values.append(metric(rand(10), randint(2,(10,))))
- >>> fig_, ax_ = metric.plot(values)
- """
- return self._plot(val, ax)
- class MulticlassHingeLoss(Metric):
- r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for multiclass tasks.
- The metric can be computed in two ways. Either, the definition by Crammer and Singer is used:
- .. math::
- \text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right)
- Where :math:`y \in {0, ..., \mathrm{C}}` is the target class (where :math:`\mathrm{C}` is the number of classes),
- and :math:`\hat{y} \in \mathbb{R}^\mathrm{C}` is the predicted output per class. Alternatively, the metric can
- also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion.
- As input to ``forward`` and ``update`` the metric accepts the following input:
- - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor
- containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider
- the input to be logits and will auto apply softmax per sample.
- - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing
- ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if `ignore_index`
- is specified).
- .. tip::
- Additional dimension ``...`` will be flattened into the batch dimension.
- As output to ``forward`` and ``compute`` the metric returns the following output:
- - ``mchl`` (:class:`~torch.Tensor`): A tensor containing the multi-class hinge loss.
- Args:
- num_classes: Integer specifying the number of classes
- squared:
- If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.
- multiclass_mode:
- Determines how to compute the metric
- ignore_index:
- Specifies a target value that is ignored and does not contribute to the metric calculation
- validate_args: bool indicating if input arguments and tensors should be validated for correctness.
- Set to ``False`` for faster computations.
- kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
- Example:
- >>> from torchmetrics.classification import MulticlassHingeLoss
- >>> preds = torch.tensor([[0.25, 0.20, 0.55],
- ... [0.55, 0.05, 0.40],
- ... [0.10, 0.30, 0.60],
- ... [0.90, 0.05, 0.05]])
- >>> target = torch.tensor([0, 1, 2, 0])
- >>> mchl = MulticlassHingeLoss(num_classes=3)
- >>> mchl(preds, target)
- tensor(0.9125)
- >>> mchl = MulticlassHingeLoss(num_classes=3, squared=True)
- >>> mchl(preds, target)
- tensor(1.1131)
- >>> mchl = MulticlassHingeLoss(num_classes=3, multiclass_mode='one-vs-all')
- >>> mchl(preds, target)
- tensor([0.8750, 1.1250, 1.1000])
- """
- is_differentiable: bool = True
- higher_is_better: bool = False
- full_state_update: bool = False
- plot_lower_bound: float = 0.0
- plot_upper_bound: float = 1.0
- plot_legend_name: str = "Class"
- measures: Tensor
- total: Tensor
- def __init__(
- self,
- num_classes: int,
- squared: bool = False,
- multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer",
- ignore_index: Optional[int] = None,
- validate_args: bool = True,
- **kwargs: Any,
- ) -> None:
- super().__init__(**kwargs)
- if validate_args:
- _multiclass_hinge_loss_arg_validation(num_classes, squared, multiclass_mode, ignore_index)
- self.validate_args = validate_args
- self.num_classes = num_classes
- self.squared = squared
- self.multiclass_mode = multiclass_mode
- self.ignore_index = ignore_index
- self.add_state(
- "measures",
- default=torch.tensor(0.0)
- if self.multiclass_mode == "crammer-singer"
- else torch.zeros(
- num_classes,
- ),
- dist_reduce_fx="sum",
- )
- self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
- def update(self, preds: Tensor, target: Tensor) -> None:
- """Update metric state."""
- if self.validate_args:
- _multiclass_hinge_loss_tensor_validation(preds, target, self.num_classes, self.ignore_index)
- preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index, convert_to_labels=False)
- measures, total = _multiclass_hinge_loss_update(preds, target, self.squared, self.multiclass_mode)
- self.measures += measures
- self.total += total
- def compute(self) -> Tensor:
- """Compute metric."""
- return _hinge_loss_compute(self.measures, self.total)
- def plot(
- self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
- ) -> _PLOT_OUT_TYPE:
- """Plot a single or multiple values from the metric.
- Args:
- val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
- If no value is provided, will automatically call `metric.compute` and plot that result.
- ax: An matplotlib axis object. If provided will add plot to that axis
- Returns:
- Figure object and Axes object
- Raises:
- ModuleNotFoundError:
- If `matplotlib` is not installed
- .. plot::
- :scale: 75
- >>> # Example plotting a single value per class
- >>> from torch import randint, randn
- >>> from torchmetrics.classification import MulticlassHingeLoss
- >>> metric = MulticlassHingeLoss(num_classes=3)
- >>> metric.update(randn(20, 3), randint(3, (20,)))
- >>> fig_, ax_ = metric.plot()
- .. plot::
- :scale: 75
- >>> # Example plotting a multiple values per class
- >>> from torch import randint, randn
- >>> from torchmetrics.classification import MulticlassHingeLoss
- >>> metric = MulticlassHingeLoss(num_classes=3)
- >>> values = []
- >>> for _ in range(20):
- ... values.append(metric(randn(20, 3), randint(3, (20,))))
- >>> fig_, ax_ = metric.plot(values)
- """
- return self._plot(val, ax)
- class HingeLoss(_ClassificationTaskWrapper):
- r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs).
- This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
- ``task`` argument to either ``'binary'`` or ``'multiclass'``. See the documentation of
- :class:`~torchmetrics.classification.BinaryHingeLoss` and :class:`~torchmetrics.classification.MulticlassHingeLoss`
- for the specific details of each argument influence and examples.
- Legacy Example:
- >>> from torch import tensor
- >>> target = tensor([0, 1, 1])
- >>> preds = tensor([0.5, 0.7, 0.1])
- >>> hinge = HingeLoss(task="binary")
- >>> hinge(preds, target)
- tensor(0.9000)
- >>> target = tensor([0, 1, 2])
- >>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
- >>> hinge = HingeLoss(task="multiclass", num_classes=3)
- >>> hinge(preds, target)
- tensor(1.5551)
- >>> target = tensor([0, 1, 2])
- >>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
- >>> hinge = HingeLoss(task="multiclass", num_classes=3, multiclass_mode="one-vs-all")
- >>> hinge(preds, target)
- tensor([1.3743, 1.1945, 1.2359])
- """
- def __new__( # type: ignore[misc]
- cls: type["HingeLoss"],
- task: Literal["binary", "multiclass"],
- num_classes: Optional[int] = None,
- squared: bool = False,
- multiclass_mode: Optional[Literal["crammer-singer", "one-vs-all"]] = "crammer-singer",
- ignore_index: Optional[int] = None,
- validate_args: bool = True,
- **kwargs: Any,
- ) -> Metric:
- """Initialize task metric."""
- task = ClassificationTaskNoMultilabel.from_str(task)
- kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
- if task == ClassificationTaskNoMultilabel.BINARY:
- return BinaryHingeLoss(squared, **kwargs)
- if task == ClassificationTaskNoMultilabel.MULTICLASS:
- if not isinstance(num_classes, int):
- raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
- if multiclass_mode not in ("crammer-singer", "one-vs-all"):
- raise ValueError(
- f"`multiclass_mode` is expected to be one of 'crammer-singer' or 'one-vs-all' but "
- f"`{multiclass_mode}` was passed."
- )
- return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs)
- raise ValueError(f"Unsupported task `{task}`")
|