logauc.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, List, Optional, Sequence, Tuple, Type, Union
  15. from torch import Tensor
  16. from typing_extensions import Literal
  17. from torchmetrics.classification.base import _ClassificationTaskWrapper
  18. from torchmetrics.classification.roc import BinaryROC, MulticlassROC, MultilabelROC
  19. from torchmetrics.functional.classification.logauc import (
  20. _binary_logauc_compute,
  21. _reduce_logauc,
  22. _validate_fpr_range,
  23. )
  24. from torchmetrics.metric import Metric
  25. from torchmetrics.utilities.enums import ClassificationTask
  26. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  27. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  28. if not _MATPLOTLIB_AVAILABLE:
  29. __doctest_skip__ = ["BinaryLogAUC.plot", "MulticlassLogAUC.plot", "MultilabelLogAUC.plot"]
  30. class BinaryLogAUC(BinaryROC):
  31. r"""Compute the `Log AUC`_ score for binary classification tasks.
  32. The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
  33. positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
  34. score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
  35. is of high importance.
  36. As input to ``forward`` and ``update`` the metric accepts the following input:
  37. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)`` containing probabilities or logits for
  38. each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  39. sigmoid per element.
  40. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and
  41. therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the
  42. positive class.
  43. As output to ``forward`` and ``compute`` the metric returns the following output:
  44. - ``logauc`` (:class:`~torch.Tensor`): A single scalar with the logauc score.
  45. Additional dimension ``...`` will be flattened into the batch dimension.
  46. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  47. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  48. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  49. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  50. size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
  51. Args:
  52. fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log
  53. AUC score.
  54. thresholds:
  55. Can be one of:
  56. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  57. all the data. Most accurate but also most memory consuming approach.
  58. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  59. 0 to 1 as bins for the calculation.
  60. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  61. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  62. bins for the calculation.
  63. ignore_index:
  64. Specifies a target value that is ignored and does not contribute to the metric calculation
  65. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  66. Set to ``False`` for faster computations.
  67. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  68. Example:
  69. >>> from torch import tensor
  70. >>> from torchmetrics.classification import BinaryLogAUC
  71. >>> preds = tensor([0.75, 0.05, 0.05, 0.05, 0.05])
  72. >>> target = tensor([1, 0, 0, 0, 0])
  73. >>> metric = BinaryLogAUC()
  74. >>> metric(preds, target)
  75. tensor(1.)
  76. """
  77. is_differentiable: bool = False
  78. higher_is_better: bool = True
  79. full_state_update: bool = False
  80. plot_lower_bound: float = 0.0
  81. plot_upper_bound: float = 1.0
  82. def __init__(
  83. self,
  84. fpr_range: Tuple[float, float] = (0.001, 0.1),
  85. thresholds: Optional[Union[int, List[float], Tensor]] = None,
  86. ignore_index: Optional[int] = None,
  87. validate_args: bool = False,
  88. **kwargs: Any,
  89. ) -> None:
  90. super().__init__(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args, **kwargs)
  91. if validate_args:
  92. _validate_fpr_range(fpr_range)
  93. self.fpr_range = fpr_range
  94. def compute(self) -> Tensor: # type: ignore[override]
  95. """Computes the log AUC score."""
  96. fpr, tpr, _ = super().compute()
  97. return _binary_logauc_compute(fpr, tpr, fpr_range=self.fpr_range)
  98. def plot( # type: ignore[override]
  99. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  100. ) -> _PLOT_OUT_TYPE:
  101. """Plot a single or multiple values from the metric.
  102. Args:
  103. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  104. If no value is provided, will automatically call `metric.compute` and plot that result.
  105. ax: An matplotlib axis object. If provided will add plot to that axis
  106. Returns:
  107. Figure and Axes object
  108. Raises:
  109. ModuleNotFoundError:
  110. If `matplotlib` is not installed
  111. .. plot::
  112. :scale: 75
  113. >>> # Example plotting a single
  114. >>> import torch
  115. >>> from torchmetrics.classification import BinaryLogAUC
  116. >>> metric = BinaryLogAUC()
  117. >>> metric.update(torch.rand(20,), torch.randint(2, (20,)))
  118. >>> fig_, ax_ = metric.plot()
  119. .. plot::
  120. :scale: 75
  121. >>> # Example plotting multiple values
  122. >>> import torch
  123. >>> from torchmetrics.classification import BinaryLogAUC
  124. >>> metric = BinaryLogAUC()
  125. >>> values = [ ]
  126. >>> for _ in range(10):
  127. ... values.append(metric(torch.rand(20,), torch.randint(2, (20,))))
  128. >>> fig_, ax_ = metric.plot(values)
  129. """
  130. return self._plot(val, ax)
  131. class MulticlassLogAUC(MulticlassROC):
  132. r"""Compute the `Log AUC`_ score for multiclass classification tasks.
  133. The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
  134. positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
  135. score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
  136. is of high importance.
  137. As input to ``forward`` and ``update`` the metric accepts the following input:
  138. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)`` containing probabilities or logits
  139. for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto
  140. apply softmax per sample.
  141. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and
  142. therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).
  143. As output to ``forward`` and ``compute`` the metric returns the following output:
  144. - ``logauc`` (:class:`~torch.Tensor`): If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will
  145. be returned with logauc score per class. If `average="macro"` then a single scalar is returned.
  146. Additional dimension ``...`` will be flattened into the batch dimension.
  147. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  148. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  149. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  150. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  151. size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
  152. Args:
  153. num_classes: Integer specifying the number of classes
  154. fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log
  155. AUC score.
  156. average:
  157. Defines the reduction that is applied over classes. Should be one of the following:
  158. - ``"macro"``: Calculate score for each class and average them
  159. - ``"weighted"``: calculates score for each class and computes weighted average using their support
  160. - ``"none"`` or ``None``: calculates score for each class and applies no reduction
  161. thresholds:
  162. Can be one of:
  163. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  164. all the data. Most accurate but also most memory consuming approach.
  165. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  166. 0 to 1 as bins for the calculation.
  167. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  168. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  169. bins for the calculation.
  170. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  171. Set to ``False`` for faster computations.
  172. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  173. Example:
  174. >>> from torch import tensor
  175. >>> from torchmetrics.classification import MulticlassLogAUC
  176. >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
  177. ... [0.05, 0.75, 0.05, 0.05, 0.05],
  178. ... [0.05, 0.05, 0.75, 0.05, 0.05],
  179. ... [0.05, 0.05, 0.05, 0.75, 0.05]])
  180. >>> target = tensor([0, 1, 3, 2])
  181. >>> metric = MulticlassLogAUC(num_classes=5, average="macro", thresholds=None)
  182. >>> metric(preds, target)
  183. tensor(0.4000)
  184. >>> metric = MulticlassLogAUC(num_classes=5, average=None, thresholds=None)
  185. >>> metric(preds, target)
  186. tensor([1., 1., 0., 0., 0.])
  187. """
  188. is_differentiable: bool = False
  189. higher_is_better: bool = True
  190. full_state_update: bool = False
  191. plot_lower_bound: float = 0.0
  192. plot_upper_bound: float = 1.0
  193. plot_legend_name: str = "Class"
  194. def __init__(
  195. self,
  196. num_classes: int,
  197. fpr_range: Tuple[float, float] = (0.001, 0.1),
  198. average: Optional[Literal["macro", "none"]] = None,
  199. thresholds: Optional[Union[int, List[float], Tensor]] = None,
  200. ignore_index: Optional[int] = None,
  201. validate_args: bool = True,
  202. **kwargs: Any,
  203. ) -> None:
  204. super().__init__(
  205. num_classes=num_classes,
  206. thresholds=thresholds,
  207. average=None,
  208. ignore_index=ignore_index,
  209. validate_args=validate_args,
  210. **kwargs,
  211. )
  212. if validate_args:
  213. _validate_fpr_range(fpr_range)
  214. self.fpr_range = fpr_range
  215. self.average2 = average # self.average is already used by parent class
  216. def compute(self) -> Tensor: # type: ignore[override]
  217. """Computes the log AUC score."""
  218. fpr, tpr, _ = super().compute()
  219. return _reduce_logauc(fpr, tpr, fpr_range=self.fpr_range, average=self.average2)
  220. def plot( # type: ignore[override]
  221. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  222. ) -> _PLOT_OUT_TYPE:
  223. """Plot a single or multiple values from the metric.
  224. Args:
  225. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  226. If no value is provided, will automatically call `metric.compute` and plot that result.
  227. ax: An matplotlib axis object. If provided will add plot to that axis
  228. Returns:
  229. Figure and Axes object
  230. Raises:
  231. ModuleNotFoundError:
  232. If `matplotlib` is not installed
  233. .. plot::
  234. :scale: 75
  235. >>> # Example plotting a single
  236. >>> import torch
  237. >>> from torchmetrics.classification import MulticlassLogAUC
  238. >>> metric = MulticlassLogAUC(num_classes=3)
  239. >>> metric.update(torch.randn(20, 3), torch.randint(3,(20,)))
  240. >>> fig_, ax_ = metric.plot()
  241. .. plot::
  242. :scale: 75
  243. >>> # Example plotting multiple values
  244. >>> import torch
  245. >>> from torchmetrics.classification import MulticlassLogAUC
  246. >>> metric = MulticlassLogAUC(num_classes=3)
  247. >>> values = [ ]
  248. >>> for _ in range(10):
  249. ... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,))))
  250. >>> fig_, ax_ = metric.plot(values)
  251. """
  252. return self._plot(val, ax)
  253. class MultilabelLogAUC(MultilabelROC):
  254. r"""Compute the `Log AUC`_ score for multiclass classification tasks.
  255. The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
  256. positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
  257. score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
  258. is of high importance.
  259. As input to ``forward`` and ``update`` the metric accepts the following input:
  260. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)`` containing probabilities or logits
  261. for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto
  262. apply sigmoid per element.
  263. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)`` containing ground truth labels, and
  264. therefore only contain {0,1} values (except if `ignore_index` is specified).
  265. As output to ``forward`` and ``compute`` the metric returns the following output:
  266. - ``logauc`` (:class:`~torch.Tensor`): If `average=None|"none"` then a 1d tensor of shape (num_labels, ) will
  267. be returned with logauc score per class. If `average="macro"` then a single scalar is returned.
  268. Additional dimension ``...`` will be flattened into the batch dimension.
  269. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  270. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  271. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  272. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  273. size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
  274. Args:
  275. num_labels: Integer specifying the number of labels
  276. fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log
  277. AUC score.
  278. average:
  279. Defines the reduction that is applied over labels. Should be one of the following:
  280. - ``"macro"``: Calculate the score for each label and average them
  281. - ``"none"`` or ``None``: calculates score for each label and applies no reduction
  282. thresholds:
  283. Can be one of:
  284. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  285. all the data. Most accurate but also most memory consuming approach.
  286. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  287. 0 to 1 as bins for the calculation.
  288. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  289. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  290. bins for the calculation.
  291. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  292. Set to ``False`` for faster computations.
  293. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  294. Example:
  295. >>> from torch import tensor
  296. >>> from torchmetrics.classification import MultilabelLogAUC
  297. >>> preds = tensor([[0.75, 0.05, 0.35],
  298. ... [0.45, 0.75, 0.05],
  299. ... [0.05, 0.55, 0.75],
  300. ... [0.05, 0.65, 0.05]])
  301. >>> target = tensor([[1, 0, 1],
  302. ... [0, 0, 0],
  303. ... [0, 1, 1],
  304. ... [1, 1, 1]])
  305. >>> metric = MultilabelLogAUC(num_labels=3, average="macro", thresholds=None)
  306. >>> metric(preds, target)
  307. tensor(0.3945)
  308. >>> metric = MultilabelLogAUC(num_labels=3, average=None, thresholds=None)
  309. >>> metric(preds, target)
  310. tensor([0.5000, 0.0000, 0.6835])
  311. """
  312. is_differentiable: bool = False
  313. higher_is_better: bool = True
  314. full_state_update: bool = False
  315. plot_lower_bound: float = 0.0
  316. plot_upper_bound: float = 1.0
  317. plot_legend_name: str = "Label"
  318. def __init__(
  319. self,
  320. num_labels: int,
  321. fpr_range: Tuple[float, float] = (0.001, 0.1),
  322. average: Optional[Literal["macro", "none"]] = None,
  323. thresholds: Optional[Union[int, List[float], Tensor]] = None,
  324. ignore_index: Optional[int] = None,
  325. validate_args: bool = True,
  326. **kwargs: Any,
  327. ) -> None:
  328. if validate_args:
  329. _validate_fpr_range(fpr_range)
  330. self.fpr_range = fpr_range
  331. self.average2 = average # self.average is already used by parent class
  332. super().__init__(
  333. num_labels=num_labels,
  334. thresholds=thresholds,
  335. ignore_index=ignore_index,
  336. validate_args=validate_args,
  337. **kwargs,
  338. )
  339. def compute(self) -> Tensor: # type: ignore[override]
  340. """Computes the log AUC score."""
  341. fpr, tpr, _ = super().compute()
  342. return _reduce_logauc(fpr, tpr, fpr_range=self.fpr_range, average=self.average2)
  343. def plot( # type: ignore[override]
  344. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  345. ) -> _PLOT_OUT_TYPE:
  346. """Plot a single or multiple values from the metric.
  347. Args:
  348. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  349. If no value is provided, will automatically call `metric.compute` and plot that result.
  350. ax: An matplotlib axis object. If provided will add plot to that axis
  351. Returns:
  352. Figure and Axes object
  353. Raises:
  354. ModuleNotFoundError:
  355. If `matplotlib` is not installed
  356. .. plot::
  357. :scale: 75
  358. >>> # Example plotting a single
  359. >>> import torch
  360. >>> from torchmetrics.classification import MultilabelLogAUC
  361. >>> metric = MultilabelLogAUC(num_labels=3)
  362. >>> metric.update(torch.rand(20,3), torch.randint(2, (20,3)))
  363. >>> fig_, ax_ = metric.plot()
  364. .. plot::
  365. :scale: 75
  366. >>> # Example plotting multiple values
  367. >>> import torch
  368. >>> from torchmetrics.classification import MultilabelLogAUC
  369. >>> metric = MultilabelLogAUC(num_labels=3)
  370. >>> values = [ ]
  371. >>> for _ in range(10):
  372. ... values.append(metric(torch.rand(20,3), torch.randint(2, (20,3))))
  373. >>> fig_, ax_ = metric.plot(values)
  374. """
  375. return self._plot(val, ax)
  376. class LogAUC(_ClassificationTaskWrapper):
  377. r"""Compute the `Log AUC`_ score for multiclass classification tasks.
  378. The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
  379. positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
  380. score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
  381. is of high importance.
  382. This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  383. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  384. :class:`~torchmetrics.classification.BinaryLogAUC`, :class:`~torchmetrics.classification.MulticlassLogAUC` and
  385. :class:`~torchmetrics.classification.MultilabelLogAUC` for the specific details of each argument influence and
  386. examples.
  387. """
  388. def __new__( # type: ignore[misc]
  389. cls: Type["LogAUC"],
  390. task: Literal["binary", "multiclass", "multilabel"],
  391. thresholds: Optional[Union[int, List[float], Tensor]] = None,
  392. fpr_range: Optional[Tuple[float, float]] = (0.001, 0.1),
  393. num_classes: Optional[int] = None,
  394. num_labels: Optional[int] = None,
  395. ignore_index: Optional[int] = None,
  396. validate_args: bool = True,
  397. **kwargs: Any,
  398. ) -> Metric:
  399. """Initialize task metric."""
  400. task = ClassificationTask.from_str(task)
  401. kwargs.update({
  402. "thresholds": thresholds,
  403. "fpr_range": fpr_range,
  404. "ignore_index": ignore_index,
  405. "validate_args": validate_args,
  406. })
  407. if task == ClassificationTask.BINARY:
  408. return BinaryLogAUC(**kwargs)
  409. if task == ClassificationTask.MULTICLASS:
  410. if not isinstance(num_classes, int):
  411. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  412. return MulticlassLogAUC(num_classes, **kwargs)
  413. if task == ClassificationTask.MULTILABEL:
  414. if not isinstance(num_labels, int):
  415. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  416. return MultilabelLogAUC(num_labels, **kwargs)
  417. raise ValueError(f"Task {task} not supported!")