auroc.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Optional, Union
  15. import torch
  16. from torch import Tensor, tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.classification.precision_recall_curve import (
  19. _binary_precision_recall_curve_arg_validation,
  20. _binary_precision_recall_curve_format,
  21. _binary_precision_recall_curve_tensor_validation,
  22. _binary_precision_recall_curve_update,
  23. _multiclass_precision_recall_curve_arg_validation,
  24. _multiclass_precision_recall_curve_format,
  25. _multiclass_precision_recall_curve_tensor_validation,
  26. _multiclass_precision_recall_curve_update,
  27. _multilabel_precision_recall_curve_arg_validation,
  28. _multilabel_precision_recall_curve_format,
  29. _multilabel_precision_recall_curve_tensor_validation,
  30. _multilabel_precision_recall_curve_update,
  31. )
  32. from torchmetrics.functional.classification.roc import (
  33. _binary_roc_compute,
  34. _multiclass_roc_compute,
  35. _multilabel_roc_compute,
  36. )
  37. from torchmetrics.utilities.compute import _auc_compute_without_check, _safe_divide
  38. from torchmetrics.utilities.data import _bincount
  39. from torchmetrics.utilities.enums import ClassificationTask
  40. from torchmetrics.utilities.prints import rank_zero_warn
  41. def _reduce_auroc(
  42. fpr: Union[Tensor, List[Tensor]],
  43. tpr: Union[Tensor, List[Tensor]],
  44. average: Optional[Literal["macro", "weighted", "none"]] = "macro",
  45. weights: Optional[Tensor] = None,
  46. direction: float = 1.0,
  47. ) -> Tensor:
  48. """Reduce multiple average precision score into one number."""
  49. if isinstance(fpr, Tensor) and isinstance(tpr, Tensor):
  50. res = _auc_compute_without_check(fpr, tpr, direction=direction, axis=1)
  51. else:
  52. res = torch.stack([_auc_compute_without_check(x, y, direction=direction) for x, y in zip(fpr, tpr)])
  53. if average is None or average == "none":
  54. return res
  55. if torch.isnan(res).any():
  56. rank_zero_warn(
  57. f"Average precision score for one or more classes was `nan`. Ignoring these classes in {average}-average",
  58. UserWarning,
  59. )
  60. idx = ~torch.isnan(res)
  61. if average == "macro":
  62. return res[idx].mean()
  63. if average == "weighted" and weights is not None:
  64. weights = _safe_divide(weights[idx], weights[idx].sum())
  65. return (res[idx] * weights).sum()
  66. raise ValueError("Received an incompatible combinations of inputs to make reduction.")
  67. def _binary_auroc_arg_validation(
  68. max_fpr: Optional[float] = None,
  69. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  70. ignore_index: Optional[int] = None,
  71. ) -> None:
  72. _binary_precision_recall_curve_arg_validation(thresholds, ignore_index)
  73. if max_fpr is not None and not isinstance(max_fpr, float) and 0 < max_fpr <= 1:
  74. raise ValueError(f"Arguments `max_fpr` should be a float in range (0, 1], but got: {max_fpr}")
  75. def _binary_auroc_compute(
  76. state: Union[Tensor, tuple[Tensor, Tensor]],
  77. thresholds: Optional[Tensor],
  78. max_fpr: Optional[float] = None,
  79. pos_label: int = 1,
  80. ) -> Tensor:
  81. fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label)
  82. if max_fpr is None or max_fpr == 1 or fpr.sum() == 0 or tpr.sum() == 0:
  83. return _auc_compute_without_check(fpr, tpr, 1.0)
  84. _device = fpr.device if isinstance(fpr, Tensor) else fpr[0].device
  85. max_area: Tensor = tensor(max_fpr, device=_device)
  86. # Add a single point at max_fpr and interpolate its tpr value
  87. stop = torch.bucketize(max_area, fpr, out_int32=True, right=True)
  88. weight = (max_area - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1])
  89. interp_tpr: Tensor = torch.lerp(tpr[stop - 1], tpr[stop], weight)
  90. tpr = torch.cat([tpr[:stop], interp_tpr.view(1)])
  91. fpr = torch.cat([fpr[:stop], max_area.view(1)])
  92. # Compute partial AUC
  93. partial_auc = _auc_compute_without_check(fpr, tpr, 1.0)
  94. # McClish correction: standardize result to be 0.5 if non-discriminant and 1 if maximal
  95. min_area: Tensor = 0.5 * max_area**2
  96. return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))
  97. def binary_auroc(
  98. preds: Tensor,
  99. target: Tensor,
  100. max_fpr: Optional[float] = None,
  101. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  102. ignore_index: Optional[int] = None,
  103. validate_args: bool = True,
  104. ) -> Tensor:
  105. r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks.
  106. The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for
  107. multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5
  108. corresponds to random guessing.
  109. Accepts the following input tensors:
  110. - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each
  111. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  112. sigmoid per element.
  113. - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
  114. only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class.
  115. Additional dimension ``...`` will be flattened into the batch dimension.
  116. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  117. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  118. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  119. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  120. size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
  121. Args:
  122. preds: Tensor with predictions
  123. target: Tensor with true labels
  124. max_fpr: If not ``None``, calculates standardized partial AUC over the range ``[0, max_fpr]``.
  125. thresholds:
  126. Can be one of:
  127. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  128. all the data. Most accurate but also most memory consuming approach.
  129. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  130. 0 to 1 as bins for the calculation.
  131. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  132. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  133. bins for the calculation.
  134. ignore_index:
  135. Specifies a target value that is ignored and does not contribute to the metric calculation
  136. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  137. Set to ``False`` for faster computations.
  138. Returns:
  139. A single scalar with the auroc score
  140. Example:
  141. >>> from torchmetrics.functional.classification import binary_auroc
  142. >>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
  143. >>> target = torch.tensor([0, 1, 1, 0])
  144. >>> binary_auroc(preds, target, thresholds=None)
  145. tensor(0.5000)
  146. >>> binary_auroc(preds, target, thresholds=5)
  147. tensor(0.5000)
  148. """
  149. if validate_args:
  150. _binary_auroc_arg_validation(max_fpr, thresholds, ignore_index)
  151. _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index)
  152. preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index)
  153. state = _binary_precision_recall_curve_update(preds, target, thresholds)
  154. return _binary_auroc_compute(state, thresholds, max_fpr)
  155. def _multiclass_auroc_arg_validation(
  156. num_classes: int,
  157. average: Optional[Literal["macro", "weighted", "none"]] = "macro",
  158. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  159. ignore_index: Optional[int] = None,
  160. ) -> None:
  161. _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index)
  162. allowed_average = ("macro", "weighted", "none", None)
  163. if average not in allowed_average:
  164. raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}")
  165. def _multiclass_auroc_compute(
  166. state: Union[Tensor, tuple[Tensor, Tensor]],
  167. num_classes: int,
  168. average: Optional[Literal["macro", "weighted", "none"]] = "macro",
  169. thresholds: Optional[Tensor] = None,
  170. ) -> Tensor:
  171. fpr, tpr, _ = _multiclass_roc_compute(state, num_classes, thresholds)
  172. return _reduce_auroc(
  173. fpr,
  174. tpr,
  175. average,
  176. weights=_bincount(state[1], minlength=num_classes).float() if thresholds is None else state[0][:, 1, :].sum(-1),
  177. )
  178. def multiclass_auroc(
  179. preds: Tensor,
  180. target: Tensor,
  181. num_classes: int,
  182. average: Optional[Literal["macro", "weighted", "none"]] = "macro",
  183. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  184. ignore_index: Optional[int] = None,
  185. validate_args: bool = True,
  186. ) -> Tensor:
  187. r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks.
  188. The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for
  189. multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5
  190. corresponds to random guessing.
  191. Accepts the following input tensors:
  192. - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
  193. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  194. softmax per sample.
  195. - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
  196. only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).
  197. Additional dimension ``...`` will be flattened into the batch dimension.
  198. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  199. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  200. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  201. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  202. size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory).
  203. Args:
  204. preds: Tensor with predictions
  205. target: Tensor with true labels
  206. num_classes: Integer specifying the number of classes
  207. average:
  208. Defines the reduction that is applied over classes. Should be one of the following:
  209. - ``macro``: Calculate score for each class and average them
  210. - ``weighted``: calculates score for each class and computes weighted average using their support
  211. - ``"none"`` or ``None``: calculates score for each class and applies no reduction
  212. thresholds:
  213. Can be one of:
  214. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  215. all the data. Most accurate but also most memory consuming approach.
  216. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  217. 0 to 1 as bins for the calculation.
  218. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  219. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  220. bins for the calculation.
  221. ignore_index:
  222. Specifies a target value that is ignored and does not contribute to the metric calculation
  223. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  224. Set to ``False`` for faster computations.
  225. Returns:
  226. If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class.
  227. If `average="macro"|"weighted"` then a single scalar is returned.
  228. Example:
  229. >>> from torchmetrics.functional.classification import multiclass_auroc
  230. >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
  231. ... [0.05, 0.75, 0.05, 0.05, 0.05],
  232. ... [0.05, 0.05, 0.75, 0.05, 0.05],
  233. ... [0.05, 0.05, 0.05, 0.75, 0.05]])
  234. >>> target = torch.tensor([0, 1, 3, 2])
  235. >>> multiclass_auroc(preds, target, num_classes=5, average="macro", thresholds=None)
  236. tensor(0.5333)
  237. >>> multiclass_auroc(preds, target, num_classes=5, average=None, thresholds=None)
  238. tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000])
  239. >>> multiclass_auroc(preds, target, num_classes=5, average="macro", thresholds=5)
  240. tensor(0.5333)
  241. >>> multiclass_auroc(preds, target, num_classes=5, average=None, thresholds=5)
  242. tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000])
  243. """
  244. if validate_args:
  245. _multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index)
  246. _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index)
  247. preds, target, thresholds = _multiclass_precision_recall_curve_format(
  248. preds, target, num_classes, thresholds, ignore_index
  249. )
  250. state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds)
  251. return _multiclass_auroc_compute(state, num_classes, average, thresholds)
  252. def _multilabel_auroc_arg_validation(
  253. num_labels: int,
  254. average: Optional[Literal["micro", "macro", "weighted", "none"]],
  255. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  256. ignore_index: Optional[int] = None,
  257. ) -> None:
  258. _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index)
  259. allowed_average = ("micro", "macro", "weighted", "none", None)
  260. if average not in allowed_average:
  261. raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}")
  262. def _multilabel_auroc_compute(
  263. state: Union[Tensor, tuple[Tensor, Tensor]],
  264. num_labels: int,
  265. average: Optional[Literal["micro", "macro", "weighted", "none"]],
  266. thresholds: Optional[Tensor],
  267. ignore_index: Optional[int] = None,
  268. ) -> Tensor:
  269. if average == "micro":
  270. if isinstance(state, Tensor) and thresholds is not None:
  271. return _binary_auroc_compute(state.sum(1), thresholds, max_fpr=None)
  272. preds = state[0].flatten()
  273. target = state[1].flatten()
  274. if ignore_index is not None:
  275. idx = target == ignore_index
  276. preds = preds[~idx]
  277. target = target[~idx]
  278. return _binary_auroc_compute((preds, target), thresholds, max_fpr=None)
  279. fpr, tpr, _ = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index)
  280. return _reduce_auroc(
  281. fpr,
  282. tpr,
  283. average,
  284. weights=(state[1] == 1).sum(dim=0).float() if thresholds is None else state[0][:, 1, :].sum(-1),
  285. )
  286. def multilabel_auroc(
  287. preds: Tensor,
  288. target: Tensor,
  289. num_labels: int,
  290. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  291. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  292. ignore_index: Optional[int] = None,
  293. validate_args: bool = True,
  294. ) -> Tensor:
  295. r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multilabel tasks.
  296. The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for
  297. multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5
  298. corresponds to random guessing.
  299. Accepts the following input tensors:
  300. - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
  301. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  302. sigmoid per element.
  303. - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore
  304. only contain {0,1} values (except if `ignore_index` is specified).
  305. Additional dimension ``...`` will be flattened into the batch dimension.
  306. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  307. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  308. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  309. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  310. size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
  311. Args:
  312. preds: Tensor with predictions
  313. target: Tensor with true labels
  314. num_labels: Integer specifying the number of labels
  315. average:
  316. Defines the reduction that is applied over labels. Should be one of the following:
  317. - ``micro``: Sum score over all labels
  318. - ``macro``: Calculate score for each label and average them
  319. - ``weighted``: calculates score for each label and computes weighted average using their support
  320. - ``"none"`` or ``None``: calculates score for each label and applies no reduction
  321. thresholds:
  322. Can be one of:
  323. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  324. all the data. Most accurate but also most memory consuming approach.
  325. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  326. 0 to 1 as bins for the calculation.
  327. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  328. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  329. bins for the calculation.
  330. ignore_index:
  331. Specifies a target value that is ignored and does not contribute to the metric calculation
  332. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  333. Set to ``False`` for faster computations.
  334. Returns:
  335. If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class.
  336. If `average="micro|macro"|"weighted"` then a single scalar is returned.
  337. Example:
  338. >>> from torchmetrics.functional.classification import multilabel_auroc
  339. >>> preds = torch.tensor([[0.75, 0.05, 0.35],
  340. ... [0.45, 0.75, 0.05],
  341. ... [0.05, 0.55, 0.75],
  342. ... [0.05, 0.65, 0.05]])
  343. >>> target = torch.tensor([[1, 0, 1],
  344. ... [0, 0, 0],
  345. ... [0, 1, 1],
  346. ... [1, 1, 1]])
  347. >>> multilabel_auroc(preds, target, num_labels=3, average="macro", thresholds=None)
  348. tensor(0.6528)
  349. >>> multilabel_auroc(preds, target, num_labels=3, average=None, thresholds=None)
  350. tensor([0.6250, 0.5000, 0.8333])
  351. >>> multilabel_auroc(preds, target, num_labels=3, average="macro", thresholds=5)
  352. tensor(0.6528)
  353. >>> multilabel_auroc(preds, target, num_labels=3, average=None, thresholds=5)
  354. tensor([0.6250, 0.5000, 0.8333])
  355. """
  356. if validate_args:
  357. _multilabel_auroc_arg_validation(num_labels, average, thresholds, ignore_index)
  358. _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index)
  359. preds, target, thresholds = _multilabel_precision_recall_curve_format(
  360. preds, target, num_labels, thresholds, ignore_index
  361. )
  362. state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds)
  363. return _multilabel_auroc_compute(state, num_labels, average, thresholds, ignore_index)
  364. def auroc(
  365. preds: Tensor,
  366. target: Tensor,
  367. task: Literal["binary", "multiclass", "multilabel"],
  368. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  369. num_classes: Optional[int] = None,
  370. num_labels: Optional[int] = None,
  371. average: Optional[Literal["macro", "weighted", "none"]] = "macro",
  372. max_fpr: Optional[float] = None,
  373. ignore_index: Optional[int] = None,
  374. validate_args: bool = True,
  375. ) -> Optional[Tensor]:
  376. r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_).
  377. The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for
  378. multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5
  379. corresponds to random guessing.
  380. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  381. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  382. :func:`~torchmetrics.functional.classification.binary_auroc`,
  383. :func:`~torchmetrics.functional.classification.multiclass_auroc` and
  384. :func:`~torchmetrics.functional.classification.multilabel_auroc` for the specific details of
  385. each argument influence and examples.
  386. Legacy Example:
  387. >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
  388. >>> target = torch.tensor([0, 0, 1, 1, 1])
  389. >>> auroc(preds, target, task='binary')
  390. tensor(0.5000)
  391. >>> preds = torch.tensor([[0.90, 0.05, 0.05],
  392. ... [0.05, 0.90, 0.05],
  393. ... [0.05, 0.05, 0.90],
  394. ... [0.85, 0.05, 0.10],
  395. ... [0.10, 0.10, 0.80]])
  396. >>> target = torch.tensor([0, 1, 1, 2, 2])
  397. >>> auroc(preds, target, task='multiclass', num_classes=3)
  398. tensor(0.7778)
  399. """
  400. task = ClassificationTask.from_str(task)
  401. if task == ClassificationTask.BINARY:
  402. return binary_auroc(preds, target, max_fpr, thresholds, ignore_index, validate_args)
  403. if task == ClassificationTask.MULTICLASS:
  404. if not isinstance(num_classes, int):
  405. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  406. return multiclass_auroc(preds, target, num_classes, average, thresholds, ignore_index, validate_args)
  407. if task == ClassificationTask.MULTILABEL:
  408. if not isinstance(num_labels, int):
  409. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  410. return multilabel_auroc(preds, target, num_labels, average, thresholds, ignore_index, validate_args)
  411. return None