hinge.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. import torch
  17. from torch import Tensor
  18. from typing_extensions import Literal
  19. from torchmetrics.classification.base import _ClassificationTaskWrapper
  20. from torchmetrics.functional.classification.hinge import (
  21. _binary_confusion_matrix_format,
  22. _binary_hinge_loss_arg_validation,
  23. _binary_hinge_loss_tensor_validation,
  24. _binary_hinge_loss_update,
  25. _hinge_loss_compute,
  26. _multiclass_confusion_matrix_format,
  27. _multiclass_hinge_loss_arg_validation,
  28. _multiclass_hinge_loss_tensor_validation,
  29. _multiclass_hinge_loss_update,
  30. )
  31. from torchmetrics.metric import Metric
  32. from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel
  33. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  34. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  35. if not _MATPLOTLIB_AVAILABLE:
  36. __doctest_skip__ = ["BinaryHingeLoss.plot", "MulticlassHingeLoss.plot"]
  37. class BinaryHingeLoss(Metric):
  38. r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for binary tasks.
  39. .. math::
  40. \text{Hinge loss} = \max(0, 1 - y \times \hat{y})
  41. Where :math:`y \in {-1, 1}` is the target, and :math:`\hat{y} \in \mathbb{R}` is the prediction.
  42. As input to ``forward`` and ``update`` the metric accepts the following input:
  43. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)``. Preds should be a tensor containing
  44. probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input
  45. to be logits and will auto apply sigmoid per element.
  46. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing
  47. ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value
  48. 1 always encodes the positive class.
  49. .. tip::
  50. Additional dimension ``...`` will be flattened into the batch dimension.
  51. As output to ``forward`` and ``compute`` the metric returns the following output:
  52. - ``bhl`` (:class:`~torch.Tensor`): A tensor containing the hinge loss.
  53. Args:
  54. squared:
  55. If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.
  56. ignore_index:
  57. Specifies a target value that is ignored and does not contribute to the metric calculation
  58. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  59. Set to ``False`` for faster computations.
  60. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  61. Example:
  62. >>> from torchmetrics.classification import BinaryHingeLoss
  63. >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75])
  64. >>> target = torch.tensor([0, 0, 1, 1, 1])
  65. >>> bhl = BinaryHingeLoss()
  66. >>> bhl(preds, target)
  67. tensor(0.6900)
  68. >>> bhl = BinaryHingeLoss(squared=True)
  69. >>> bhl(preds, target)
  70. tensor(0.6905)
  71. """
  72. is_differentiable: bool = True
  73. higher_is_better: bool = False
  74. full_state_update: bool = False
  75. plot_lower_bound: float = 0.0
  76. plot_upper_bound: float = 1.0
  77. measures: Tensor
  78. total: Tensor
  79. def __init__(
  80. self,
  81. squared: bool = False,
  82. ignore_index: Optional[int] = None,
  83. validate_args: bool = True,
  84. **kwargs: Any,
  85. ) -> None:
  86. super().__init__(**kwargs)
  87. if validate_args:
  88. _binary_hinge_loss_arg_validation(squared, ignore_index)
  89. self.validate_args = validate_args
  90. self.squared = squared
  91. self.ignore_index = ignore_index
  92. self.add_state("measures", default=torch.tensor(0.0), dist_reduce_fx="sum")
  93. self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
  94. def update(self, preds: Tensor, target: Tensor) -> None:
  95. """Update metric state."""
  96. if self.validate_args:
  97. _binary_hinge_loss_tensor_validation(preds, target, self.ignore_index)
  98. preds, target = _binary_confusion_matrix_format(
  99. preds, target, threshold=0.0, ignore_index=self.ignore_index, convert_to_labels=False
  100. )
  101. measures, total = _binary_hinge_loss_update(preds, target, self.squared)
  102. self.measures += measures
  103. self.total += total
  104. def compute(self) -> Tensor:
  105. """Compute metric."""
  106. return _hinge_loss_compute(self.measures, self.total)
  107. def plot(
  108. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  109. ) -> _PLOT_OUT_TYPE:
  110. """Plot a single or multiple values from the metric.
  111. Args:
  112. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  113. If no value is provided, will automatically call `metric.compute` and plot that result.
  114. ax: An matplotlib axis object. If provided will add plot to that axis
  115. Returns:
  116. Figure object and Axes object
  117. Raises:
  118. ModuleNotFoundError:
  119. If `matplotlib` is not installed
  120. .. plot::
  121. :scale: 75
  122. >>> # Example plotting a single value
  123. >>> from torch import rand, randint
  124. >>> from torchmetrics.classification import BinaryHingeLoss
  125. >>> metric = BinaryHingeLoss()
  126. >>> metric.update(rand(10), randint(2,(10,)))
  127. >>> fig_, ax_ = metric.plot()
  128. .. plot::
  129. :scale: 75
  130. >>> # Example plotting multiple values
  131. >>> from torch import rand, randint
  132. >>> from torchmetrics.classification import BinaryHingeLoss
  133. >>> metric = BinaryHingeLoss()
  134. >>> values = [ ]
  135. >>> for _ in range(10):
  136. ... values.append(metric(rand(10), randint(2,(10,))))
  137. >>> fig_, ax_ = metric.plot(values)
  138. """
  139. return self._plot(val, ax)
  140. class MulticlassHingeLoss(Metric):
  141. r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for multiclass tasks.
  142. The metric can be computed in two ways. Either, the definition by Crammer and Singer is used:
  143. .. math::
  144. \text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right)
  145. Where :math:`y \in {0, ..., \mathrm{C}}` is the target class (where :math:`\mathrm{C}` is the number of classes),
  146. and :math:`\hat{y} \in \mathbb{R}^\mathrm{C}` is the predicted output per class. Alternatively, the metric can
  147. also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion.
  148. As input to ``forward`` and ``update`` the metric accepts the following input:
  149. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor
  150. containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider
  151. the input to be logits and will auto apply softmax per sample.
  152. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing
  153. ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if `ignore_index`
  154. is specified).
  155. .. tip::
  156. Additional dimension ``...`` will be flattened into the batch dimension.
  157. As output to ``forward`` and ``compute`` the metric returns the following output:
  158. - ``mchl`` (:class:`~torch.Tensor`): A tensor containing the multi-class hinge loss.
  159. Args:
  160. num_classes: Integer specifying the number of classes
  161. squared:
  162. If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.
  163. multiclass_mode:
  164. Determines how to compute the metric
  165. ignore_index:
  166. Specifies a target value that is ignored and does not contribute to the metric calculation
  167. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  168. Set to ``False`` for faster computations.
  169. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  170. Example:
  171. >>> from torchmetrics.classification import MulticlassHingeLoss
  172. >>> preds = torch.tensor([[0.25, 0.20, 0.55],
  173. ... [0.55, 0.05, 0.40],
  174. ... [0.10, 0.30, 0.60],
  175. ... [0.90, 0.05, 0.05]])
  176. >>> target = torch.tensor([0, 1, 2, 0])
  177. >>> mchl = MulticlassHingeLoss(num_classes=3)
  178. >>> mchl(preds, target)
  179. tensor(0.9125)
  180. >>> mchl = MulticlassHingeLoss(num_classes=3, squared=True)
  181. >>> mchl(preds, target)
  182. tensor(1.1131)
  183. >>> mchl = MulticlassHingeLoss(num_classes=3, multiclass_mode='one-vs-all')
  184. >>> mchl(preds, target)
  185. tensor([0.8750, 1.1250, 1.1000])
  186. """
  187. is_differentiable: bool = True
  188. higher_is_better: bool = False
  189. full_state_update: bool = False
  190. plot_lower_bound: float = 0.0
  191. plot_upper_bound: float = 1.0
  192. plot_legend_name: str = "Class"
  193. measures: Tensor
  194. total: Tensor
  195. def __init__(
  196. self,
  197. num_classes: int,
  198. squared: bool = False,
  199. multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer",
  200. ignore_index: Optional[int] = None,
  201. validate_args: bool = True,
  202. **kwargs: Any,
  203. ) -> None:
  204. super().__init__(**kwargs)
  205. if validate_args:
  206. _multiclass_hinge_loss_arg_validation(num_classes, squared, multiclass_mode, ignore_index)
  207. self.validate_args = validate_args
  208. self.num_classes = num_classes
  209. self.squared = squared
  210. self.multiclass_mode = multiclass_mode
  211. self.ignore_index = ignore_index
  212. self.add_state(
  213. "measures",
  214. default=torch.tensor(0.0)
  215. if self.multiclass_mode == "crammer-singer"
  216. else torch.zeros(
  217. num_classes,
  218. ),
  219. dist_reduce_fx="sum",
  220. )
  221. self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
  222. def update(self, preds: Tensor, target: Tensor) -> None:
  223. """Update metric state."""
  224. if self.validate_args:
  225. _multiclass_hinge_loss_tensor_validation(preds, target, self.num_classes, self.ignore_index)
  226. preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index, convert_to_labels=False)
  227. measures, total = _multiclass_hinge_loss_update(preds, target, self.squared, self.multiclass_mode)
  228. self.measures += measures
  229. self.total += total
  230. def compute(self) -> Tensor:
  231. """Compute metric."""
  232. return _hinge_loss_compute(self.measures, self.total)
  233. def plot(
  234. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  235. ) -> _PLOT_OUT_TYPE:
  236. """Plot a single or multiple values from the metric.
  237. Args:
  238. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  239. If no value is provided, will automatically call `metric.compute` and plot that result.
  240. ax: An matplotlib axis object. If provided will add plot to that axis
  241. Returns:
  242. Figure object and Axes object
  243. Raises:
  244. ModuleNotFoundError:
  245. If `matplotlib` is not installed
  246. .. plot::
  247. :scale: 75
  248. >>> # Example plotting a single value per class
  249. >>> from torch import randint, randn
  250. >>> from torchmetrics.classification import MulticlassHingeLoss
  251. >>> metric = MulticlassHingeLoss(num_classes=3)
  252. >>> metric.update(randn(20, 3), randint(3, (20,)))
  253. >>> fig_, ax_ = metric.plot()
  254. .. plot::
  255. :scale: 75
  256. >>> # Example plotting a multiple values per class
  257. >>> from torch import randint, randn
  258. >>> from torchmetrics.classification import MulticlassHingeLoss
  259. >>> metric = MulticlassHingeLoss(num_classes=3)
  260. >>> values = []
  261. >>> for _ in range(20):
  262. ... values.append(metric(randn(20, 3), randint(3, (20,))))
  263. >>> fig_, ax_ = metric.plot(values)
  264. """
  265. return self._plot(val, ax)
  266. class HingeLoss(_ClassificationTaskWrapper):
  267. r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs).
  268. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  269. ``task`` argument to either ``'binary'`` or ``'multiclass'``. See the documentation of
  270. :class:`~torchmetrics.classification.BinaryHingeLoss` and :class:`~torchmetrics.classification.MulticlassHingeLoss`
  271. for the specific details of each argument influence and examples.
  272. Legacy Example:
  273. >>> from torch import tensor
  274. >>> target = tensor([0, 1, 1])
  275. >>> preds = tensor([0.5, 0.7, 0.1])
  276. >>> hinge = HingeLoss(task="binary")
  277. >>> hinge(preds, target)
  278. tensor(0.9000)
  279. >>> target = tensor([0, 1, 2])
  280. >>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
  281. >>> hinge = HingeLoss(task="multiclass", num_classes=3)
  282. >>> hinge(preds, target)
  283. tensor(1.5551)
  284. >>> target = tensor([0, 1, 2])
  285. >>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
  286. >>> hinge = HingeLoss(task="multiclass", num_classes=3, multiclass_mode="one-vs-all")
  287. >>> hinge(preds, target)
  288. tensor([1.3743, 1.1945, 1.2359])
  289. """
  290. def __new__( # type: ignore[misc]
  291. cls: type["HingeLoss"],
  292. task: Literal["binary", "multiclass"],
  293. num_classes: Optional[int] = None,
  294. squared: bool = False,
  295. multiclass_mode: Optional[Literal["crammer-singer", "one-vs-all"]] = "crammer-singer",
  296. ignore_index: Optional[int] = None,
  297. validate_args: bool = True,
  298. **kwargs: Any,
  299. ) -> Metric:
  300. """Initialize task metric."""
  301. task = ClassificationTaskNoMultilabel.from_str(task)
  302. kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
  303. if task == ClassificationTaskNoMultilabel.BINARY:
  304. return BinaryHingeLoss(squared, **kwargs)
  305. if task == ClassificationTaskNoMultilabel.MULTICLASS:
  306. if not isinstance(num_classes, int):
  307. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  308. if multiclass_mode not in ("crammer-singer", "one-vs-all"):
  309. raise ValueError(
  310. f"`multiclass_mode` is expected to be one of 'crammer-singer' or 'one-vs-all' but "
  311. f"`{multiclass_mode}` was passed."
  312. )
  313. return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs)
  314. raise ValueError(f"Unsupported task `{task}`")