specificity_sensitivity.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. from typing import List, Optional, Union
  16. import torch
  17. from torch import Tensor
  18. from typing_extensions import Literal
  19. from torchmetrics.functional.classification.precision_recall_curve import (
  20. _binary_precision_recall_curve_arg_validation,
  21. _binary_precision_recall_curve_format,
  22. _binary_precision_recall_curve_tensor_validation,
  23. _binary_precision_recall_curve_update,
  24. _multiclass_precision_recall_curve_arg_validation,
  25. _multiclass_precision_recall_curve_format,
  26. _multiclass_precision_recall_curve_tensor_validation,
  27. _multiclass_precision_recall_curve_update,
  28. _multilabel_precision_recall_curve_arg_validation,
  29. _multilabel_precision_recall_curve_format,
  30. _multilabel_precision_recall_curve_tensor_validation,
  31. _multilabel_precision_recall_curve_update,
  32. )
  33. from torchmetrics.functional.classification.roc import (
  34. _binary_roc_compute,
  35. _multiclass_roc_compute,
  36. _multilabel_roc_compute,
  37. )
  38. from torchmetrics.utilities.enums import ClassificationTask
  39. def _convert_fpr_to_specificity(fpr: Tensor) -> Tensor:
  40. """Convert fprs to specificity."""
  41. return 1 - fpr
  42. def _specificity_at_sensitivity(
  43. specificity: Tensor,
  44. sensitivity: Tensor,
  45. thresholds: Tensor,
  46. min_sensitivity: float,
  47. ) -> tuple[Tensor, Tensor]:
  48. # get indices where sensitivity is greater than min_sensitivity
  49. indices = sensitivity >= min_sensitivity
  50. # if no indices are found, max_spec, best_threshold = 0.0, 1e6
  51. if not indices.any():
  52. max_spec = torch.tensor(0.0, device=specificity.device, dtype=specificity.dtype)
  53. best_threshold = torch.tensor(1e6, device=thresholds.device, dtype=thresholds.dtype)
  54. else:
  55. # redefine specificity, sensitivity and threshold tensor based on indices
  56. specificity, sensitivity, thresholds = specificity[indices], sensitivity[indices], thresholds[indices]
  57. # get argmax
  58. idx = torch.argmax(specificity)
  59. # get max_spec and best_threshold
  60. max_spec, best_threshold = specificity[idx], thresholds[idx]
  61. return max_spec, best_threshold
  62. def _binary_specificity_at_sensitivity_arg_validation(
  63. min_sensitivity: float,
  64. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  65. ignore_index: Optional[int] = None,
  66. ) -> None:
  67. _binary_precision_recall_curve_arg_validation(thresholds, ignore_index)
  68. if not isinstance(min_sensitivity, float) and not (0 <= min_sensitivity <= 1):
  69. raise ValueError(
  70. f"Expected argument `min_sensitivity` to be an float in the [0,1] range, but got {min_sensitivity}"
  71. )
  72. def _binary_specificity_at_sensitivity_compute(
  73. state: Union[Tensor, tuple[Tensor, Tensor]],
  74. thresholds: Optional[Tensor],
  75. min_sensitivity: float,
  76. pos_label: int = 1,
  77. ) -> tuple[Tensor, Tensor]:
  78. fpr, sensitivity, thresholds = _binary_roc_compute(state, thresholds, pos_label)
  79. specificity = _convert_fpr_to_specificity(fpr)
  80. return _specificity_at_sensitivity(specificity, sensitivity, thresholds, min_sensitivity)
  81. def binary_specificity_at_sensitivity(
  82. preds: Tensor,
  83. target: Tensor,
  84. min_sensitivity: float,
  85. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  86. ignore_index: Optional[int] = None,
  87. validate_args: bool = True,
  88. ) -> tuple[Tensor, Tensor]:
  89. r"""Compute the highest possible specificity value given the minimum sensitivity levels provided for binary tasks.
  90. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and
  91. the find the specificity for a given sensitivity level.
  92. Accepts the following input tensors:
  93. - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each
  94. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  95. sigmoid per element.
  96. - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
  97. only contain {0,1} values (except if `ignore_index` is specified).
  98. Additional dimension ``...`` will be flattened into the batch dimension.
  99. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  100. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  101. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  102. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  103. size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
  104. Args:
  105. preds: Tensor with predictions
  106. target: Tensor with true labels
  107. min_sensitivity: float value specifying minimum sensitivity threshold.
  108. thresholds:
  109. Can be one of:
  110. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  111. all the data. Most accurate but also most memory consuming approach.
  112. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  113. 0 to 1 as bins for the calculation.
  114. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  115. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  116. bins for the calculation.
  117. ignore_index:
  118. Specifies a target value that is ignored and does not contribute to the metric calculation
  119. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  120. Set to ``False`` for faster computations.
  121. Returns:
  122. (tuple): a tuple of 2 tensors containing:
  123. - specificity: a scalar tensor with the maximum specificity for the given sensitivity level
  124. - threshold: a scalar tensor with the corresponding threshold level
  125. Example:
  126. >>> from torchmetrics.functional.classification import binary_specificity_at_sensitivity
  127. >>> preds = torch.tensor([0, 0.5, 0.4, 0.1])
  128. >>> target = torch.tensor([0, 1, 1, 1])
  129. >>> binary_specificity_at_sensitivity(preds, target, min_sensitivity=0.5, thresholds=None)
  130. (tensor(1.), tensor(0.4000))
  131. >>> binary_specificity_at_sensitivity(preds, target, min_sensitivity=0.5, thresholds=5)
  132. (tensor(1.), tensor(0.2500))
  133. """
  134. if validate_args:
  135. _binary_specificity_at_sensitivity_arg_validation(min_sensitivity, thresholds, ignore_index)
  136. _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index)
  137. preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index)
  138. state = _binary_precision_recall_curve_update(preds, target, thresholds)
  139. return _binary_specificity_at_sensitivity_compute(state, thresholds, min_sensitivity)
  140. def _multiclass_specificity_at_sensitivity_arg_validation(
  141. num_classes: int,
  142. min_sensitivity: float,
  143. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  144. ignore_index: Optional[int] = None,
  145. ) -> None:
  146. _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index)
  147. if not isinstance(min_sensitivity, float) and not (0 <= min_sensitivity <= 1):
  148. raise ValueError(
  149. f"Expected argument `min_sensitivity` to be an float in the [0,1] range, but got {min_sensitivity}"
  150. )
  151. def _multiclass_specificity_at_sensitivity_compute(
  152. state: Union[Tensor, tuple[Tensor, Tensor]],
  153. num_classes: int,
  154. thresholds: Optional[Tensor],
  155. min_sensitivity: float,
  156. ) -> tuple[Tensor, Tensor]:
  157. fpr, sensitivity, thresholds = _multiclass_roc_compute(state, num_classes, thresholds)
  158. specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr]
  159. if isinstance(state, Tensor):
  160. res = [
  161. _specificity_at_sensitivity(sp, sn, thresholds, min_sensitivity) # type: ignore
  162. for sp, sn in zip(specificity, sensitivity)
  163. ]
  164. else:
  165. res = [
  166. _specificity_at_sensitivity(sp, sn, t, min_sensitivity)
  167. for sp, sn, t in zip(specificity, sensitivity, thresholds)
  168. ]
  169. specificity = torch.stack([r[0] for r in res])
  170. thresholds = torch.stack([r[1] for r in res])
  171. return specificity, thresholds
  172. def multiclass_specificity_at_sensitivity(
  173. preds: Tensor,
  174. target: Tensor,
  175. num_classes: int,
  176. min_sensitivity: float,
  177. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  178. ignore_index: Optional[int] = None,
  179. validate_args: bool = True,
  180. ) -> tuple[Tensor, Tensor]:
  181. r"""Compute the highest possible specificity value given minimum sensitivity level provided for multiclass tasks.
  182. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
  183. find the specificity for a given sensitivity level.
  184. Accepts the following input tensors:
  185. - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
  186. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  187. softmax per sample.
  188. - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
  189. only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).
  190. Additional dimension ``...`` will be flattened into the batch dimension.
  191. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  192. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  193. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  194. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  195. size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory).
  196. Args:
  197. preds: Tensor with predictions
  198. target: Tensor with true labels
  199. num_classes: Integer specifying the number of classes
  200. min_sensitivity: float value specifying minimum sensitivity threshold.
  201. thresholds:
  202. Can be one of:
  203. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  204. all the data. Most accurate but also most memory consuming approach.
  205. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  206. 0 to 1 as bins for the calculation.
  207. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  208. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  209. bins for the calculation.
  210. ignore_index:
  211. Specifies a target value that is ignored and does not contribute to the metric calculation
  212. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  213. Set to ``False`` for faster computations.
  214. Returns:
  215. (tuple): a tuple of either 2 tensors or 2 lists containing
  216. - recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class
  217. - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
  218. Example:
  219. >>> from torchmetrics.functional.classification import multiclass_specificity_at_sensitivity
  220. >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
  221. ... [0.05, 0.75, 0.05, 0.05, 0.05],
  222. ... [0.05, 0.05, 0.75, 0.05, 0.05],
  223. ... [0.05, 0.05, 0.05, 0.75, 0.05]])
  224. >>> target = torch.tensor([0, 1, 3, 2])
  225. >>> multiclass_specificity_at_sensitivity(preds, target, num_classes=5, min_sensitivity=0.5, thresholds=None)
  226. (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06]))
  227. >>> multiclass_specificity_at_sensitivity(preds, target, num_classes=5, min_sensitivity=0.5, thresholds=5)
  228. (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06]))
  229. """
  230. if validate_args:
  231. _multiclass_specificity_at_sensitivity_arg_validation(num_classes, min_sensitivity, thresholds, ignore_index)
  232. _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index)
  233. preds, target, thresholds = _multiclass_precision_recall_curve_format(
  234. preds, target, num_classes, thresholds, ignore_index
  235. )
  236. state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds)
  237. return _multiclass_specificity_at_sensitivity_compute(state, num_classes, thresholds, min_sensitivity)
  238. def _multilabel_specificity_at_sensitivity_arg_validation(
  239. num_labels: int,
  240. min_sensitivity: float,
  241. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  242. ignore_index: Optional[int] = None,
  243. ) -> None:
  244. _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index)
  245. if not isinstance(min_sensitivity, float) and not (0 <= min_sensitivity <= 1):
  246. raise ValueError(
  247. f"Expected argument `min_sensitivity` to be an float in the [0,1] range, but got {min_sensitivity}"
  248. )
  249. def _multilabel_specificity_at_sensitivity_compute(
  250. state: Union[Tensor, tuple[Tensor, Tensor]],
  251. num_labels: int,
  252. thresholds: Optional[Tensor],
  253. ignore_index: Optional[int],
  254. min_sensitivity: float,
  255. ) -> tuple[Tensor, Tensor]:
  256. fpr, sensitivity, thresholds = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index)
  257. specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr]
  258. if isinstance(state, Tensor):
  259. res = [
  260. _specificity_at_sensitivity(sp, sn, thresholds, min_sensitivity) # type: ignore
  261. for sp, sn in zip(specificity, sensitivity)
  262. ]
  263. else:
  264. res = [
  265. _specificity_at_sensitivity(sp, sn, t, min_sensitivity)
  266. for sp, sn, t in zip(specificity, sensitivity, thresholds)
  267. ]
  268. specificity = torch.stack([r[0] for r in res])
  269. thresholds = torch.stack([r[1] for r in res])
  270. return specificity, thresholds
  271. def multilabel_specificity_at_sensitivity(
  272. preds: Tensor,
  273. target: Tensor,
  274. num_labels: int,
  275. min_sensitivity: float,
  276. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  277. ignore_index: Optional[int] = None,
  278. validate_args: bool = True,
  279. ) -> tuple[Tensor, Tensor]:
  280. r"""Compute the highest possible specificity value given minimum sensitivity level provided for multilabel tasks.
  281. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and
  282. the find the specificity for a given sensitivity level.
  283. Accepts the following input tensors:
  284. - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
  285. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  286. sigmoid per element.
  287. - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore
  288. only contain {0,1} values (except if `ignore_index` is specified).
  289. Additional dimension ``...`` will be flattened into the batch dimension.
  290. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  291. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  292. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  293. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  294. size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
  295. Args:
  296. preds: Tensor with predictions
  297. target: Tensor with true labels
  298. num_labels: Integer specifying the number of labels
  299. min_sensitivity: float value specifying minimum sensitivity threshold.
  300. thresholds:
  301. Can be one of:
  302. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  303. all the data. Most accurate but also most memory consuming approach.
  304. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  305. 0 to 1 as bins for the calculation.
  306. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  307. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  308. bins for the calculation.
  309. ignore_index:
  310. Specifies a target value that is ignored and does not contribute to the metric calculation
  311. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  312. Set to ``False`` for faster computations.
  313. Returns:
  314. (tuple): a tuple of either 2 tensors or 2 lists containing
  315. - specificity: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision
  316. level per class
  317. - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
  318. Example:
  319. >>> from torchmetrics.functional.classification import multilabel_specificity_at_sensitivity
  320. >>> preds = torch.tensor([[0.75, 0.05, 0.35],
  321. ... [0.45, 0.75, 0.05],
  322. ... [0.05, 0.55, 0.75],
  323. ... [0.05, 0.65, 0.05]])
  324. >>> target = torch.tensor([[1, 0, 1],
  325. ... [0, 0, 0],
  326. ... [0, 1, 1],
  327. ... [1, 1, 1]])
  328. >>> multilabel_specificity_at_sensitivity(preds, target, num_labels=3, min_sensitivity=0.5, thresholds=None)
  329. (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.6500, 0.3500]))
  330. >>> multilabel_specificity_at_sensitivity(preds, target, num_labels=3, min_sensitivity=0.5, thresholds=5)
  331. (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.5000, 0.2500]))
  332. """
  333. if validate_args:
  334. _multilabel_specificity_at_sensitivity_arg_validation(num_labels, min_sensitivity, thresholds, ignore_index)
  335. _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index)
  336. preds, target, thresholds = _multilabel_precision_recall_curve_format(
  337. preds, target, num_labels, thresholds, ignore_index
  338. )
  339. state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds)
  340. return _multilabel_specificity_at_sensitivity_compute(state, num_labels, thresholds, ignore_index, min_sensitivity)
  341. def specicity_at_sensitivity(
  342. preds: Tensor,
  343. target: Tensor,
  344. task: Literal["binary", "multiclass", "multilabel"],
  345. min_sensitivity: float,
  346. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  347. num_classes: Optional[int] = None,
  348. num_labels: Optional[int] = None,
  349. ignore_index: Optional[int] = None,
  350. validate_args: bool = True,
  351. ) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]:
  352. r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided.
  353. .. warning::
  354. This function was deprecated in v1.3.0 of Torchmetrics and will be removed in v2.0.0.
  355. Use `specificity_at_sensitivity` instead.
  356. """
  357. warnings.warn(
  358. "This method has will be removed in 2.0.0. Use `specificity_at_sensitivity` instead.",
  359. DeprecationWarning,
  360. stacklevel=1,
  361. )
  362. return specificity_at_sensitivity(
  363. preds=preds,
  364. target=target,
  365. task=task,
  366. min_sensitivity=min_sensitivity,
  367. thresholds=thresholds,
  368. num_classes=num_classes,
  369. num_labels=num_labels,
  370. ignore_index=ignore_index,
  371. validate_args=validate_args,
  372. )
  373. def specificity_at_sensitivity(
  374. preds: Tensor,
  375. target: Tensor,
  376. task: Literal["binary", "multiclass", "multilabel"],
  377. min_sensitivity: float,
  378. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  379. num_classes: Optional[int] = None,
  380. num_labels: Optional[int] = None,
  381. ignore_index: Optional[int] = None,
  382. validate_args: bool = True,
  383. ) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]:
  384. r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided.
  385. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and
  386. the find the specificity for a given sensitivity level.
  387. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  388. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  389. :func:`~torchmetrics.functional.classification.binary_specificity_at_sensitivity`,
  390. :func:`~torchmetrics.functional.classification.multiclass_specificity_at_sensitivity` and
  391. :func:`~torchmetrics.functional.classification.multilabel_specificity_at_sensitivity` for the specific details of
  392. each argument influence and examples.
  393. """
  394. task = ClassificationTask.from_str(task)
  395. if task == ClassificationTask.BINARY:
  396. return binary_specificity_at_sensitivity( # type: ignore
  397. preds, target, min_sensitivity, thresholds, ignore_index, validate_args
  398. )
  399. if task == ClassificationTask.MULTICLASS:
  400. if not isinstance(num_classes, int):
  401. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  402. return multiclass_specificity_at_sensitivity( # type: ignore
  403. preds, target, num_classes, min_sensitivity, thresholds, ignore_index, validate_args
  404. )
  405. if task == ClassificationTask.MULTILABEL:
  406. if not isinstance(num_labels, int):
  407. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  408. return multilabel_specificity_at_sensitivity( # type: ignore
  409. preds, target, num_labels, min_sensitivity, thresholds, ignore_index, validate_args
  410. )
  411. raise ValueError(f"Not handled value: {task}")