specificity.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. from torch import Tensor
  16. from typing_extensions import Literal
  17. from torchmetrics.functional.classification.stat_scores import (
  18. _binary_stat_scores_arg_validation,
  19. _binary_stat_scores_format,
  20. _binary_stat_scores_tensor_validation,
  21. _binary_stat_scores_update,
  22. _multiclass_stat_scores_arg_validation,
  23. _multiclass_stat_scores_format,
  24. _multiclass_stat_scores_tensor_validation,
  25. _multiclass_stat_scores_update,
  26. _multilabel_stat_scores_arg_validation,
  27. _multilabel_stat_scores_format,
  28. _multilabel_stat_scores_tensor_validation,
  29. _multilabel_stat_scores_update,
  30. )
  31. from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide
  32. from torchmetrics.utilities.enums import ClassificationTask
  33. def _specificity_reduce(
  34. tp: Tensor,
  35. fp: Tensor,
  36. tn: Tensor,
  37. fn: Tensor,
  38. average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
  39. multidim_average: Literal["global", "samplewise"] = "global",
  40. multilabel: bool = False,
  41. ) -> Tensor:
  42. if average == "binary":
  43. return _safe_divide(tn, tn + fp)
  44. if average == "micro":
  45. tn = tn.sum(dim=0 if multidim_average == "global" else 1)
  46. fp = fp.sum(dim=0 if multidim_average == "global" else 1)
  47. return _safe_divide(tn, tn + fp)
  48. specificity_score = _safe_divide(tn, tn + fp)
  49. return _adjust_weights_safe_divide(specificity_score, average, multilabel, tp, fp, fn)
  50. def binary_specificity(
  51. preds: Tensor,
  52. target: Tensor,
  53. threshold: float = 0.5,
  54. multidim_average: Literal["global", "samplewise"] = "global",
  55. ignore_index: Optional[int] = None,
  56. validate_args: bool = True,
  57. ) -> Tensor:
  58. r"""Compute `Specificity`_ for binary tasks.
  59. .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}
  60. Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and
  61. false positives respecitively.
  62. Accepts the following input tensors:
  63. - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside
  64. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  65. we convert to int tensor with thresholding using the value in ``threshold``.
  66. - ``target`` (int tensor): ``(N, ...)``
  67. Args:
  68. preds: Tensor with predictions
  69. target: Tensor with true labels
  70. threshold: Threshold for transforming probability to binary {0,1} predictions
  71. multidim_average:
  72. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  73. - ``global``: Additional dimensions are flatted along the batch dimension
  74. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  75. The statistics in this case are calculated over the additional dimensions.
  76. ignore_index:
  77. Specifies a target value that is ignored and does not contribute to the metric calculation
  78. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  79. Set to ``False`` for faster computations.
  80. Returns:
  81. If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average``
  82. is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample.
  83. Example (preds is int tensor):
  84. >>> from torch import tensor
  85. >>> from torchmetrics.functional.classification import binary_specificity
  86. >>> target = tensor([0, 1, 0, 1, 0, 1])
  87. >>> preds = tensor([0, 0, 1, 1, 0, 1])
  88. >>> binary_specificity(preds, target)
  89. tensor(0.6667)
  90. Example (preds is float tensor):
  91. >>> from torchmetrics.functional.classification import binary_specificity
  92. >>> target = tensor([0, 1, 0, 1, 0, 1])
  93. >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
  94. >>> binary_specificity(preds, target)
  95. tensor(0.6667)
  96. Example (multidim tensors):
  97. >>> from torchmetrics.functional.classification import binary_specificity
  98. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  99. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  100. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  101. >>> binary_specificity(preds, target, multidim_average='samplewise')
  102. tensor([0.0000, 0.3333])
  103. """
  104. if validate_args:
  105. _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index)
  106. _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index)
  107. preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index)
  108. tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average)
  109. return _specificity_reduce(tp, fp, tn, fn, average="binary", multidim_average=multidim_average)
  110. def multiclass_specificity(
  111. preds: Tensor,
  112. target: Tensor,
  113. num_classes: int,
  114. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  115. top_k: int = 1,
  116. multidim_average: Literal["global", "samplewise"] = "global",
  117. ignore_index: Optional[int] = None,
  118. validate_args: bool = True,
  119. ) -> Tensor:
  120. r"""Compute `Specificity`_ for multiclass tasks.
  121. .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}
  122. Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and
  123. false positives respecitively.
  124. Accepts the following input tensors:
  125. - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point
  126. we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into
  127. an int tensor.
  128. - ``target`` (int tensor): ``(N, ...)``
  129. Args:
  130. preds: Tensor with predictions
  131. target: Tensor with true labels
  132. num_classes: Integer specifying the number of classes
  133. average:
  134. Defines the reduction that is applied over labels. Should be one of the following:
  135. - ``micro``: Sum statistics over all labels
  136. - ``macro``: Calculate statistics for each label and average them
  137. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  138. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  139. top_k:
  140. Number of highest probability or logit score predictions considered to find the correct label.
  141. Only works when ``preds`` contain probabilities/logits.
  142. multidim_average:
  143. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  144. - ``global``: Additional dimensions are flatted along the batch dimension
  145. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  146. The statistics in this case are calculated over the additional dimensions.
  147. ignore_index:
  148. Specifies a target value that is ignored and does not contribute to the metric calculation
  149. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  150. Set to ``False`` for faster computations.
  151. Returns:
  152. The returned shape depends on the ``average`` and ``multidim_average`` arguments:
  153. - If ``multidim_average`` is set to ``global``:
  154. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  155. - If ``average=None/'none'``, the shape will be ``(C,)``
  156. - If ``multidim_average`` is set to ``samplewise``:
  157. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  158. - If ``average=None/'none'``, the shape will be ``(N, C)``
  159. Example (preds is int tensor):
  160. >>> from torch import tensor
  161. >>> from torchmetrics.functional.classification import multiclass_specificity
  162. >>> target = tensor([2, 1, 0, 0])
  163. >>> preds = tensor([2, 1, 0, 1])
  164. >>> multiclass_specificity(preds, target, num_classes=3)
  165. tensor(0.8889)
  166. >>> multiclass_specificity(preds, target, num_classes=3, average=None)
  167. tensor([1.0000, 0.6667, 1.0000])
  168. Example (preds is float tensor):
  169. >>> from torchmetrics.functional.classification import multiclass_specificity
  170. >>> target = tensor([2, 1, 0, 0])
  171. >>> preds = tensor([[0.16, 0.26, 0.58],
  172. ... [0.22, 0.61, 0.17],
  173. ... [0.71, 0.09, 0.20],
  174. ... [0.05, 0.82, 0.13]])
  175. >>> multiclass_specificity(preds, target, num_classes=3)
  176. tensor(0.8889)
  177. >>> multiclass_specificity(preds, target, num_classes=3, average=None)
  178. tensor([1.0000, 0.6667, 1.0000])
  179. Example (multidim tensors):
  180. >>> from torchmetrics.functional.classification import multiclass_specificity
  181. >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
  182. >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
  183. >>> multiclass_specificity(preds, target, num_classes=3, multidim_average='samplewise')
  184. tensor([0.7500, 0.6556])
  185. >>> multiclass_specificity(preds, target, num_classes=3, multidim_average='samplewise', average=None)
  186. tensor([[0.7500, 0.7500, 0.7500],
  187. [0.8000, 0.6667, 0.5000]])
  188. """
  189. if validate_args:
  190. _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index)
  191. _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index)
  192. preds, target = _multiclass_stat_scores_format(preds, target, top_k)
  193. tp, fp, tn, fn = _multiclass_stat_scores_update(
  194. preds, target, num_classes, top_k, average, multidim_average, ignore_index
  195. )
  196. return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average)
  197. def multilabel_specificity(
  198. preds: Tensor,
  199. target: Tensor,
  200. num_labels: int,
  201. threshold: float = 0.5,
  202. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  203. multidim_average: Literal["global", "samplewise"] = "global",
  204. ignore_index: Optional[int] = None,
  205. validate_args: bool = True,
  206. ) -> Tensor:
  207. r"""Compute `Specificity`_ for multilabel tasks.
  208. .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}
  209. Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and
  210. false positives respecitively.
  211. Accepts the following input tensors:
  212. - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside
  213. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  214. we convert to int tensor with thresholding using the value in ``threshold``.
  215. - ``target`` (int tensor): ``(N, C, ...)``
  216. Args:
  217. preds: Tensor with predictions
  218. target: Tensor with true labels
  219. num_labels: Integer specifying the number of labels
  220. threshold: Threshold for transforming probability to binary (0,1) predictions
  221. average:
  222. Defines the reduction that is applied over labels. Should be one of the following:
  223. - ``micro``: Sum statistics over all labels
  224. - ``macro``: Calculate statistics for each label and average them
  225. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  226. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  227. multidim_average:
  228. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  229. - ``global``: Additional dimensions are flatted along the batch dimension
  230. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  231. The statistics in this case are calculated over the additional dimensions.
  232. ignore_index:
  233. Specifies a target value that is ignored and does not contribute to the metric calculation
  234. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  235. Set to ``False`` for faster computations.
  236. Returns:
  237. The returned shape depends on the ``average`` and ``multidim_average`` arguments:
  238. - If ``multidim_average`` is set to ``global``:
  239. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  240. - If ``average=None/'none'``, the shape will be ``(C,)``
  241. - If ``multidim_average`` is set to ``samplewise``:
  242. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  243. - If ``average=None/'none'``, the shape will be ``(N, C)``
  244. Example (preds is int tensor):
  245. >>> from torch import tensor
  246. >>> from torchmetrics.functional.classification import multilabel_specificity
  247. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  248. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  249. >>> multilabel_specificity(preds, target, num_labels=3)
  250. tensor(0.6667)
  251. >>> multilabel_specificity(preds, target, num_labels=3, average=None)
  252. tensor([1., 1., 0.])
  253. Example (preds is float tensor):
  254. >>> from torchmetrics.functional.classification import multilabel_specificity
  255. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  256. >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
  257. >>> multilabel_specificity(preds, target, num_labels=3)
  258. tensor(0.6667)
  259. >>> multilabel_specificity(preds, target, num_labels=3, average=None)
  260. tensor([1., 1., 0.])
  261. Example (multidim tensors):
  262. >>> from torchmetrics.functional.classification import multilabel_specificity
  263. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  264. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  265. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  266. >>> multilabel_specificity(preds, target, num_labels=3, multidim_average='samplewise')
  267. tensor([0.0000, 0.3333])
  268. >>> multilabel_specificity(preds, target, num_labels=3, multidim_average='samplewise', average=None)
  269. tensor([[0., 0., 0.],
  270. [0., 0., 1.]])
  271. """
  272. if validate_args:
  273. _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index)
  274. _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
  275. preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
  276. tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average)
  277. return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True)
  278. def specificity(
  279. preds: Tensor,
  280. target: Tensor,
  281. task: Literal["binary", "multiclass", "multilabel"],
  282. threshold: float = 0.5,
  283. num_classes: Optional[int] = None,
  284. num_labels: Optional[int] = None,
  285. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
  286. multidim_average: Optional[Literal["global", "samplewise"]] = "global",
  287. top_k: Optional[int] = 1,
  288. ignore_index: Optional[int] = None,
  289. validate_args: bool = True,
  290. ) -> Tensor:
  291. r"""Compute `Specificity`_.
  292. .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}
  293. Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and
  294. false positives respecitively.
  295. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  296. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  297. :func:`~torchmetrics.functional.classification.binary_specificity`,
  298. :func:`~torchmetrics.functional.classification.multiclass_specificity` and
  299. :func:`~torchmetrics.functional.classification.multilabel_specificity` for the specific
  300. details of each argument influence and examples.
  301. LegacyExample:
  302. >>> from torch import tensor
  303. >>> preds = tensor([2, 0, 2, 1])
  304. >>> target = tensor([1, 1, 2, 0])
  305. >>> specificity(preds, target, task="multiclass", average='macro', num_classes=3)
  306. tensor(0.6111)
  307. >>> specificity(preds, target, task="multiclass", average='micro', num_classes=3)
  308. tensor(0.6250)
  309. """
  310. task = ClassificationTask.from_str(task)
  311. assert multidim_average is not None # noqa: S101 # needed for mypy
  312. if task == ClassificationTask.BINARY:
  313. return binary_specificity(preds, target, threshold, multidim_average, ignore_index, validate_args)
  314. if task == ClassificationTask.MULTICLASS:
  315. if not isinstance(num_classes, int):
  316. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  317. if not isinstance(top_k, int):
  318. raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`")
  319. return multiclass_specificity(
  320. preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args
  321. )
  322. if task == ClassificationTask.MULTILABEL:
  323. if not isinstance(num_labels, int):
  324. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  325. return multilabel_specificity(
  326. preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args
  327. )
  328. raise ValueError(f"Not handled value: {task}")