sensitivity_specificity.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Optional, Union
  15. from torch import Tensor
  16. from typing_extensions import Literal
  17. from torchmetrics.classification.base import _ClassificationTaskWrapper
  18. from torchmetrics.classification.precision_recall_curve import (
  19. BinaryPrecisionRecallCurve,
  20. MulticlassPrecisionRecallCurve,
  21. MultilabelPrecisionRecallCurve,
  22. )
  23. from torchmetrics.functional.classification.sensitivity_specificity import (
  24. _binary_sensitivity_at_specificity_arg_validation,
  25. _binary_sensitivity_at_specificity_compute,
  26. _multiclass_sensitivity_at_specificity_arg_validation,
  27. _multiclass_sensitivity_at_specificity_compute,
  28. _multilabel_sensitivity_at_specificity_arg_validation,
  29. _multilabel_sensitivity_at_specificity_compute,
  30. )
  31. from torchmetrics.metric import Metric
  32. from torchmetrics.utilities.data import dim_zero_cat as _cat
  33. from torchmetrics.utilities.enums import ClassificationTask
  34. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  35. if not _MATPLOTLIB_AVAILABLE:
  36. __doctest_skip__ = [
  37. "BinarySensitivityAtSpecificity.plot",
  38. "MulticlassSensitivityAtSpecificity.plot",
  39. "MultilabelSensitivityAtSpecificity.plot",
  40. ]
  41. class BinarySensitivityAtSpecificity(BinaryPrecisionRecallCurve):
  42. r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided.
  43. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
  44. find the sensitivity for a given specificity level.
  45. Accepts the following input tensors:
  46. - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each
  47. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  48. sigmoid per element.
  49. - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
  50. only contain {0,1} values (except if `ignore_index` is specified).
  51. Additional dimension ``...`` will be flattened into the batch dimension.
  52. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  53. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  54. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  55. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  56. size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
  57. Args:
  58. min_specificity: float value specifying minimum specificity threshold.
  59. thresholds:
  60. Can be one of:
  61. - ``None``, will use a non-binned approach where thresholds are dynamically calculated from
  62. all the data. It is the most accurate but also the most memory-consuming approach.
  63. - ``int`` (larger than 1), will use that number of thresholds linearly spaced from
  64. 0 to 1 as bins for the calculation.
  65. - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation
  66. - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as
  67. bins for the calculation.
  68. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  69. Set to ``False`` for faster computations.
  70. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  71. Returns:
  72. (tuple): a tuple of 2 tensors containing:
  73. - sensitivity: an scalar tensor with the maximum sensitivity for the given specificity level
  74. - threshold: an scalar tensor with the corresponding threshold level
  75. Example:
  76. >>> from torchmetrics.classification import BinarySensitivityAtSpecificity
  77. >>> from torch import tensor
  78. >>> preds = tensor([0, 0.5, 0.4, 0.1])
  79. >>> target = tensor([0, 1, 1, 1])
  80. >>> metric = BinarySensitivityAtSpecificity(min_specificity=0.5, thresholds=None)
  81. >>> metric(preds, target)
  82. (tensor(1.), tensor(0.1000))
  83. >>> metric = BinarySensitivityAtSpecificity(min_specificity=0.5, thresholds=5)
  84. >>> metric(preds, target)
  85. (tensor(0.6667), tensor(0.2500))
  86. """
  87. is_differentiable: bool = False
  88. higher_is_better: Optional[bool] = None
  89. full_state_update: bool = False
  90. plot_lower_bound: float = 0.0
  91. plot_upper_bound: float = 1.0
  92. def __init__(
  93. self,
  94. min_specificity: float,
  95. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  96. ignore_index: Optional[int] = None,
  97. validate_args: bool = True,
  98. **kwargs: Any,
  99. ) -> None:
  100. super().__init__(thresholds, ignore_index, validate_args=False, **kwargs)
  101. if validate_args:
  102. _binary_sensitivity_at_specificity_arg_validation(min_specificity, thresholds, ignore_index)
  103. self.validate_args = validate_args
  104. self.min_specificity = min_specificity
  105. def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
  106. """Compute metric."""
  107. state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat
  108. return _binary_sensitivity_at_specificity_compute(state, self.thresholds, self.min_specificity)
  109. class MulticlassSensitivityAtSpecificity(MulticlassPrecisionRecallCurve):
  110. r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided.
  111. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
  112. find the sensitivity for a given specificity level.
  113. For multiclass the metric is calculated by iteratively treating each class as the positive class and all other
  114. classes as the negative, which is referred to as the one-vs-rest approach. One-vs-one is currently not supported by
  115. this metric.
  116. Accepts the following input tensors:
  117. - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
  118. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  119. softmax per sample.
  120. - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
  121. only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).
  122. Additional dimension ``...`` will be flattened into the batch dimension.
  123. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  124. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  125. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  126. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  127. size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory).
  128. Args:
  129. num_classes: Integer specifying the number of classes
  130. min_specificity: float value specifying minimum specificity threshold.
  131. thresholds:
  132. Can be one of:
  133. - ``None``, will use a non-binned approach where thresholds are dynamically calculated from
  134. all the data. It is the most accurate but also the most memory-consuming approach.
  135. - ``int`` (larger than 1), will use that number of thresholds linearly spaced from
  136. 0 to 1 as bins for the calculation.
  137. - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation
  138. - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as
  139. bins for the calculation.
  140. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  141. Set to ``False`` for faster computations.
  142. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  143. Returns:
  144. (tuple): a tuple of either 2 tensors or 2 lists containing
  145. - sensitivity: an 1d tensor of size (n_classes, ) with the maximum sensitivity for the given
  146. specificity level per class
  147. - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
  148. Example:
  149. >>> from torchmetrics.classification import MulticlassSensitivityAtSpecificity
  150. >>> from torch import tensor
  151. >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
  152. ... [0.05, 0.75, 0.05, 0.05, 0.05],
  153. ... [0.05, 0.05, 0.75, 0.05, 0.05],
  154. ... [0.05, 0.05, 0.05, 0.75, 0.05]])
  155. >>> target = tensor([0, 1, 3, 2])
  156. >>> metric = MulticlassSensitivityAtSpecificity(num_classes=5, min_specificity=0.5, thresholds=None)
  157. >>> metric(preds, target)
  158. (tensor([1., 1., 0., 0., 0.]), tensor([0.7500, 0.7500, 1.0000, 1.0000, 1.0000]))
  159. >>> metric = MulticlassSensitivityAtSpecificity(num_classes=5, min_specificity=0.5, thresholds=5)
  160. >>> metric(preds, target)
  161. (tensor([1., 1., 0., 0., 0.]), tensor([0.7500, 0.7500, 1.0000, 1.0000, 1.0000]))
  162. """
  163. is_differentiable: bool = False
  164. higher_is_better: Optional[bool] = None
  165. full_state_update: bool = False
  166. plot_lower_bound: float = 0.0
  167. plot_upper_bound: float = 1.0
  168. plot_legend_name: str = "Class"
  169. def __init__(
  170. self,
  171. num_classes: int,
  172. min_specificity: float,
  173. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  174. ignore_index: Optional[int] = None,
  175. validate_args: bool = True,
  176. **kwargs: Any,
  177. ) -> None:
  178. super().__init__(
  179. num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs
  180. )
  181. if validate_args:
  182. _multiclass_sensitivity_at_specificity_arg_validation(
  183. num_classes, min_specificity, thresholds, ignore_index
  184. )
  185. self.validate_args = validate_args
  186. self.min_specificity = min_specificity
  187. def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
  188. """Compute metric."""
  189. state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat
  190. return _multiclass_sensitivity_at_specificity_compute(
  191. state, self.num_classes, self.thresholds, self.min_specificity
  192. )
  193. class MultilabelSensitivityAtSpecificity(MultilabelPrecisionRecallCurve):
  194. r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided.
  195. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
  196. find the sensitivity for a given specificity level.
  197. Accepts the following input tensors:
  198. - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
  199. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  200. sigmoid per element.
  201. - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore
  202. only contain {0,1} values (except if `ignore_index` is specified).
  203. Additional dimension ``...`` will be flattened into the batch dimension.
  204. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  205. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  206. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  207. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  208. size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
  209. Args:
  210. num_labels: Integer specifying the number of labels
  211. min_specificity: float value specifying minimum specificity threshold.
  212. thresholds:
  213. Can be one of:
  214. - ``None``, will use a non-binned approach where thresholds are dynamically calculated from
  215. all the data. It is the most accurate but also the most memory-consuming approach.
  216. - ``int`` (larger than 1), will use that number of thresholds linearly spaced from
  217. 0 to 1 as bins for the calculation.
  218. - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation
  219. - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as
  220. bins for the calculation.
  221. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  222. Set to ``False`` for faster computations.
  223. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  224. Returns:
  225. (tuple): a tuple of either 2 tensors or 2 lists containing
  226. - sensitivity: an 1d tensor of size ``(n_classes, )`` with the maximum sensitivity for the given
  227. specificity level per class
  228. - thresholds: an 1d tensor of size ``(n_classes, )`` with the corresponding threshold level per class
  229. Example:
  230. >>> from torchmetrics.classification import MultilabelSensitivityAtSpecificity
  231. >>> from torch import tensor
  232. >>> preds = tensor([[0.75, 0.05, 0.35],
  233. ... [0.45, 0.75, 0.05],
  234. ... [0.05, 0.55, 0.75],
  235. ... [0.05, 0.65, 0.05]])
  236. >>> target = tensor([[1, 0, 1],
  237. ... [0, 0, 0],
  238. ... [0, 1, 1],
  239. ... [1, 1, 1]])
  240. >>> metric = MultilabelSensitivityAtSpecificity(num_labels=3, min_specificity=0.5, thresholds=None)
  241. >>> metric(preds, target)
  242. (tensor([0.5000, 1.0000, 0.6667]), tensor([0.7500, 0.5500, 0.3500]))
  243. >>> metric = MultilabelSensitivityAtSpecificity(num_labels=3, min_specificity=0.5, thresholds=5)
  244. >>> metric(preds, target)
  245. (tensor([0.5000, 1.0000, 0.6667]), tensor([0.7500, 0.5000, 0.2500]))
  246. """
  247. is_differentiable: bool = False
  248. higher_is_better: Optional[bool] = None
  249. full_state_update: bool = False
  250. plot_lower_bound: float = 0.0
  251. plot_upper_bound: float = 1.0
  252. plot_legend_name: str = "Label"
  253. def __init__(
  254. self,
  255. num_labels: int,
  256. min_specificity: float,
  257. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  258. ignore_index: Optional[int] = None,
  259. validate_args: bool = True,
  260. **kwargs: Any,
  261. ) -> None:
  262. super().__init__(
  263. num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs
  264. )
  265. if validate_args:
  266. _multilabel_sensitivity_at_specificity_arg_validation(num_labels, min_specificity, thresholds, ignore_index)
  267. self.validate_args = validate_args
  268. self.min_specificity = min_specificity
  269. def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
  270. """Compute metric."""
  271. state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat
  272. return _multilabel_sensitivity_at_specificity_compute(
  273. state, self.num_labels, self.thresholds, self.ignore_index, self.min_specificity
  274. )
  275. class SensitivityAtSpecificity(_ClassificationTaskWrapper):
  276. r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided.
  277. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
  278. find the sensitivity for a given specificity level.
  279. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  280. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  281. :class:`~torchmetrics.classification.BinarySensitivityAtSpecificity`,
  282. :class:`~torchmetrics.classification.MulticlassSensitivityAtSpecificity` and
  283. :class:`~torchmetrics.classification.MultilabelSensitivityAtSpecificity` for the specific details of each argument
  284. influence and examples.
  285. """
  286. def __new__( # type: ignore[misc]
  287. cls: type["SensitivityAtSpecificity"],
  288. task: Literal["binary", "multiclass", "multilabel"],
  289. min_specificity: float,
  290. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  291. num_classes: Optional[int] = None,
  292. num_labels: Optional[int] = None,
  293. ignore_index: Optional[int] = None,
  294. validate_args: bool = True,
  295. **kwargs: Any,
  296. ) -> Metric:
  297. """Initialize task metric."""
  298. task = ClassificationTask.from_str(task)
  299. if task == ClassificationTask.BINARY:
  300. return BinarySensitivityAtSpecificity(min_specificity, thresholds, ignore_index, validate_args, **kwargs)
  301. if task == ClassificationTask.MULTICLASS:
  302. if not isinstance(num_classes, int):
  303. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  304. return MulticlassSensitivityAtSpecificity(
  305. num_classes, min_specificity, thresholds, ignore_index, validate_args, **kwargs
  306. )
  307. if task == ClassificationTask.MULTILABEL:
  308. if not isinstance(num_labels, int):
  309. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  310. return MultilabelSensitivityAtSpecificity(
  311. num_labels, min_specificity, thresholds, ignore_index, validate_args, **kwargs
  312. )
  313. raise ValueError(f"Task {task} not supported!")