hinge.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. import torch
  16. from torch import Tensor, tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.classification.confusion_matrix import (
  19. _binary_confusion_matrix_format,
  20. _binary_confusion_matrix_tensor_validation,
  21. _multiclass_confusion_matrix_format,
  22. _multiclass_confusion_matrix_tensor_validation,
  23. )
  24. from torchmetrics.utilities.compute import normalize_logits_if_needed
  25. from torchmetrics.utilities.data import to_onehot
  26. from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel
  27. def _hinge_loss_compute(measure: Tensor, total: Tensor) -> Tensor:
  28. return measure / total
  29. def _binary_hinge_loss_arg_validation(squared: bool, ignore_index: Optional[int] = None) -> None:
  30. if not isinstance(squared, bool):
  31. raise ValueError(f"Expected argument `squared` to be an bool but got {squared}")
  32. if ignore_index is not None and not isinstance(ignore_index, int):
  33. raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}")
  34. def _binary_hinge_loss_tensor_validation(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> None:
  35. _binary_confusion_matrix_tensor_validation(preds, target, ignore_index)
  36. if not preds.is_floating_point():
  37. raise ValueError(
  38. "Expected argument `preds` to be floating tensor with probabilities/logits"
  39. f" but got tensor with dtype {preds.dtype}"
  40. )
  41. def _binary_hinge_loss_update(
  42. preds: Tensor,
  43. target: Tensor,
  44. squared: bool,
  45. ) -> tuple[Tensor, Tensor]:
  46. target = target.bool()
  47. margin = torch.zeros_like(preds)
  48. margin[target] = preds[target]
  49. margin[~target] = -preds[~target]
  50. measures = 1 - margin
  51. measures = torch.clamp(measures, 0)
  52. if squared:
  53. measures = measures.pow(2)
  54. total = tensor(target.shape[0], device=target.device)
  55. return measures.sum(dim=0), total
  56. def binary_hinge_loss(
  57. preds: Tensor,
  58. target: Tensor,
  59. squared: bool = False,
  60. ignore_index: Optional[int] = None,
  61. validate_args: bool = False,
  62. ) -> Tensor:
  63. r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for binary tasks.
  64. .. math::
  65. \text{Hinge loss} = \max(0, 1 - y \times \hat{y})
  66. Where :math:`y \in {-1, 1}` is the target, and :math:`\hat{y} \in \mathbb{R}` is the prediction.
  67. Accepts the following input tensors:
  68. - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each
  69. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  70. sigmoid per element.
  71. - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
  72. only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class.
  73. Additional dimension ``...`` will be flattened into the batch dimension.
  74. Args:
  75. preds: Tensor with predictions
  76. target: Tensor with true labels
  77. squared:
  78. If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.
  79. ignore_index:
  80. Specifies a target value that is ignored and does not contribute to the metric calculation
  81. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  82. Set to ``False`` for faster computations.
  83. Example:
  84. >>> from torch import tensor
  85. >>> from torchmetrics.functional.classification import binary_hinge_loss
  86. >>> preds = tensor([0.25, 0.25, 0.55, 0.75, 0.75])
  87. >>> target = tensor([0, 0, 1, 1, 1])
  88. >>> binary_hinge_loss(preds, target)
  89. tensor(0.6900)
  90. >>> binary_hinge_loss(preds, target, squared=True)
  91. tensor(0.6905)
  92. """
  93. if validate_args:
  94. _binary_hinge_loss_arg_validation(squared, ignore_index)
  95. _binary_hinge_loss_tensor_validation(preds, target, ignore_index)
  96. preds, target = _binary_confusion_matrix_format(
  97. preds, target, threshold=0.0, ignore_index=ignore_index, convert_to_labels=False
  98. )
  99. measures, total = _binary_hinge_loss_update(preds, target, squared)
  100. return _hinge_loss_compute(measures, total)
  101. def _multiclass_hinge_loss_arg_validation(
  102. num_classes: int,
  103. squared: bool = False,
  104. multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer",
  105. ignore_index: Optional[int] = None,
  106. ) -> None:
  107. _binary_hinge_loss_arg_validation(squared, ignore_index)
  108. if not isinstance(num_classes, int) or num_classes < 2:
  109. raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}")
  110. allowed_mm = ("crammer-singer", "one-vs-all")
  111. if multiclass_mode not in allowed_mm:
  112. raise ValueError(f"Expected argument `multiclass_mode` to be one of {allowed_mm}, but got {multiclass_mode}.")
  113. def _multiclass_hinge_loss_tensor_validation(
  114. preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None
  115. ) -> None:
  116. _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index)
  117. if not preds.is_floating_point():
  118. raise ValueError(
  119. "Expected argument `preds` to be floating tensor with probabilities/logits"
  120. f" but got tensor with dtype {preds.dtype}"
  121. )
  122. def _multiclass_hinge_loss_update(
  123. preds: Tensor,
  124. target: Tensor,
  125. squared: bool,
  126. multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer",
  127. ) -> tuple[Tensor, Tensor]:
  128. preds = normalize_logits_if_needed(preds, "softmax")
  129. target = to_onehot(target, max(2, preds.shape[1])).bool()
  130. if multiclass_mode == "crammer-singer":
  131. margin = preds[target]
  132. margin -= torch.max(preds[~target].view(preds.shape[0], -1), dim=1)[0]
  133. else:
  134. target = target.bool()
  135. margin = torch.zeros_like(preds)
  136. margin[target] = preds[target]
  137. margin[~target] = -preds[~target]
  138. measures = 1 - margin
  139. measures = torch.clamp(measures, 0)
  140. if squared:
  141. measures = measures.pow(2)
  142. total = tensor(target.shape[0], device=target.device)
  143. return measures.sum(dim=0), total
  144. def multiclass_hinge_loss(
  145. preds: Tensor,
  146. target: Tensor,
  147. num_classes: int,
  148. squared: bool = False,
  149. multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer",
  150. ignore_index: Optional[int] = None,
  151. validate_args: bool = False,
  152. ) -> Tensor:
  153. r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for multiclass tasks.
  154. The metric can be computed in two ways. Either, the definition by Crammer and Singer is used:
  155. .. math::
  156. \text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right)
  157. Where :math:`y \in {0, ..., \mathrm{C}}` is the target class (where :math:`\mathrm{C}` is the number of classes),
  158. and :math:`\hat{y} \in \mathbb{R}^\mathrm{C}` is the predicted output per class. Alternatively, the metric can
  159. also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion.
  160. Accepts the following input tensors:
  161. - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
  162. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  163. softmax per sample.
  164. - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
  165. only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).
  166. Additional dimension ``...`` will be flattened into the batch dimension.
  167. Args:
  168. preds: Tensor with predictions
  169. target: Tensor with true labels
  170. num_classes: Integer specifying the number of classes
  171. squared:
  172. If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.
  173. multiclass_mode:
  174. Determines how to compute the metric
  175. ignore_index:
  176. Specifies a target value that is ignored and does not contribute to the metric calculation
  177. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  178. Set to ``False`` for faster computations.
  179. Example:
  180. >>> from torch import tensor
  181. >>> from torchmetrics.functional.classification import multiclass_hinge_loss
  182. >>> preds = tensor([[0.25, 0.20, 0.55],
  183. ... [0.55, 0.05, 0.40],
  184. ... [0.10, 0.30, 0.60],
  185. ... [0.90, 0.05, 0.05]])
  186. >>> target = tensor([0, 1, 2, 0])
  187. >>> multiclass_hinge_loss(preds, target, num_classes=3)
  188. tensor(0.9125)
  189. >>> multiclass_hinge_loss(preds, target, num_classes=3, squared=True)
  190. tensor(1.1131)
  191. >>> multiclass_hinge_loss(preds, target, num_classes=3, multiclass_mode='one-vs-all')
  192. tensor([0.8750, 1.1250, 1.1000])
  193. """
  194. if validate_args:
  195. _multiclass_hinge_loss_arg_validation(num_classes, squared, multiclass_mode, ignore_index)
  196. _multiclass_hinge_loss_tensor_validation(preds, target, num_classes, ignore_index)
  197. preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index, convert_to_labels=False)
  198. measures, total = _multiclass_hinge_loss_update(preds, target, squared, multiclass_mode)
  199. return _hinge_loss_compute(measures, total)
  200. def hinge_loss(
  201. preds: Tensor,
  202. target: Tensor,
  203. task: Literal["binary", "multiclass"],
  204. num_classes: Optional[int] = None,
  205. squared: bool = False,
  206. multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer",
  207. ignore_index: Optional[int] = None,
  208. validate_args: bool = True,
  209. ) -> Tensor:
  210. r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs).
  211. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  212. ``task`` argument to either ``'binary'`` or ``'multiclass'``. See the documentation of
  213. :func:`~torchmetrics.functional.classification.binary_hinge_loss` and
  214. :func:`~torchmetrics.functional.classification.multiclass_hinge_loss` for the specific details of
  215. each argument influence and examples.
  216. Legacy Example:
  217. >>> from torch import tensor
  218. >>> target = tensor([0, 1, 1])
  219. >>> preds = tensor([0.5, 0.7, 0.1])
  220. >>> hinge_loss(preds, target, task="binary")
  221. tensor(0.9000)
  222. >>> target = tensor([0, 1, 2])
  223. >>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
  224. >>> hinge_loss(preds, target, task="multiclass", num_classes=3)
  225. tensor(1.5551)
  226. >>> target = tensor([0, 1, 2])
  227. >>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
  228. >>> hinge_loss(preds, target, task="multiclass", num_classes=3, multiclass_mode="one-vs-all")
  229. tensor([1.3743, 1.1945, 1.2359])
  230. """
  231. task = ClassificationTaskNoMultilabel.from_str(task)
  232. if task == ClassificationTaskNoMultilabel.BINARY:
  233. return binary_hinge_loss(preds, target, squared, ignore_index, validate_args)
  234. if task == ClassificationTaskNoMultilabel.MULTICLASS:
  235. if not isinstance(num_classes, int):
  236. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  237. return multiclass_hinge_loss(preds, target, num_classes, squared, multiclass_mode, ignore_index, validate_args)
  238. raise ValueError(f"Not handled value: {task}")