logauc.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Optional, Tuple, Union
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc
  19. from torchmetrics.utilities import rank_zero_warn
  20. from torchmetrics.utilities.compute import _auc_compute_without_check, _safe_divide
  21. from torchmetrics.utilities.data import interp
  22. from torchmetrics.utilities.enums import ClassificationTask
  23. def _validate_fpr_range(fpr_range: Tuple[float, float]) -> None:
  24. """Validate the `fpr_range` argument for the logauc metric."""
  25. if not isinstance(fpr_range, tuple) and not len(fpr_range) == 2:
  26. raise ValueError(f"The `fpr_range` should be a tuple of two floats, but got {type(fpr_range)}.")
  27. if not (0 <= fpr_range[0] < fpr_range[1] <= 1):
  28. raise ValueError(f"The `fpr_range` should be a tuple of two floats in the range [0, 1], but got {fpr_range}.")
  29. def _binary_logauc_compute(
  30. fpr: Tensor,
  31. tpr: Tensor,
  32. fpr_range: Tuple[float, float] = (0.001, 0.1),
  33. ) -> Tensor:
  34. """Compute the logauc score for binary classification tasks."""
  35. fpr_range = torch.tensor(fpr_range).to(fpr.device)
  36. if fpr.numel() < 2 or tpr.numel() < 2:
  37. rank_zero_warn(
  38. "At least two values on for the fpr and tpr are required to compute the log AUC. Returns 0 score."
  39. )
  40. return torch.tensor(0.0, device=fpr.device)
  41. tpr = torch.cat([tpr, interp(fpr_range, fpr, tpr)]).sort().values
  42. fpr = torch.cat([fpr, fpr_range]).sort().values
  43. log_fpr = torch.log10(fpr)
  44. bounds = torch.log10(fpr_range.detach().clone())
  45. lower_bound_idx = torch.where(log_fpr == bounds[0])[0][-1]
  46. upper_bound_idx = torch.where(log_fpr == bounds[1])[0][-1]
  47. trimmed_log_fpr = log_fpr[lower_bound_idx : upper_bound_idx + 1]
  48. trimmed_tpr = tpr[lower_bound_idx : upper_bound_idx + 1]
  49. # compute area and rescale it to the range of fpr
  50. return _auc_compute_without_check(trimmed_log_fpr, trimmed_tpr, 1.0) / (bounds[1] - bounds[0])
  51. def _reduce_logauc(
  52. fpr: Union[Tensor, List[Tensor]],
  53. tpr: Union[Tensor, List[Tensor]],
  54. fpr_range: Tuple[float, float] = (0.001, 0.1),
  55. average: Optional[Literal["macro", "weighted", "none"]] = "macro",
  56. weights: Optional[Tensor] = None,
  57. ) -> Tensor:
  58. """Reduce the logauc score to a single value for multiclass and multilabel classification tasks."""
  59. scores = []
  60. for fpr_i, tpr_i in zip(fpr, tpr):
  61. scores.append(_binary_logauc_compute(fpr_i, tpr_i, fpr_range))
  62. scores = torch.stack(scores)
  63. if torch.isnan(scores).any():
  64. rank_zero_warn(
  65. "LogAUC score for one or more classes/labels was `nan`. Ignoring these classes in {average}-average."
  66. )
  67. idx = ~torch.isnan(scores)
  68. if average is None or average == "none":
  69. return scores
  70. if average == "macro":
  71. return scores[idx].mean()
  72. if average == "weighted" and weights is not None:
  73. weights = _safe_divide(weights[idx], weights[idx].sum())
  74. return (scores[idx] * weights).sum()
  75. raise ValueError(f"Got unknown average parameter: {average}. Please choose one of ['macro', 'weighted', 'none'].")
  76. def binary_logauc(
  77. preds: Tensor,
  78. target: Tensor,
  79. fpr_range: Tuple[float, float] = (0.001, 0.1),
  80. thresholds: Optional[Union[int, List[float], Tensor]] = None,
  81. ignore_index: Optional[int] = None,
  82. validate_args: bool = True,
  83. ) -> Tensor:
  84. r"""Compute the `Log AUC`_ score for binary classification tasks.
  85. The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
  86. positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
  87. score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
  88. is of high importance.
  89. Accepts the following input tensors:
  90. - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each
  91. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  92. sigmoid per element.
  93. - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
  94. only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class.
  95. Additional dimension ``...`` will be flattened into the batch dimension.
  96. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  97. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  98. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  99. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  100. size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
  101. Args:
  102. preds: Tensor with predictions
  103. target: Tensor with ground truth labels
  104. fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log
  105. AUC score.
  106. thresholds:
  107. Can be one of:
  108. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  109. all the data. Most accurate but also most memory consuming approach.
  110. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  111. 0 to 1 as bins for the calculation.
  112. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  113. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  114. bins for the calculation.
  115. ignore_index:
  116. Specifies a target value that is ignored and does not contribute to the metric calculation
  117. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  118. Set to ``False`` for faster computations.
  119. Returns:
  120. A single scalar with the log auc score
  121. Example:
  122. >>> from torchmetrics.functional.classification import binary_logauc
  123. >>> from torch import tensor
  124. >>> preds = tensor([0.75, 0.05, 0.05, 0.05, 0.05])
  125. >>> target = tensor([1, 0, 0, 0, 0])
  126. >>> binary_logauc(preds, target)
  127. tensor(1.)
  128. """
  129. _validate_fpr_range(fpr_range)
  130. fpr, tpr, _ = binary_roc(preds, target, thresholds, ignore_index, validate_args)
  131. return _binary_logauc_compute(fpr, tpr, fpr_range)
  132. def multiclass_logauc(
  133. preds: Tensor,
  134. target: Tensor,
  135. num_classes: int,
  136. fpr_range: Tuple[float, float] = (0.001, 0.1),
  137. average: Optional[Literal["macro", "none"]] = "macro",
  138. thresholds: Optional[Union[int, List[float], Tensor]] = None,
  139. ignore_index: Optional[int] = None,
  140. validate_args: bool = True,
  141. ) -> Tensor:
  142. r"""Compute the `Log AUC`_ score for multiclass classification tasks.
  143. The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
  144. positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
  145. score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
  146. is of high importance.
  147. Accepts the following input tensors:
  148. - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
  149. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  150. softmax per sample.
  151. - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
  152. only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).
  153. Additional dimension ``...`` will be flattened into the batch dimension.
  154. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  155. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  156. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  157. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  158. size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory).
  159. Args:
  160. preds: Tensor with predictions
  161. target: Tensor with true labels
  162. num_classes: Integer specifying the number of classes
  163. fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log
  164. AUC score.
  165. thresholds:
  166. Can be one of:
  167. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  168. all the data. Most accurate but also most memory consuming approach.
  169. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  170. 0 to 1 as bins for the calculation.
  171. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  172. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  173. bins for the calculation.
  174. average:
  175. Defines the reduction that is applied over classes. Should be one of the following:
  176. - ``macro``: Calculate score for each class and average them
  177. - ``"none"`` or ``None``: calculates score for each class and applies no reduction
  178. ignore_index:
  179. Specifies a target value that is ignored and does not contribute to the metric calculation
  180. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  181. Set to ``False`` for faster computations.
  182. Example:
  183. >>> from torchmetrics.functional.classification import multiclass_logauc
  184. >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
  185. ... [0.05, 0.75, 0.05, 0.05, 0.05],
  186. ... [0.05, 0.05, 0.75, 0.05, 0.05],
  187. ... [0.05, 0.05, 0.05, 0.75, 0.05]])
  188. >>> target = torch.tensor([0, 1, 3, 2])
  189. >>> multiclass_logauc(preds, target, num_classes=5, average="macro", thresholds=None)
  190. tensor(0.4000)
  191. >>> multiclass_logauc(preds, target, num_classes=5, average=None, thresholds=None)
  192. tensor([1., 1., 0., 0., 0.])
  193. """
  194. if validate_args:
  195. _validate_fpr_range(fpr_range)
  196. fpr, tpr, _ = multiclass_roc(
  197. preds, target, num_classes, thresholds, average=None, ignore_index=ignore_index, validate_args=validate_args
  198. )
  199. return _reduce_logauc(fpr, tpr, fpr_range, average)
  200. def multilabel_logauc(
  201. preds: Tensor,
  202. target: Tensor,
  203. num_labels: int,
  204. fpr_range: Tuple[float, float] = (0.001, 0.1),
  205. average: Optional[Literal["macro", "none"]] = "macro",
  206. thresholds: Optional[Union[int, List[float], Tensor]] = None,
  207. ignore_index: Optional[int] = None,
  208. validate_args: bool = True,
  209. ) -> Tensor:
  210. r"""Compute the `Log AUC`_ score for multilabel classification tasks.
  211. The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
  212. positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
  213. score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
  214. is of high importance.
  215. Accepts the following input tensors:
  216. - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
  217. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  218. sigmoid per element.
  219. - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore
  220. only contain {0,1} values (except if `ignore_index` is specified).
  221. Additional dimension ``...`` will be flattened into the batch dimension.
  222. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  223. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  224. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  225. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  226. size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
  227. Args:
  228. preds: Tensor with predictions
  229. target: Tensor with true labels
  230. num_labels: Integer specifying the number of labels
  231. fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log
  232. AUC score.
  233. average:
  234. Defines the reduction that is applied over labels. Should be one of the following:
  235. - ``macro``: Calculate score for each label and average them
  236. - ``"none"`` or ``None``: calculates score for each label and applies no reduction
  237. thresholds:
  238. Can be one of:
  239. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  240. all the data. Most accurate but also most memory consuming approach.
  241. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  242. 0 to 1 as bins for the calculation.
  243. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  244. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  245. bins for the calculation.
  246. ignore_index:
  247. Specifies a target value that is ignored and does not contribute to the metric calculation
  248. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  249. Set to ``False`` for faster computations.
  250. Example:
  251. >>> from torchmetrics.functional.classification import multilabel_logauc
  252. >>> preds = torch.tensor([[0.75, 0.05, 0.35],
  253. ... [0.45, 0.75, 0.05],
  254. ... [0.05, 0.55, 0.75],
  255. ... [0.05, 0.65, 0.05]])
  256. >>> target = torch.tensor([[1, 0, 1],
  257. ... [0, 0, 0],
  258. ... [0, 1, 1],
  259. ... [1, 1, 1]])
  260. >>> multilabel_logauc(preds, target, num_labels=3, average="macro", thresholds=None)
  261. tensor(0.3945)
  262. >>> multilabel_logauc(preds, target, num_labels=3, average=None, thresholds=None)
  263. tensor([0.5000, 0.0000, 0.6835])
  264. """
  265. fpr, tpr, _ = multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args)
  266. return _reduce_logauc(fpr, tpr, fpr_range, average=average)
  267. def logauc(
  268. preds: Tensor,
  269. target: Tensor,
  270. task: Literal["binary", "multiclass", "multilabel"],
  271. thresholds: Optional[Union[int, List[float], Tensor]] = None,
  272. num_classes: Optional[int] = None,
  273. num_labels: Optional[int] = None,
  274. fpr_range: Tuple[float, float] = (0.001, 0.1),
  275. average: Optional[Literal["macro", "none"]] = None,
  276. ignore_index: Optional[int] = None,
  277. validate_args: bool = True,
  278. ) -> Optional[Tensor]:
  279. r"""Compute the `Log AUC`_ score for classification tasks.
  280. The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
  281. positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
  282. score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
  283. is of high importance.
  284. """
  285. task = ClassificationTask.from_str(task)
  286. if task == ClassificationTask.BINARY:
  287. return binary_logauc(preds, target, fpr_range, thresholds, ignore_index, validate_args)
  288. if task == ClassificationTask.MULTICLASS:
  289. if not isinstance(num_classes, int):
  290. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  291. return multiclass_logauc(
  292. preds, target, num_classes, fpr_range, average, thresholds, ignore_index, validate_args
  293. )
  294. if task == ClassificationTask.MULTILABEL:
  295. if not isinstance(num_labels, int):
  296. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  297. return multilabel_logauc(preds, target, num_labels, fpr_range, average, thresholds, ignore_index, validate_args)
  298. return None