specificity_sensitivity.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Optional, Union
  15. from torch import Tensor
  16. from typing_extensions import Literal
  17. from torchmetrics.classification.base import _ClassificationTaskWrapper
  18. from torchmetrics.classification.precision_recall_curve import (
  19. BinaryPrecisionRecallCurve,
  20. MulticlassPrecisionRecallCurve,
  21. MultilabelPrecisionRecallCurve,
  22. )
  23. from torchmetrics.functional.classification.specificity_sensitivity import (
  24. _binary_specificity_at_sensitivity_arg_validation,
  25. _binary_specificity_at_sensitivity_compute,
  26. _multiclass_specificity_at_sensitivity_arg_validation,
  27. _multiclass_specificity_at_sensitivity_compute,
  28. _multilabel_specificity_at_sensitivity_arg_validation,
  29. _multilabel_specificity_at_sensitivity_compute,
  30. )
  31. from torchmetrics.metric import Metric
  32. from torchmetrics.utilities.data import dim_zero_cat as _cat
  33. from torchmetrics.utilities.enums import ClassificationTask
  34. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  35. if not _MATPLOTLIB_AVAILABLE:
  36. __doctest_skip__ = [
  37. "BinarySpecificityAtSensitivity.plot",
  38. "MulticlassSpecificityAtSensitivity.plot",
  39. "MultilabelSpecificityAtSensitivity.plot",
  40. ]
  41. class BinarySpecificityAtSensitivity(BinaryPrecisionRecallCurve):
  42. r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided.
  43. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
  44. find the specificity for a given sensitivity level.
  45. Accepts the following input tensors:
  46. - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each
  47. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  48. sigmoid per element.
  49. - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
  50. only contain {0,1} values (except if `ignore_index` is specified).
  51. Additional dimension ``...`` will be flattened into the batch dimension.
  52. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  53. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  54. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  55. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  56. size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
  57. Args:
  58. min_sensitivity: float value specifying minimum sensitivity threshold.
  59. thresholds:
  60. Can be one of:
  61. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  62. all the data. Most accurate but also most memory consuming approach.
  63. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  64. 0 to 1 as bins for the calculation.
  65. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  66. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  67. bins for the calculation.
  68. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  69. Set to ``False`` for faster computations.
  70. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  71. Returns:
  72. (tuple): a tuple of 2 tensors containing:
  73. - specificity: an scalar tensor with the maximum specificity for the given sensitivity level
  74. - threshold: an scalar tensor with the corresponding threshold level
  75. Example:
  76. >>> from torchmetrics.classification import BinarySpecificityAtSensitivity
  77. >>> from torch import tensor
  78. >>> preds = tensor([0, 0.5, 0.4, 0.1])
  79. >>> target = tensor([0, 1, 1, 1])
  80. >>> metric = BinarySpecificityAtSensitivity(min_sensitivity=0.5, thresholds=None)
  81. >>> metric(preds, target)
  82. (tensor(1.), tensor(0.4000))
  83. >>> metric = BinarySpecificityAtSensitivity(min_sensitivity=0.5, thresholds=5)
  84. >>> metric(preds, target)
  85. (tensor(1.), tensor(0.2500))
  86. """
  87. is_differentiable: bool = False
  88. higher_is_better: Optional[bool] = None
  89. full_state_update: bool = False
  90. plot_lower_bound: float = 0.0
  91. plot_upper_bound: float = 1.0
  92. def __init__(
  93. self,
  94. min_sensitivity: float,
  95. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  96. ignore_index: Optional[int] = None,
  97. validate_args: bool = True,
  98. **kwargs: Any,
  99. ) -> None:
  100. super().__init__(thresholds, ignore_index, validate_args=False, **kwargs)
  101. if validate_args:
  102. _binary_specificity_at_sensitivity_arg_validation(min_sensitivity, thresholds, ignore_index)
  103. self.validate_args = validate_args
  104. self.min_sensitivity = min_sensitivity
  105. def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
  106. """Compute metric."""
  107. state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat
  108. return _binary_specificity_at_sensitivity_compute(state, self.thresholds, self.min_sensitivity)
  109. class MulticlassSpecificityAtSensitivity(MulticlassPrecisionRecallCurve):
  110. r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided.
  111. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
  112. find the specificity for a given sensitivity level.
  113. For multiclass the metric is calculated by iteratively treating each class as the positive class and all other
  114. classes as the negative, which is referred to as the one-vs-rest approach. One-vs-one is currently not supported by
  115. this metric.
  116. Accepts the following input tensors:
  117. - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
  118. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  119. softmax per sample.
  120. - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
  121. only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).
  122. Additional dimension ``...`` will be flattened into the batch dimension.
  123. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  124. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  125. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  126. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  127. size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory).
  128. Args:
  129. num_classes: Integer specifying the number of classes
  130. min_sensitivity: float value specifying minimum sensitivity threshold.
  131. thresholds:
  132. Can be one of:
  133. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  134. all the data. Most accurate but also most memory consuming approach.
  135. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  136. 0 to 1 as bins for the calculation.
  137. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  138. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  139. bins for the calculation.
  140. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  141. Set to ``False`` for faster computations.
  142. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  143. Returns:
  144. (tuple): a tuple of either 2 tensors or 2 lists containing
  145. - specificity: an 1d tensor of size (n_classes, ) with the maximum specificity for the given
  146. sensitivity level per class
  147. - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
  148. Example:
  149. >>> from torchmetrics.classification import MulticlassSpecificityAtSensitivity
  150. >>> from torch import tensor
  151. >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
  152. ... [0.05, 0.75, 0.05, 0.05, 0.05],
  153. ... [0.05, 0.05, 0.75, 0.05, 0.05],
  154. ... [0.05, 0.05, 0.05, 0.75, 0.05]])
  155. >>> target = tensor([0, 1, 3, 2])
  156. >>> metric = MulticlassSpecificityAtSensitivity(num_classes=5, min_sensitivity=0.5, thresholds=None)
  157. >>> metric(preds, target)
  158. (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06]))
  159. >>> metric = MulticlassSpecificityAtSensitivity(num_classes=5, min_sensitivity=0.5, thresholds=5)
  160. >>> metric(preds, target)
  161. (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06]))
  162. """
  163. is_differentiable: bool = False
  164. higher_is_better: Optional[bool] = None
  165. full_state_update: bool = False
  166. plot_lower_bound: float = 0.0
  167. plot_upper_bound: float = 1.0
  168. plot_legend_name: str = "Class"
  169. def __init__(
  170. self,
  171. num_classes: int,
  172. min_sensitivity: float,
  173. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  174. ignore_index: Optional[int] = None,
  175. validate_args: bool = True,
  176. **kwargs: Any,
  177. ) -> None:
  178. super().__init__(
  179. num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs
  180. )
  181. if validate_args:
  182. _multiclass_specificity_at_sensitivity_arg_validation(
  183. num_classes, min_sensitivity, thresholds, ignore_index
  184. )
  185. self.validate_args = validate_args
  186. self.min_sensitivity = min_sensitivity
  187. def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
  188. """Compute metric."""
  189. state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat
  190. return _multiclass_specificity_at_sensitivity_compute(
  191. state, self.num_classes, self.thresholds, self.min_sensitivity
  192. )
  193. class MultilabelSpecificityAtSensitivity(MultilabelPrecisionRecallCurve):
  194. r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided.
  195. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
  196. find the specificity for a given sensitivity level.
  197. Accepts the following input tensors:
  198. - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
  199. observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  200. sigmoid per element.
  201. - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore
  202. only contain {0,1} values (except if `ignore_index` is specified).
  203. Additional dimension ``...`` will be flattened into the batch dimension.
  204. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  205. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  206. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  207. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  208. size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
  209. Args:
  210. num_labels: Integer specifying the number of labels
  211. min_sensitivity: float value specifying minimum sensitivity threshold.
  212. thresholds:
  213. Can be one of:
  214. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  215. all the data. Most accurate but also most memory consuming approach.
  216. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  217. 0 to 1 as bins for the calculation.
  218. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  219. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  220. bins for the calculation.
  221. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  222. Set to ``False`` for faster computations.
  223. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  224. Returns:
  225. (tuple): a tuple of either 2 tensors or 2 lists containing
  226. - specificity: an 1d tensor of size (n_classes, ) with the maximum specificity for the given
  227. sensitivity level per class
  228. - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
  229. Example:
  230. >>> from torchmetrics.classification import MultilabelSpecificityAtSensitivity
  231. >>> from torch import tensor
  232. >>> preds = tensor([[0.75, 0.05, 0.35],
  233. ... [0.45, 0.75, 0.05],
  234. ... [0.05, 0.55, 0.75],
  235. ... [0.05, 0.65, 0.05]])
  236. >>> target = tensor([[1, 0, 1],
  237. ... [0, 0, 0],
  238. ... [0, 1, 1],
  239. ... [1, 1, 1]])
  240. >>> metric = MultilabelSpecificityAtSensitivity(num_labels=3, min_sensitivity=0.5, thresholds=None)
  241. >>> metric(preds, target)
  242. (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.6500, 0.3500]))
  243. >>> metric = MultilabelSpecificityAtSensitivity(num_labels=3, min_sensitivity=0.5, thresholds=5)
  244. >>> metric(preds, target)
  245. (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.5000, 0.2500]))
  246. """
  247. is_differentiable: bool = False
  248. higher_is_better: Optional[bool] = None
  249. full_state_update: bool = False
  250. plot_lower_bound: float = 0.0
  251. plot_upper_bound: float = 1.0
  252. plot_legend_name: str = "Label"
  253. def __init__(
  254. self,
  255. num_labels: int,
  256. min_sensitivity: float,
  257. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  258. ignore_index: Optional[int] = None,
  259. validate_args: bool = True,
  260. **kwargs: Any,
  261. ) -> None:
  262. super().__init__(
  263. num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs
  264. )
  265. if validate_args:
  266. _multilabel_specificity_at_sensitivity_arg_validation(num_labels, min_sensitivity, thresholds, ignore_index)
  267. self.validate_args = validate_args
  268. self.min_sensitivity = min_sensitivity
  269. def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
  270. """Compute metric."""
  271. state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat
  272. return _multilabel_specificity_at_sensitivity_compute(
  273. state, self.num_labels, self.thresholds, self.ignore_index, self.min_sensitivity
  274. )
  275. class SpecificityAtSensitivity(_ClassificationTaskWrapper):
  276. r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided.
  277. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the
  278. find the specificity for a given sensitivity level.
  279. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  280. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  281. :class:`~torchmetrics.classification.BinarySpecificityAtSensitivity`,
  282. :class:`~torchmetrics.classification.MulticlassSpecificityAtSensitivity` and
  283. :class:`~torchmetrics.classification.MultilabelSpecificityAtSensitivity` for the specific details of each argument
  284. influence and examples.
  285. """
  286. def __new__( # type: ignore[misc]
  287. cls: type["SpecificityAtSensitivity"],
  288. task: Literal["binary", "multiclass", "multilabel"],
  289. min_sensitivity: float,
  290. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  291. num_classes: Optional[int] = None,
  292. num_labels: Optional[int] = None,
  293. ignore_index: Optional[int] = None,
  294. validate_args: bool = True,
  295. **kwargs: Any,
  296. ) -> Metric:
  297. """Initialize task metric."""
  298. task = ClassificationTask.from_str(task)
  299. if task == ClassificationTask.BINARY:
  300. return BinarySpecificityAtSensitivity(min_sensitivity, thresholds, ignore_index, validate_args, **kwargs)
  301. if task == ClassificationTask.MULTICLASS:
  302. if not isinstance(num_classes, int):
  303. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  304. return MulticlassSpecificityAtSensitivity(
  305. num_classes, min_sensitivity, thresholds, ignore_index, validate_args, **kwargs
  306. )
  307. if task == ClassificationTask.MULTILABEL:
  308. if not isinstance(num_labels, int):
  309. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  310. return MultilabelSpecificityAtSensitivity(
  311. num_labels, min_sensitivity, thresholds, ignore_index, validate_args, **kwargs
  312. )
  313. raise ValueError(f"Task {task} not supported!")