precision_recall.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. from torch import Tensor
  16. from typing_extensions import Literal
  17. from torchmetrics.functional.classification.stat_scores import (
  18. _binary_stat_scores_arg_validation,
  19. _binary_stat_scores_format,
  20. _binary_stat_scores_tensor_validation,
  21. _binary_stat_scores_update,
  22. _multiclass_stat_scores_arg_validation,
  23. _multiclass_stat_scores_format,
  24. _multiclass_stat_scores_tensor_validation,
  25. _multiclass_stat_scores_update,
  26. _multilabel_stat_scores_arg_validation,
  27. _multilabel_stat_scores_format,
  28. _multilabel_stat_scores_tensor_validation,
  29. _multilabel_stat_scores_update,
  30. )
  31. from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide
  32. from torchmetrics.utilities.enums import ClassificationTask
  33. def _precision_recall_reduce(
  34. stat: Literal["precision", "recall"],
  35. tp: Tensor,
  36. fp: Tensor,
  37. tn: Tensor,
  38. fn: Tensor,
  39. average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
  40. multidim_average: Literal["global", "samplewise"] = "global",
  41. multilabel: bool = False,
  42. top_k: int = 1,
  43. zero_division: float = 0,
  44. ) -> Tensor:
  45. different_stat = fp if stat == "precision" else fn # this is what differs between the two scores
  46. if average == "binary":
  47. return _safe_divide(tp, tp + different_stat, zero_division)
  48. if average == "micro":
  49. tp = tp.sum(dim=0 if multidim_average == "global" else 1)
  50. fn = fn.sum(dim=0 if multidim_average == "global" else 1)
  51. different_stat = different_stat.sum(dim=0 if multidim_average == "global" else 1)
  52. return _safe_divide(tp, tp + different_stat, zero_division)
  53. score = _safe_divide(tp, tp + different_stat, zero_division)
  54. return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k)
  55. def binary_precision(
  56. preds: Tensor,
  57. target: Tensor,
  58. threshold: float = 0.5,
  59. multidim_average: Literal["global", "samplewise"] = "global",
  60. ignore_index: Optional[int] = None,
  61. validate_args: bool = True,
  62. zero_division: float = 0,
  63. ) -> Tensor:
  64. r"""Compute `Precision`_ for binary tasks.
  65. .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}
  66. Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and
  67. false positives respecitively.
  68. Accepts the following input tensors:
  69. - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside
  70. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  71. we convert to int tensor with thresholding using the value in ``threshold``.
  72. - ``target`` (int tensor): ``(N, ...)``
  73. Args:
  74. preds: Tensor with predictions
  75. target: Tensor with true labels
  76. threshold: Threshold for transforming probability to binary {0,1} predictions
  77. multidim_average:
  78. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  79. - ``global``: Additional dimensions are flatted along the batch dimension
  80. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  81. The statistics in this case are calculated over the additional dimensions.
  82. ignore_index:
  83. Specifies a target value that is ignored and does not contribute to the metric calculation
  84. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  85. Set to ``False`` for faster computations.
  86. zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
  87. Returns:
  88. If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average``
  89. is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample.
  90. Example (preds is int tensor):
  91. >>> from torch import tensor
  92. >>> from torchmetrics.functional.classification import binary_precision
  93. >>> target = tensor([0, 1, 0, 1, 0, 1])
  94. >>> preds = tensor([0, 0, 1, 1, 0, 1])
  95. >>> binary_precision(preds, target)
  96. tensor(0.6667)
  97. Example (preds is float tensor):
  98. >>> from torchmetrics.functional.classification import binary_precision
  99. >>> target = tensor([0, 1, 0, 1, 0, 1])
  100. >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
  101. >>> binary_precision(preds, target)
  102. tensor(0.6667)
  103. Example (multidim tensors):
  104. >>> from torchmetrics.functional.classification import binary_precision
  105. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  106. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  107. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  108. >>> binary_precision(preds, target, multidim_average='samplewise')
  109. tensor([0.4000, 0.0000])
  110. """
  111. if validate_args:
  112. _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index)
  113. _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index)
  114. preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index)
  115. tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average)
  116. return _precision_recall_reduce(
  117. "precision", tp, fp, tn, fn, average="binary", multidim_average=multidim_average, zero_division=zero_division
  118. )
  119. def multiclass_precision(
  120. preds: Tensor,
  121. target: Tensor,
  122. num_classes: int,
  123. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  124. top_k: int = 1,
  125. multidim_average: Literal["global", "samplewise"] = "global",
  126. ignore_index: Optional[int] = None,
  127. validate_args: bool = True,
  128. zero_division: float = 0,
  129. ) -> Tensor:
  130. r"""Compute `Precision`_ for multiclass tasks.
  131. .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}
  132. Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and
  133. false positives respecitively.
  134. Accepts the following input tensors:
  135. - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point
  136. we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into
  137. an int tensor.
  138. - ``target`` (int tensor): ``(N, ...)``
  139. Args:
  140. preds: Tensor with predictions
  141. target: Tensor with true labels
  142. num_classes: Integer specifying the number of classes
  143. average:
  144. Defines the reduction that is applied over labels. Should be one of the following:
  145. - ``micro``: Sum statistics over all labels
  146. - ``macro``: Calculate statistics for each label and average them
  147. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  148. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  149. top_k:
  150. Number of highest probability or logit score predictions considered to find the correct label.
  151. Only works when ``preds`` contain probabilities/logits.
  152. multidim_average:
  153. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  154. - ``global``: Additional dimensions are flatted along the batch dimension
  155. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  156. The statistics in this case are calculated over the additional dimensions.
  157. ignore_index:
  158. Specifies a target value that is ignored and does not contribute to the metric calculation
  159. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  160. Set to ``False`` for faster computations.
  161. zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
  162. Returns:
  163. The returned shape depends on the ``average`` and ``multidim_average`` arguments:
  164. - If ``multidim_average`` is set to ``global``:
  165. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  166. - If ``average=None/'none'``, the shape will be ``(C,)``
  167. - If ``multidim_average`` is set to ``samplewise``:
  168. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  169. - If ``average=None/'none'``, the shape will be ``(N, C)``
  170. Example (preds is int tensor):
  171. >>> from torch import tensor
  172. >>> from torchmetrics.functional.classification import multiclass_precision
  173. >>> target = tensor([2, 1, 0, 0])
  174. >>> preds = tensor([2, 1, 0, 1])
  175. >>> multiclass_precision(preds, target, num_classes=3)
  176. tensor(0.8333)
  177. >>> multiclass_precision(preds, target, num_classes=3, average=None)
  178. tensor([1.0000, 0.5000, 1.0000])
  179. Example (preds is float tensor):
  180. >>> from torchmetrics.functional.classification import multiclass_precision
  181. >>> target = tensor([2, 1, 0, 0])
  182. >>> preds = tensor([[0.16, 0.26, 0.58],
  183. ... [0.22, 0.61, 0.17],
  184. ... [0.71, 0.09, 0.20],
  185. ... [0.05, 0.82, 0.13]])
  186. >>> multiclass_precision(preds, target, num_classes=3)
  187. tensor(0.8333)
  188. >>> multiclass_precision(preds, target, num_classes=3, average=None)
  189. tensor([1.0000, 0.5000, 1.0000])
  190. Example (multidim tensors):
  191. >>> from torchmetrics.functional.classification import multiclass_precision
  192. >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
  193. >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
  194. >>> multiclass_precision(preds, target, num_classes=3, multidim_average='samplewise')
  195. tensor([0.3889, 0.2778])
  196. >>> multiclass_precision(preds, target, num_classes=3, multidim_average='samplewise', average=None)
  197. tensor([[0.6667, 0.0000, 0.5000],
  198. [0.0000, 0.5000, 0.3333]])
  199. """
  200. if validate_args:
  201. _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index)
  202. _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index)
  203. preds, target = _multiclass_stat_scores_format(preds, target, top_k)
  204. tp, fp, tn, fn = _multiclass_stat_scores_update(
  205. preds, target, num_classes, top_k, average, multidim_average, ignore_index
  206. )
  207. return _precision_recall_reduce(
  208. "precision",
  209. tp,
  210. fp,
  211. tn,
  212. fn,
  213. average=average,
  214. multidim_average=multidim_average,
  215. top_k=top_k,
  216. zero_division=zero_division,
  217. )
  218. def multilabel_precision(
  219. preds: Tensor,
  220. target: Tensor,
  221. num_labels: int,
  222. threshold: float = 0.5,
  223. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  224. multidim_average: Literal["global", "samplewise"] = "global",
  225. ignore_index: Optional[int] = None,
  226. validate_args: bool = True,
  227. zero_division: float = 0,
  228. ) -> Tensor:
  229. r"""Compute `Precision`_ for multilabel tasks.
  230. .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}
  231. Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and
  232. false positives respecitively.
  233. Accepts the following input tensors:
  234. - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside
  235. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  236. we convert to int tensor with thresholding using the value in ``threshold``.
  237. - ``target`` (int tensor): ``(N, C, ...)``
  238. Args:
  239. preds: Tensor with predictions
  240. target: Tensor with true labels
  241. num_labels: Integer specifying the number of labels
  242. threshold: Threshold for transforming probability to binary (0,1) predictions
  243. average:
  244. Defines the reduction that is applied over labels. Should be one of the following:
  245. - ``micro``: Sum statistics over all labels
  246. - ``macro``: Calculate statistics for each label and average them
  247. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  248. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  249. multidim_average:
  250. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  251. - ``global``: Additional dimensions are flatted along the batch dimension
  252. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  253. The statistics in this case are calculated over the additional dimensions.
  254. ignore_index:
  255. Specifies a target value that is ignored and does not contribute to the metric calculation
  256. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  257. Set to ``False`` for faster computations.
  258. zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
  259. Returns:
  260. The returned shape depends on the ``average`` and ``multidim_average`` arguments:
  261. - If ``multidim_average`` is set to ``global``:
  262. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  263. - If ``average=None/'none'``, the shape will be ``(C,)``
  264. - If ``multidim_average`` is set to ``samplewise``:
  265. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  266. - If ``average=None/'none'``, the shape will be ``(N, C)``
  267. Example (preds is int tensor):
  268. >>> from torch import tensor
  269. >>> from torchmetrics.functional.classification import multilabel_precision
  270. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  271. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  272. >>> multilabel_precision(preds, target, num_labels=3)
  273. tensor(0.5000)
  274. >>> multilabel_precision(preds, target, num_labels=3, average=None)
  275. tensor([1.0000, 0.0000, 0.5000])
  276. Example (preds is float tensor):
  277. >>> from torchmetrics.functional.classification import multilabel_precision
  278. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  279. >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
  280. >>> multilabel_precision(preds, target, num_labels=3)
  281. tensor(0.5000)
  282. >>> multilabel_precision(preds, target, num_labels=3, average=None)
  283. tensor([1.0000, 0.0000, 0.5000])
  284. Example (multidim tensors):
  285. >>> from torchmetrics.functional.classification import multilabel_precision
  286. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  287. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  288. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  289. >>> multilabel_precision(preds, target, num_labels=3, multidim_average='samplewise')
  290. tensor([0.3333, 0.0000])
  291. >>> multilabel_precision(preds, target, num_labels=3, multidim_average='samplewise', average=None)
  292. tensor([[0.5000, 0.5000, 0.0000],
  293. [0.0000, 0.0000, 0.0000]])
  294. """
  295. if validate_args:
  296. _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index)
  297. _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
  298. preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
  299. tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average)
  300. return _precision_recall_reduce(
  301. "precision",
  302. tp,
  303. fp,
  304. tn,
  305. fn,
  306. average=average,
  307. multidim_average=multidim_average,
  308. multilabel=True,
  309. zero_division=zero_division,
  310. )
  311. def binary_recall(
  312. preds: Tensor,
  313. target: Tensor,
  314. threshold: float = 0.5,
  315. multidim_average: Literal["global", "samplewise"] = "global",
  316. ignore_index: Optional[int] = None,
  317. validate_args: bool = True,
  318. zero_division: float = 0,
  319. ) -> Tensor:
  320. r"""Compute `Recall`_ for binary tasks.
  321. .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}
  322. Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and
  323. false negatives respecitively.
  324. Accepts the following input tensors:
  325. - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside
  326. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  327. we convert to int tensor with thresholding using the value in ``threshold``.
  328. - ``target`` (int tensor): ``(N, ...)``
  329. Args:
  330. preds: Tensor with predictions
  331. target: Tensor with true labels
  332. threshold: Threshold for transforming probability to binary {0,1} predictions
  333. multidim_average:
  334. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  335. - ``global``: Additional dimensions are flatted along the batch dimension
  336. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  337. The statistics in this case are calculated over the additional dimensions.
  338. ignore_index:
  339. Specifies a target value that is ignored and does not contribute to the metric calculation
  340. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  341. Set to ``False`` for faster computations.
  342. zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
  343. Returns:
  344. If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average``
  345. is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample.
  346. Example (preds is int tensor):
  347. >>> from torch import tensor
  348. >>> from torchmetrics.functional.classification import binary_recall
  349. >>> target = tensor([0, 1, 0, 1, 0, 1])
  350. >>> preds = tensor([0, 0, 1, 1, 0, 1])
  351. >>> binary_recall(preds, target)
  352. tensor(0.6667)
  353. Example (preds is float tensor):
  354. >>> from torchmetrics.functional.classification import binary_recall
  355. >>> target = tensor([0, 1, 0, 1, 0, 1])
  356. >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
  357. >>> binary_recall(preds, target)
  358. tensor(0.6667)
  359. Example (multidim tensors):
  360. >>> from torchmetrics.functional.classification import binary_recall
  361. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  362. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  363. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  364. >>> binary_recall(preds, target, multidim_average='samplewise')
  365. tensor([0.6667, 0.0000])
  366. """
  367. if validate_args:
  368. _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index)
  369. _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index)
  370. preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index)
  371. tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average)
  372. return _precision_recall_reduce(
  373. "recall", tp, fp, tn, fn, average="binary", multidim_average=multidim_average, zero_division=zero_division
  374. )
  375. def multiclass_recall(
  376. preds: Tensor,
  377. target: Tensor,
  378. num_classes: int,
  379. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  380. top_k: int = 1,
  381. multidim_average: Literal["global", "samplewise"] = "global",
  382. ignore_index: Optional[int] = None,
  383. validate_args: bool = True,
  384. zero_division: float = 0,
  385. ) -> Tensor:
  386. r"""Compute `Recall`_ for multiclass tasks.
  387. .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}
  388. Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and
  389. false negatives respecitively.
  390. Accepts the following input tensors:
  391. - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point
  392. we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into
  393. an int tensor.
  394. - ``target`` (int tensor): ``(N, ...)``
  395. Args:
  396. preds: Tensor with predictions
  397. target: Tensor with true labels
  398. num_classes: Integer specifying the number of classes
  399. average:
  400. Defines the reduction that is applied over labels. Should be one of the following:
  401. - ``micro``: Sum statistics over all labels
  402. - ``macro``: Calculate statistics for each label and average them
  403. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  404. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  405. top_k:
  406. Number of highest probability or logit score predictions considered to find the correct label.
  407. Only works when ``preds`` contain probabilities/logits.
  408. multidim_average:
  409. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  410. - ``global``: Additional dimensions are flatted along the batch dimension
  411. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  412. The statistics in this case are calculated over the additional dimensions.
  413. ignore_index:
  414. Specifies a target value that is ignored and does not contribute to the metric calculation
  415. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  416. Set to ``False`` for faster computations.
  417. zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
  418. Returns:
  419. The returned shape depends on the ``average`` and ``multidim_average`` arguments:
  420. - If ``multidim_average`` is set to ``global``:
  421. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  422. - If ``average=None/'none'``, the shape will be ``(C,)``
  423. - If ``multidim_average`` is set to ``samplewise``:
  424. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  425. - If ``average=None/'none'``, the shape will be ``(N, C)``
  426. Example (preds is int tensor):
  427. >>> from torch import tensor
  428. >>> from torchmetrics.functional.classification import multiclass_recall
  429. >>> target = tensor([2, 1, 0, 0])
  430. >>> preds = tensor([2, 1, 0, 1])
  431. >>> multiclass_recall(preds, target, num_classes=3)
  432. tensor(0.8333)
  433. >>> multiclass_recall(preds, target, num_classes=3, average=None)
  434. tensor([0.5000, 1.0000, 1.0000])
  435. Example (preds is float tensor):
  436. >>> from torchmetrics.functional.classification import multiclass_recall
  437. >>> target = tensor([2, 1, 0, 0])
  438. >>> preds = tensor([[0.16, 0.26, 0.58],
  439. ... [0.22, 0.61, 0.17],
  440. ... [0.71, 0.09, 0.20],
  441. ... [0.05, 0.82, 0.13]])
  442. >>> multiclass_recall(preds, target, num_classes=3)
  443. tensor(0.8333)
  444. >>> multiclass_recall(preds, target, num_classes=3, average=None)
  445. tensor([0.5000, 1.0000, 1.0000])
  446. Example (multidim tensors):
  447. >>> from torchmetrics.functional.classification import multiclass_recall
  448. >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
  449. >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
  450. >>> multiclass_recall(preds, target, num_classes=3, multidim_average='samplewise')
  451. tensor([0.5000, 0.2778])
  452. >>> multiclass_recall(preds, target, num_classes=3, multidim_average='samplewise', average=None)
  453. tensor([[1.0000, 0.0000, 0.5000],
  454. [0.0000, 0.3333, 0.5000]])
  455. """
  456. if validate_args:
  457. _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index)
  458. _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index)
  459. preds, target = _multiclass_stat_scores_format(preds, target, top_k)
  460. tp, fp, tn, fn = _multiclass_stat_scores_update(
  461. preds, target, num_classes, top_k, average, multidim_average, ignore_index
  462. )
  463. return _precision_recall_reduce(
  464. "recall",
  465. tp,
  466. fp,
  467. tn,
  468. fn,
  469. average=average,
  470. multidim_average=multidim_average,
  471. top_k=top_k,
  472. zero_division=zero_division,
  473. )
  474. def multilabel_recall(
  475. preds: Tensor,
  476. target: Tensor,
  477. num_labels: int,
  478. threshold: float = 0.5,
  479. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  480. multidim_average: Literal["global", "samplewise"] = "global",
  481. ignore_index: Optional[int] = None,
  482. validate_args: bool = True,
  483. zero_division: float = 0,
  484. ) -> Tensor:
  485. r"""Compute `Recall`_ for multilabel tasks.
  486. .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}
  487. Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and
  488. false negatives respecitively.
  489. Accepts the following input tensors:
  490. - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside
  491. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  492. we convert to int tensor with thresholding using the value in ``threshold``.
  493. - ``target`` (int tensor): ``(N, C, ...)``
  494. Args:
  495. preds: Tensor with predictions
  496. target: Tensor with true labels
  497. num_labels: Integer specifying the number of labels
  498. threshold: Threshold for transforming probability to binary (0,1) predictions
  499. average:
  500. Defines the reduction that is applied over labels. Should be one of the following:
  501. - ``micro``: Sum statistics over all labels
  502. - ``macro``: Calculate statistics for each label and average them
  503. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  504. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  505. multidim_average:
  506. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  507. - ``global``: Additional dimensions are flatted along the batch dimension
  508. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  509. The statistics in this case are calculated over the additional dimensions.
  510. ignore_index:
  511. Specifies a target value that is ignored and does not contribute to the metric calculation
  512. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  513. Set to ``False`` for faster computations.
  514. zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
  515. Returns:
  516. The returned shape depends on the ``average`` and ``multidim_average`` arguments:
  517. - If ``multidim_average`` is set to ``global``:
  518. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  519. - If ``average=None/'none'``, the shape will be ``(C,)``
  520. - If ``multidim_average`` is set to ``samplewise``:
  521. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  522. - If ``average=None/'none'``, the shape will be ``(N, C)``
  523. Example (preds is int tensor):
  524. >>> from torch import tensor
  525. >>> from torchmetrics.functional.classification import multilabel_recall
  526. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  527. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  528. >>> multilabel_recall(preds, target, num_labels=3)
  529. tensor(0.6667)
  530. >>> multilabel_recall(preds, target, num_labels=3, average=None)
  531. tensor([1., 0., 1.])
  532. Example (preds is float tensor):
  533. >>> from torchmetrics.functional.classification import multilabel_recall
  534. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  535. >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
  536. >>> multilabel_recall(preds, target, num_labels=3)
  537. tensor(0.6667)
  538. >>> multilabel_recall(preds, target, num_labels=3, average=None)
  539. tensor([1., 0., 1.])
  540. Example (multidim tensors):
  541. >>> from torchmetrics.functional.classification import multilabel_recall
  542. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  543. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  544. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  545. >>> multilabel_recall(preds, target, num_labels=3, multidim_average='samplewise')
  546. tensor([0.6667, 0.0000])
  547. >>> multilabel_recall(preds, target, num_labels=3, multidim_average='samplewise', average=None)
  548. tensor([[1., 1., 0.],
  549. [0., 0., 0.]])
  550. """
  551. if validate_args:
  552. _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index)
  553. _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
  554. preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
  555. tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average)
  556. return _precision_recall_reduce(
  557. "recall",
  558. tp,
  559. fp,
  560. tn,
  561. fn,
  562. average=average,
  563. multidim_average=multidim_average,
  564. multilabel=True,
  565. zero_division=zero_division,
  566. )
  567. def precision(
  568. preds: Tensor,
  569. target: Tensor,
  570. task: Literal["binary", "multiclass", "multilabel"],
  571. threshold: float = 0.5,
  572. num_classes: Optional[int] = None,
  573. num_labels: Optional[int] = None,
  574. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
  575. multidim_average: Optional[Literal["global", "samplewise"]] = "global",
  576. top_k: Optional[int] = 1,
  577. ignore_index: Optional[int] = None,
  578. validate_args: bool = True,
  579. zero_division: float = 0,
  580. ) -> Tensor:
  581. r"""Compute `Precision`_.
  582. .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}
  583. Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and
  584. false positives respecitively.
  585. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  586. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  587. :func:`~torchmetrics.functional.classification.binary_precision`,
  588. :func:`~torchmetrics.functional.classification.multiclass_precision` and
  589. :func:`~torchmetrics.functional.classification.multilabel_precision` for the specific details of
  590. each argument influence and examples.
  591. Legacy Example:
  592. >>> from torch import tensor
  593. >>> preds = tensor([2, 0, 2, 1])
  594. >>> target = tensor([1, 1, 2, 0])
  595. >>> precision(preds, target, task="multiclass", average='macro', num_classes=3)
  596. tensor(0.1667)
  597. >>> precision(preds, target, task="multiclass", average='micro', num_classes=3)
  598. tensor(0.2500)
  599. """
  600. assert multidim_average is not None # noqa: S101 # needed for mypy
  601. if task == ClassificationTask.BINARY:
  602. return binary_precision(preds, target, threshold, multidim_average, ignore_index, validate_args, zero_division)
  603. if task == ClassificationTask.MULTICLASS:
  604. if not isinstance(num_classes, int):
  605. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  606. if not isinstance(top_k, int):
  607. raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`")
  608. return multiclass_precision(
  609. preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args, zero_division
  610. )
  611. if task == ClassificationTask.MULTILABEL:
  612. if not isinstance(num_labels, int):
  613. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  614. return multilabel_precision(
  615. preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args, zero_division
  616. )
  617. raise ValueError(
  618. f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
  619. )
  620. def recall(
  621. preds: Tensor,
  622. target: Tensor,
  623. task: Literal["binary", "multiclass", "multilabel"],
  624. threshold: float = 0.5,
  625. num_classes: Optional[int] = None,
  626. num_labels: Optional[int] = None,
  627. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
  628. multidim_average: Optional[Literal["global", "samplewise"]] = "global",
  629. top_k: Optional[int] = 1,
  630. ignore_index: Optional[int] = None,
  631. validate_args: bool = True,
  632. zero_division: float = 0,
  633. ) -> Tensor:
  634. r"""Compute `Recall`_.
  635. .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}
  636. Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and
  637. false negatives respecitively.
  638. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  639. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  640. :func:`~torchmetrics.functional.classification.binary_recall`,
  641. :func:`~torchmetrics.functional.classification.multiclass_recall` and
  642. :func:`~torchmetrics.functional.classification.multilabel_recall` for the specific details of
  643. each argument influence and examples.
  644. Legacy Example:
  645. >>> from torch import tensor
  646. >>> preds = tensor([2, 0, 2, 1])
  647. >>> target = tensor([1, 1, 2, 0])
  648. >>> recall(preds, target, task="multiclass", average='macro', num_classes=3)
  649. tensor(0.3333)
  650. >>> recall(preds, target, task="multiclass", average='micro', num_classes=3)
  651. tensor(0.2500)
  652. """
  653. task = ClassificationTask.from_str(task)
  654. assert multidim_average is not None # noqa: S101 # needed for mypy
  655. if task == ClassificationTask.BINARY:
  656. return binary_recall(preds, target, threshold, multidim_average, ignore_index, validate_args, zero_division)
  657. if task == ClassificationTask.MULTICLASS:
  658. if not isinstance(num_classes, int):
  659. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  660. if not isinstance(top_k, int):
  661. raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`")
  662. return multiclass_recall(
  663. preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args, zero_division
  664. )
  665. if task == ClassificationTask.MULTILABEL:
  666. if not isinstance(num_labels, int):
  667. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  668. return multilabel_recall(
  669. preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args, zero_division
  670. )
  671. raise ValueError(f"Not handled value: {task}")