ranking.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. import torch
  17. from torch import Tensor
  18. from torchmetrics.functional.classification.ranking import (
  19. _multilabel_confusion_matrix_arg_validation,
  20. _multilabel_confusion_matrix_format,
  21. _multilabel_coverage_error_update,
  22. _multilabel_ranking_average_precision_update,
  23. _multilabel_ranking_loss_update,
  24. _multilabel_ranking_tensor_validation,
  25. _ranking_reduce,
  26. )
  27. from torchmetrics.metric import Metric
  28. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  29. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  30. if not _MATPLOTLIB_AVAILABLE:
  31. __doctest_skip__ = [
  32. "MultilabelCoverageError.plot",
  33. "MultilabelRankingAveragePrecision.plot",
  34. "MultilabelRankingLoss.plot",
  35. ]
  36. class MultilabelCoverageError(Metric):
  37. """Compute `Multilabel coverage error`_.
  38. The score measure how far we need to go through the ranked scores to cover all true labels. The best value is equal
  39. to the average number of labels in the target tensor per sample.
  40. As input to ``forward`` and ``update`` the metric accepts the following input:
  41. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor
  42. containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider
  43. the input to be logits and will auto apply sigmoid per element.
  44. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``. Target should be a tensor
  45. containing ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified).
  46. .. tip::
  47. Additional dimension ``...`` will be flattened into the batch dimension.
  48. As output to ``forward`` and ``compute`` the metric returns the following output:
  49. - ``mlce`` (:class:`~torch.Tensor`): A tensor containing the multilabel coverage error.
  50. Args:
  51. num_labels: Integer specifying the number of labels
  52. ignore_index:
  53. Specifies a target value that is ignored and does not contribute to the metric calculation
  54. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  55. Set to ``False`` for faster computations.
  56. Example:
  57. >>> from torch import rand, randint
  58. >>> from torchmetrics.classification import MultilabelCoverageError
  59. >>> preds = rand(10, 5)
  60. >>> target = randint(2, (10, 5))
  61. >>> mlce = MultilabelCoverageError(num_labels=5)
  62. >>> mlce(preds, target)
  63. tensor(3.9000)
  64. """
  65. higher_is_better: bool = False
  66. is_differentiable: bool = False
  67. full_state_update: bool = False
  68. plot_lower_bound: float = 0.0
  69. plot_upper_bound: float = 1.0
  70. plot_legend_name: str = "Label"
  71. def __init__(
  72. self,
  73. num_labels: int,
  74. ignore_index: Optional[int] = None,
  75. validate_args: bool = True,
  76. **kwargs: Any,
  77. ) -> None:
  78. super().__init__(**kwargs)
  79. if validate_args:
  80. _multilabel_confusion_matrix_arg_validation(num_labels, threshold=0.0, ignore_index=ignore_index)
  81. self.validate_args = validate_args
  82. self.num_labels = num_labels
  83. self.ignore_index = ignore_index
  84. self.add_state("measure", torch.tensor(0.0), dist_reduce_fx="sum")
  85. self.add_state("total", torch.tensor(0.0), dist_reduce_fx="sum")
  86. def update(self, preds: Tensor, target: Tensor) -> None:
  87. """Update metric states."""
  88. if self.validate_args:
  89. _multilabel_ranking_tensor_validation(preds, target, self.num_labels, self.ignore_index)
  90. preds, target = _multilabel_confusion_matrix_format(
  91. preds, target, self.num_labels, threshold=0.0, ignore_index=self.ignore_index, should_threshold=False
  92. )
  93. measure, num_elements = _multilabel_coverage_error_update(preds, target)
  94. if not isinstance(self.measure, Tensor):
  95. raise TypeError(f"Expected 'self.measure' to be of type Tensor, but got {type(self.measure)}.")
  96. if not isinstance(self.total, Tensor):
  97. raise TypeError(f"Expected 'self.total' to be of type Tensor, but got {type(self.total)}.")
  98. self.measure += measure
  99. self.total += num_elements
  100. def compute(self) -> Tensor:
  101. """Compute metric."""
  102. if not isinstance(self.measure, Tensor):
  103. raise TypeError(f"Expected 'self.measure' to be of type Tensor, but got {type(self.measure)}.")
  104. if not isinstance(self.total, Tensor):
  105. raise TypeError(f"Expected 'self.total' to be of type Tensor, but got {type(self.total)}.")
  106. return _ranking_reduce(self.measure, int(self.total.item()))
  107. def plot(
  108. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  109. ) -> _PLOT_OUT_TYPE:
  110. """Plot a single or multiple values from the metric.
  111. Args:
  112. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  113. If no value is provided, will automatically call `metric.compute` and plot that result.
  114. ax: An matplotlib axis object. If provided will add plot to that axis
  115. Returns:
  116. Figure object and Axes object
  117. Raises:
  118. ModuleNotFoundError:
  119. If `matplotlib` is not installed
  120. .. plot::
  121. :scale: 75
  122. >>> from torch import rand, randint
  123. >>> # Example plotting a single value
  124. >>> from torchmetrics.classification import MultilabelCoverageError
  125. >>> metric = MultilabelCoverageError(num_labels=3)
  126. >>> metric.update(rand(20, 3), randint(2, (20, 3)))
  127. >>> fig_, ax_ = metric.plot()
  128. .. plot::
  129. :scale: 75
  130. >>> from torch import rand, randint
  131. >>> # Example plotting multiple values
  132. >>> from torchmetrics.classification import MultilabelCoverageError
  133. >>> metric = MultilabelCoverageError(num_labels=3)
  134. >>> values = [ ]
  135. >>> for _ in range(10):
  136. ... values.append(metric(rand(20, 3), randint(2, (20, 3))))
  137. >>> fig_, ax_ = metric.plot(values)
  138. """
  139. return self._plot(val, ax)
  140. class MultilabelRankingAveragePrecision(Metric):
  141. """Compute label ranking average precision score for multilabel data [1].
  142. The score is the average over each ground truth label assigned to each sample of the ratio of true vs. total labels
  143. with lower score. Best score is 1.
  144. As input to ``forward`` and ``update`` the metric accepts the following input:
  145. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor
  146. containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider
  147. the input to be logits and will auto apply sigmoid per element.
  148. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``. Target should be a tensor
  149. containing ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified).
  150. .. tip::
  151. Additional dimension ``...`` will be flattened into the batch dimension.
  152. As output to ``forward`` and ``compute`` the metric returns the following output:
  153. - ``mlrap`` (:class:`~torch.Tensor`): A tensor containing the multilabel ranking average precision.
  154. Args:
  155. num_labels: Integer specifying the number of labels
  156. ignore_index:
  157. Specifies a target value that is ignored and does not contribute to the metric calculation
  158. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  159. Set to ``False`` for faster computations.
  160. Example:
  161. >>> from torch import rand, randint
  162. >>> from torchmetrics.classification import MultilabelRankingAveragePrecision
  163. >>> preds = rand(10, 5)
  164. >>> target = randint(2, (10, 5))
  165. >>> mlrap = MultilabelRankingAveragePrecision(num_labels=5)
  166. >>> mlrap(preds, target)
  167. tensor(0.7744)
  168. """
  169. higher_is_better: bool = True
  170. is_differentiable: bool = False
  171. full_state_update: bool = False
  172. plot_lower_bound: float = 0.0
  173. plot_upper_bound: float = 1.0
  174. plot_legend_name: str = "Label"
  175. def __init__(
  176. self,
  177. num_labels: int,
  178. ignore_index: Optional[int] = None,
  179. validate_args: bool = True,
  180. **kwargs: Any,
  181. ) -> None:
  182. super().__init__(**kwargs)
  183. if validate_args:
  184. _multilabel_confusion_matrix_arg_validation(num_labels, threshold=0.0, ignore_index=ignore_index)
  185. self.validate_args = validate_args
  186. self.num_labels = num_labels
  187. self.ignore_index = ignore_index
  188. self.add_state("measure", torch.tensor(0.0), dist_reduce_fx="sum")
  189. self.add_state("total", torch.tensor(0.0), dist_reduce_fx="sum")
  190. def update(self, preds: Tensor, target: Tensor) -> None:
  191. """Update metric states."""
  192. if self.validate_args:
  193. _multilabel_ranking_tensor_validation(preds, target, self.num_labels, self.ignore_index)
  194. preds, target = _multilabel_confusion_matrix_format(
  195. preds, target, self.num_labels, threshold=0.0, ignore_index=self.ignore_index, should_threshold=False
  196. )
  197. if not isinstance(self.measure, Tensor):
  198. raise TypeError(f"Expected 'self.measure' to be of type Tensor, but got {type(self.measure)}.")
  199. if not isinstance(self.total, Tensor):
  200. raise TypeError(f"Expected 'self.total' to be of type Tensor, but got {type(self.total)}.")
  201. measure, num_elements = _multilabel_ranking_average_precision_update(preds, target)
  202. self.measure += measure
  203. self.total += num_elements
  204. def compute(self) -> Tensor:
  205. """Compute metric."""
  206. if not isinstance(self.measure, Tensor):
  207. raise TypeError(f"Expected 'self.measure' to be of type Tensor, but got {type(self.measure)}.")
  208. if not isinstance(self.total, Tensor):
  209. raise TypeError(f"Expected 'self.total' to be of type Tensor, but got {type(self.total)}.")
  210. return _ranking_reduce(self.measure, int(self.total.item()))
  211. def plot(
  212. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  213. ) -> _PLOT_OUT_TYPE:
  214. """Plot a single or multiple values from the metric.
  215. Args:
  216. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  217. If no value is provided, will automatically call `metric.compute` and plot that result.
  218. ax: An matplotlib axis object. If provided will add plot to that axis
  219. Returns:
  220. Figure object and Axes object
  221. Raises:
  222. ModuleNotFoundError:
  223. If `matplotlib` is not installed
  224. .. plot::
  225. :scale: 75
  226. >>> from torch import rand, randint
  227. >>> # Example plotting a single value
  228. >>> from torchmetrics.classification import MultilabelRankingAveragePrecision
  229. >>> metric = MultilabelRankingAveragePrecision(num_labels=3)
  230. >>> metric.update(rand(20, 3), randint(2, (20, 3)))
  231. >>> fig_, ax_ = metric.plot()
  232. .. plot::
  233. :scale: 75
  234. >>> from torch import rand, randint
  235. >>> # Example plotting multiple values
  236. >>> from torchmetrics.classification import MultilabelRankingAveragePrecision
  237. >>> metric = MultilabelRankingAveragePrecision(num_labels=3)
  238. >>> values = [ ]
  239. >>> for _ in range(10):
  240. ... values.append(metric(rand(20, 3), randint(2, (20, 3))))
  241. >>> fig_, ax_ = metric.plot(values)
  242. """
  243. return self._plot(val, ax)
  244. class MultilabelRankingLoss(Metric):
  245. """Compute the label ranking loss for multilabel data [1].
  246. The score is corresponds to the average number of label pairs that are incorrectly ordered given some predictions
  247. weighted by the size of the label set and the number of labels not in the label set. The best score is 0.
  248. As input to ``forward`` and ``update`` the metric accepts the following input:
  249. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor
  250. containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider
  251. the input to be logits and will auto apply sigmoid per element.
  252. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``. Target should be a tensor
  253. containing ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified).
  254. .. tip::
  255. Additional dimension ``...`` will be flattened into the batch dimension.
  256. As output to ``forward`` and ``compute`` the metric returns the following output:
  257. - ``mlrl`` (:class:`~torch.Tensor`): A tensor containing the multilabel ranking loss.
  258. Args:
  259. preds: Tensor with predictions
  260. target: Tensor with true labels
  261. num_labels: Integer specifying the number of labels
  262. ignore_index:
  263. Specifies a target value that is ignored and does not contribute to the metric calculation
  264. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  265. Set to ``False`` for faster computations.
  266. Example:
  267. >>> from torch import rand, randint
  268. >>> from torchmetrics.classification import MultilabelRankingLoss
  269. >>> preds = rand(10, 5)
  270. >>> target = randint(2, (10, 5))
  271. >>> mlrl = MultilabelRankingLoss(num_labels=5)
  272. >>> mlrl(preds, target)
  273. tensor(0.4167)
  274. """
  275. higher_is_better: bool = False
  276. is_differentiable: bool = False
  277. full_state_update: bool = False
  278. plot_lower_bound: float = 0.0
  279. plot_upper_bound: float = 1.0
  280. plot_legend_name: str = "Label"
  281. def __init__(
  282. self,
  283. num_labels: int,
  284. ignore_index: Optional[int] = None,
  285. validate_args: bool = True,
  286. **kwargs: Any,
  287. ) -> None:
  288. super().__init__(**kwargs)
  289. if validate_args:
  290. _multilabel_confusion_matrix_arg_validation(num_labels, threshold=0.0, ignore_index=ignore_index)
  291. self.validate_args = validate_args
  292. self.num_labels = num_labels
  293. self.ignore_index = ignore_index
  294. self.add_state("measure", torch.tensor(0.0), dist_reduce_fx="sum")
  295. self.add_state("total", torch.tensor(0.0), dist_reduce_fx="sum")
  296. def update(self, preds: Tensor, target: Tensor) -> None:
  297. """Update metric states."""
  298. if self.validate_args:
  299. _multilabel_ranking_tensor_validation(preds, target, self.num_labels, self.ignore_index)
  300. preds, target = _multilabel_confusion_matrix_format(
  301. preds, target, self.num_labels, threshold=0.0, ignore_index=self.ignore_index, should_threshold=False
  302. )
  303. if not isinstance(self.measure, Tensor):
  304. raise TypeError(f"Expected 'self.measure' to be of type Tensor, but got {type(self.measure)}.")
  305. if not isinstance(self.total, Tensor):
  306. raise TypeError(f"Expected 'self.total' to be of type Tensor, but got {type(self.total)}.")
  307. measure, num_elements = _multilabel_ranking_loss_update(preds, target)
  308. self.measure += measure
  309. self.total += num_elements
  310. def compute(self) -> Tensor:
  311. """Compute metric."""
  312. if not isinstance(self.measure, Tensor):
  313. raise TypeError(f"Expected 'self.measure' to be of type Tensor, but got {type(self.measure)}.")
  314. if not isinstance(self.total, Tensor):
  315. raise TypeError(f"Expected 'self.total' to be of type Tensor, but got {type(self.total)}.")
  316. return _ranking_reduce(self.measure, int(self.total.item()))
  317. def plot(
  318. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  319. ) -> _PLOT_OUT_TYPE:
  320. """Plot a single or multiple values from the metric.
  321. Args:
  322. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  323. If no value is provided, will automatically call `metric.compute` and plot that result.
  324. ax: An matplotlib axis object. If provided will add plot to that axis
  325. Returns:
  326. Figure object and Axes object
  327. Raises:
  328. ModuleNotFoundError:
  329. If `matplotlib` is not installed
  330. .. plot::
  331. :scale: 75
  332. >>> from torch import rand, randint
  333. >>> # Example plotting a single value
  334. >>> from torchmetrics.classification import MultilabelRankingLoss
  335. >>> metric = MultilabelRankingLoss(num_labels=3)
  336. >>> metric.update(rand(20, 3), randint(2, (20, 3)))
  337. >>> fig_, ax_ = metric.plot()
  338. .. plot::
  339. :scale: 75
  340. >>> from torch import rand, randint
  341. >>> # Example plotting multiple values
  342. >>> from torchmetrics.classification import MultilabelRankingLoss
  343. >>> metric = MultilabelRankingLoss(num_labels=3)
  344. >>> values = [ ]
  345. >>> for _ in range(10):
  346. ... values.append(metric(rand(20, 3), randint(2, (20, 3))))
  347. >>> fig_, ax_ = metric.plot(values)
  348. """
  349. return self._plot(val, ax)