sensitivity_specificity.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Optional, Union
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.classification.precision_recall_curve import (
  19. _binary_precision_recall_curve_arg_validation,
  20. _binary_precision_recall_curve_format,
  21. _binary_precision_recall_curve_tensor_validation,
  22. _binary_precision_recall_curve_update,
  23. _multiclass_precision_recall_curve_arg_validation,
  24. _multiclass_precision_recall_curve_format,
  25. _multiclass_precision_recall_curve_tensor_validation,
  26. _multiclass_precision_recall_curve_update,
  27. _multilabel_precision_recall_curve_arg_validation,
  28. _multilabel_precision_recall_curve_format,
  29. _multilabel_precision_recall_curve_tensor_validation,
  30. _multilabel_precision_recall_curve_update,
  31. )
  32. from torchmetrics.functional.classification.roc import (
  33. _binary_roc_compute,
  34. _multiclass_roc_compute,
  35. _multilabel_roc_compute,
  36. )
  37. from torchmetrics.utilities.enums import ClassificationTask
  38. def _convert_fpr_to_specificity(fpr: Tensor) -> Tensor:
  39. """Convert fprs to specificity."""
  40. return 1 - fpr
  41. def _sensitivity_at_specificity(
  42. sensitivity: Tensor,
  43. specificity: Tensor,
  44. thresholds: Tensor,
  45. min_specificity: float,
  46. ) -> tuple[Tensor, Tensor]:
  47. # get indices where specificity is greater than min_specificity
  48. indices = specificity >= min_specificity
  49. # if no indices are found, max_spec, best_threshold = 0.0, 1e6
  50. if not indices.any():
  51. max_spec = torch.tensor(0.0, device=sensitivity.device, dtype=sensitivity.dtype)
  52. best_threshold = torch.tensor(1e6, device=thresholds.device, dtype=thresholds.dtype)
  53. else:
  54. # redefine sensitivity, specificity and threshold tensor based on indices
  55. sensitivity, specificity, thresholds = sensitivity[indices], specificity[indices], thresholds[indices]
  56. # get argmax
  57. idx = torch.argmax(sensitivity)
  58. # get max_spec and best_threshold
  59. max_spec, best_threshold = sensitivity[idx], thresholds[idx]
  60. return max_spec, best_threshold
  61. def _binary_sensitivity_at_specificity_arg_validation(
  62. min_specificity: float,
  63. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  64. ignore_index: Optional[int] = None,
  65. ) -> None:
  66. _binary_precision_recall_curve_arg_validation(thresholds, ignore_index)
  67. if not isinstance(min_specificity, float) and not (0 <= min_specificity <= 1):
  68. raise ValueError(
  69. f"Expected argument `min_specificity` to be an float in the [0,1] range, but got {min_specificity}"
  70. )
  71. def _binary_sensitivity_at_specificity_compute(
  72. state: Union[Tensor, tuple[Tensor, Tensor]],
  73. thresholds: Optional[Tensor],
  74. min_specificity: float,
  75. pos_label: int = 1,
  76. ) -> tuple[Tensor, Tensor]:
  77. fpr, sensitivity, thresholds = _binary_roc_compute(state, thresholds, pos_label)
  78. specificity = _convert_fpr_to_specificity(fpr)
  79. return _sensitivity_at_specificity(sensitivity, specificity, thresholds, min_specificity)
  80. def binary_sensitivity_at_specificity(
  81. preds: Tensor,
  82. target: Tensor,
  83. min_specificity: float,
  84. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  85. ignore_index: Optional[int] = None,
  86. validate_args: bool = True,
  87. ) -> tuple[Tensor, Tensor]:
  88. r"""Compute the highest possible sensitivity value given the minimum specificity levels provided for binary tasks.
  89. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and
  90. the find the sensitivity for a given specificity level.
  91. Accepts the following input tensors:
  92. - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each
  93. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  94. sigmoid per element.
  95. - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
  96. only contain {0,1} values (except if `ignore_index` is specified).
  97. Additional dimension ``...`` will be flattened into the batch dimension.
  98. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  99. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  100. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  101. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  102. size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
  103. Args:
  104. preds: Tensor with predictions
  105. target: Tensor with true labels
  106. min_specificity: float value specifying minimum specificity threshold.
  107. thresholds:
  108. Can be one of:
  109. - ``None``, will use a non-binned approach where thresholds are dynamically calculated from
  110. all the data. It is the most accurate but also the most memory-consuming approach.
  111. - ``int`` (larger than 1), will use that number of thresholds linearly spaced from
  112. 0 to 1 as bins for the calculation.
  113. - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation
  114. - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as
  115. bins for the calculation.
  116. ignore_index:
  117. Specifies a target value that is ignored and does not contribute to the metric calculation
  118. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  119. Set to ``False`` for faster computations.
  120. Returns:
  121. (tuple): a tuple of 2 tensors containing:
  122. - sensitivity: a scalar tensor with the maximum sensitivity for the given specificity level
  123. - threshold: a scalar tensor with the corresponding threshold level
  124. Example:
  125. >>> from torchmetrics.functional.classification import binary_sensitivity_at_specificity
  126. >>> preds = torch.tensor([0, 0.5, 0.4, 0.1])
  127. >>> target = torch.tensor([0, 1, 1, 1])
  128. >>> binary_sensitivity_at_specificity(preds, target, min_specificity=0.5, thresholds=None)
  129. (tensor(1.), tensor(0.1000))
  130. >>> binary_sensitivity_at_specificity(preds, target, min_specificity=0.5, thresholds=5)
  131. (tensor(0.6667), tensor(0.2500))
  132. """
  133. if validate_args:
  134. _binary_sensitivity_at_specificity_arg_validation(min_specificity, thresholds, ignore_index)
  135. _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index)
  136. preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index)
  137. state = _binary_precision_recall_curve_update(preds, target, thresholds)
  138. return _binary_sensitivity_at_specificity_compute(state, thresholds, min_specificity)
  139. def _multiclass_sensitivity_at_specificity_arg_validation(
  140. num_classes: int,
  141. min_specificity: float,
  142. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  143. ignore_index: Optional[int] = None,
  144. ) -> None:
  145. _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index)
  146. if not isinstance(min_specificity, float) and not (0 <= min_specificity <= 1):
  147. raise ValueError(
  148. f"Expected argument `min_specificity` to be an float in the [0,1] range, but got {min_specificity}"
  149. )
  150. def _multiclass_sensitivity_at_specificity_compute(
  151. state: Union[Tensor, tuple[Tensor, Tensor]],
  152. num_classes: int,
  153. thresholds: Optional[Tensor],
  154. min_specificity: float,
  155. ) -> tuple[Tensor, Tensor]:
  156. fpr, sensitivity, thresholds = _multiclass_roc_compute(state, num_classes, thresholds)
  157. specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr]
  158. if isinstance(state, Tensor):
  159. res = [
  160. _sensitivity_at_specificity(sp, sn, thresholds, min_specificity) # type: ignore
  161. for sp, sn in zip(sensitivity, specificity)
  162. ]
  163. else:
  164. res = [
  165. _sensitivity_at_specificity(sp, sn, t, min_specificity)
  166. for sp, sn, t in zip(sensitivity, specificity, thresholds)
  167. ]
  168. sensitivity = torch.stack([r[0] for r in res])
  169. thresholds = torch.stack([r[1] for r in res])
  170. return sensitivity, thresholds
  171. def multiclass_sensitivity_at_specificity(
  172. preds: Tensor,
  173. target: Tensor,
  174. num_classes: int,
  175. min_specificity: float,
  176. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  177. ignore_index: Optional[int] = None,
  178. validate_args: bool = True,
  179. ) -> tuple[Tensor, Tensor]:
  180. r"""Compute the highest possible sensitivity value given minimum specificity level provided for multiclass tasks.
  181. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
  182. find the sensitivity for a given specificity level.
  183. Accepts the following input tensors:
  184. - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
  185. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  186. softmax per sample.
  187. - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
  188. only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).
  189. Additional dimension ``...`` will be flattened into the batch dimension.
  190. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  191. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  192. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  193. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  194. size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory).
  195. Args:
  196. preds: Tensor with predictions
  197. target: Tensor with true labels
  198. num_classes: Integer specifying the number of classes
  199. min_specificity: float value specifying minimum specificity threshold.
  200. thresholds:
  201. Can be one of:
  202. - ``None``, will use a non-binned approach where thresholds are dynamically calculated from
  203. all the data. It is the most accurate but also the most memory-consuming approach.
  204. - ``int`` (larger than 1), will use that number of thresholds linearly spaced from
  205. 0 to 1 as bins for the calculation.
  206. - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation
  207. - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as
  208. bins for the calculation.
  209. ignore_index:
  210. Specifies a target value that is ignored and does not contribute to the metric calculation
  211. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  212. Set to ``False`` for faster computations.
  213. Returns:
  214. (tuple): a tuple of either 2 tensors or 2 lists containing
  215. - recall: an 1d tensor of size ``(n_classes, )`` with the maximum recall for the given precision level per class
  216. - thresholds: an 1d tensor of size ``(n_classes, )`` with the corresponding threshold level per class
  217. Example:
  218. >>> from torchmetrics.functional.classification import multiclass_sensitivity_at_specificity
  219. >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
  220. ... [0.05, 0.75, 0.05, 0.05, 0.05],
  221. ... [0.05, 0.05, 0.75, 0.05, 0.05],
  222. ... [0.05, 0.05, 0.05, 0.75, 0.05]])
  223. >>> target = torch.tensor([0, 1, 3, 2])
  224. >>> multiclass_sensitivity_at_specificity(preds, target, num_classes=5, min_specificity=0.5, thresholds=None)
  225. (tensor([1., 1., 0., 0., 0.]), tensor([0.7500, 0.7500, 1.0000, 1.0000, 1.0000]))
  226. >>> multiclass_sensitivity_at_specificity(preds, target, num_classes=5, min_specificity=0.5, thresholds=5)
  227. (tensor([1., 1., 0., 0., 0.]), tensor([0.7500, 0.7500, 1.0000, 1.0000, 1.0000]))
  228. """
  229. if validate_args:
  230. _multiclass_sensitivity_at_specificity_arg_validation(num_classes, min_specificity, thresholds, ignore_index)
  231. _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index)
  232. preds, target, thresholds = _multiclass_precision_recall_curve_format(
  233. preds, target, num_classes, thresholds, ignore_index
  234. )
  235. state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds)
  236. return _multiclass_sensitivity_at_specificity_compute(state, num_classes, thresholds, min_specificity)
  237. def _multilabel_sensitivity_at_specificity_arg_validation(
  238. num_labels: int,
  239. min_specificity: float,
  240. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  241. ignore_index: Optional[int] = None,
  242. ) -> None:
  243. _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index)
  244. if not isinstance(min_specificity, float) and not (0 <= min_specificity <= 1):
  245. raise ValueError(
  246. f"Expected argument `min_specificity` to be an float in the [0,1] range, but got {min_specificity}"
  247. )
  248. def _multilabel_sensitivity_at_specificity_compute(
  249. state: Union[Tensor, tuple[Tensor, Tensor]],
  250. num_labels: int,
  251. thresholds: Optional[Tensor],
  252. ignore_index: Optional[int],
  253. min_specificity: float,
  254. ) -> tuple[Tensor, Tensor]:
  255. fpr, sensitivity, thresholds = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index)
  256. specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr]
  257. if isinstance(state, Tensor):
  258. res = [
  259. _sensitivity_at_specificity(sp, sn, thresholds, min_specificity) # type: ignore
  260. for sp, sn in zip(sensitivity, specificity)
  261. ]
  262. else:
  263. res = [
  264. _sensitivity_at_specificity(sp, sn, t, min_specificity)
  265. for sp, sn, t in zip(sensitivity, specificity, thresholds)
  266. ]
  267. sensitivity = torch.stack([r[0] for r in res])
  268. thresholds = torch.stack([r[1] for r in res])
  269. return sensitivity, thresholds
  270. def multilabel_sensitivity_at_specificity(
  271. preds: Tensor,
  272. target: Tensor,
  273. num_labels: int,
  274. min_specificity: float,
  275. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  276. ignore_index: Optional[int] = None,
  277. validate_args: bool = True,
  278. ) -> tuple[Tensor, Tensor]:
  279. r"""Compute the highest possible sensitivity value given minimum specificity level provided for multilabel tasks.
  280. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and
  281. the find the sensitivity for a given specificity level.
  282. Accepts the following input tensors:
  283. - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
  284. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  285. sigmoid per element.
  286. - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore
  287. only contain {0,1} values (except if `ignore_index` is specified).
  288. Additional dimension ``...`` will be flattened into the batch dimension.
  289. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  290. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  291. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  292. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  293. size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
  294. Args:
  295. preds: Tensor with predictions
  296. target: Tensor with true labels
  297. num_labels: Integer specifying the number of labels
  298. min_specificity: float value specifying minimum specificity threshold.
  299. thresholds:
  300. Can be one of:
  301. - ``None``, will use a non-binned approach where thresholds are dynamically calculated from
  302. all the data. It is the most accurate but also the most memory-consuming approach.
  303. - ``int`` (larger than 1), will use that number of thresholds linearly spaced from
  304. 0 to 1 as bins for the calculation.
  305. - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation
  306. - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as
  307. bins for the calculation.
  308. ignore_index:
  309. Specifies a target value that is ignored and does not contribute to the metric calculation
  310. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  311. Set to ``False`` for faster computations.
  312. Returns:
  313. (tuple): a tuple of either 2 tensors or 2 lists containing
  314. - sensitivity: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision
  315. level per class
  316. - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
  317. Example:
  318. >>> from torchmetrics.functional.classification import multilabel_sensitivity_at_specificity
  319. >>> preds = torch.tensor([[0.75, 0.05, 0.35],
  320. ... [0.45, 0.75, 0.05],
  321. ... [0.05, 0.55, 0.75],
  322. ... [0.05, 0.65, 0.05]])
  323. >>> target = torch.tensor([[1, 0, 1],
  324. ... [0, 0, 0],
  325. ... [0, 1, 1],
  326. ... [1, 1, 1]])
  327. >>> multilabel_sensitivity_at_specificity(preds, target, num_labels=3, min_specificity=0.5, thresholds=None)
  328. (tensor([0.5000, 1.0000, 0.6667]), tensor([0.7500, 0.5500, 0.3500]))
  329. >>> multilabel_sensitivity_at_specificity(preds, target, num_labels=3, min_specificity=0.5, thresholds=5)
  330. (tensor([0.5000, 1.0000, 0.6667]), tensor([0.7500, 0.5000, 0.2500]))
  331. """
  332. if validate_args:
  333. _multilabel_sensitivity_at_specificity_arg_validation(num_labels, min_specificity, thresholds, ignore_index)
  334. _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index)
  335. preds, target, thresholds = _multilabel_precision_recall_curve_format(
  336. preds, target, num_labels, thresholds, ignore_index
  337. )
  338. state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds)
  339. return _multilabel_sensitivity_at_specificity_compute(state, num_labels, thresholds, ignore_index, min_specificity)
  340. def sensitivity_at_specificity(
  341. preds: Tensor,
  342. target: Tensor,
  343. task: Literal["binary", "multiclass", "multilabel"],
  344. min_specificity: float,
  345. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  346. num_classes: Optional[int] = None,
  347. num_labels: Optional[int] = None,
  348. ignore_index: Optional[int] = None,
  349. validate_args: bool = True,
  350. ) -> Union[Tensor, tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]]:
  351. r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided.
  352. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and
  353. the find the sensitivity for a given specificity level.
  354. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  355. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  356. :func:`~torchmetrics.functional.classification.binary_sensitivity_at_specificity`,
  357. :func:`~torchmetrics.functional.classification.multiclass_sensitivity_at_specificity` and
  358. :func:`~torchmetrics.functional.classification.multilabel_sensitivity_at_specificity` for the specific details of
  359. each argument influence and examples.
  360. """
  361. task = ClassificationTask.from_str(task)
  362. if task == ClassificationTask.BINARY:
  363. return binary_sensitivity_at_specificity( # type: ignore
  364. preds, target, min_specificity, thresholds, ignore_index, validate_args
  365. )
  366. if task == ClassificationTask.MULTICLASS:
  367. if not isinstance(num_classes, int):
  368. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  369. return multiclass_sensitivity_at_specificity( # type: ignore
  370. preds, target, num_classes, min_specificity, thresholds, ignore_index, validate_args
  371. )
  372. if task == ClassificationTask.MULTILABEL:
  373. if not isinstance(num_labels, int):
  374. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  375. return multilabel_sensitivity_at_specificity( # type: ignore
  376. preds, target, num_labels, min_specificity, thresholds, ignore_index, validate_args
  377. )
  378. raise ValueError(f"Not handled value: {task}")