precision_recall.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.classification.base import _ClassificationTaskWrapper
  19. from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores
  20. from torchmetrics.functional.classification.precision_recall import (
  21. _precision_recall_reduce,
  22. )
  23. from torchmetrics.metric import Metric
  24. from torchmetrics.utilities.enums import ClassificationTask
  25. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  26. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  27. if not _MATPLOTLIB_AVAILABLE:
  28. __doctest_skip__ = [
  29. "BinaryPrecision.plot",
  30. "MulticlassPrecision.plot",
  31. "MultilabelPrecision.plot",
  32. "BinaryRecall.plot",
  33. "MulticlassRecall.plot",
  34. "MultilabelRecall.plot",
  35. ]
  36. class BinaryPrecision(BinaryStatScores):
  37. r"""Compute `Precision`_ for binary tasks.
  38. .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}
  39. Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
  40. respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
  41. encountered a score of `zero_division` (0 or 1, default is 0) is returned.
  42. As input to ``forward`` and ``update`` the metric accepts the following input:
  43. - ``preds`` (:class:`~torch.Tensor`): A int or float tensor of shape ``(N, ...)``. If preds is a floating point
  44. tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per
  45. element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
  46. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``.
  47. As output to ``forward`` and ``compute`` the metric returns the following output:
  48. - ``bp`` (:class:`~torch.Tensor`): If ``multidim_average`` is set to ``global``, the metric returns a scalar
  49. value. If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a
  50. scalar value per sample.
  51. If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
  52. which the reduction will then be applied over instead of the sample dimension ``N``.
  53. Args:
  54. threshold: Threshold for transforming probability to binary {0,1} predictions
  55. multidim_average:
  56. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  57. - ``global``: Additional dimensions are flatted along the batch dimension
  58. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  59. The statistics in this case are calculated over the additional dimensions.
  60. ignore_index:
  61. Specifies a target value that is ignored and does not contribute to the metric calculation
  62. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  63. Set to ``False`` for faster computations.
  64. zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
  65. Example (preds is int tensor):
  66. >>> from torch import tensor
  67. >>> from torchmetrics.classification import BinaryPrecision
  68. >>> target = tensor([0, 1, 0, 1, 0, 1])
  69. >>> preds = tensor([0, 0, 1, 1, 0, 1])
  70. >>> metric = BinaryPrecision()
  71. >>> metric(preds, target)
  72. tensor(0.6667)
  73. Example (preds is float tensor):
  74. >>> from torchmetrics.classification import BinaryPrecision
  75. >>> target = tensor([0, 1, 0, 1, 0, 1])
  76. >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
  77. >>> metric = BinaryPrecision()
  78. >>> metric(preds, target)
  79. tensor(0.6667)
  80. Example (multidim tensors):
  81. >>> from torchmetrics.classification import BinaryPrecision
  82. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  83. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  84. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  85. >>> metric = BinaryPrecision(multidim_average='samplewise')
  86. >>> metric(preds, target)
  87. tensor([0.4000, 0.0000])
  88. """
  89. is_differentiable: bool = False
  90. higher_is_better: Optional[bool] = True
  91. full_state_update: bool = False
  92. plot_lower_bound: float = 0.0
  93. plot_upper_bound: float = 1.0
  94. def compute(self) -> Tensor:
  95. """Compute metric."""
  96. tp, fp, tn, fn = self._final_state()
  97. return _precision_recall_reduce(
  98. "precision",
  99. tp,
  100. fp,
  101. tn,
  102. fn,
  103. average="binary",
  104. multidim_average=self.multidim_average,
  105. zero_division=self.zero_division,
  106. )
  107. def plot(
  108. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  109. ) -> _PLOT_OUT_TYPE:
  110. """Plot a single or multiple values from the metric.
  111. Args:
  112. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  113. If no value is provided, will automatically call `metric.compute` and plot that result.
  114. ax: An matplotlib axis object. If provided will add plot to that axis
  115. Returns:
  116. Figure object and Axes object
  117. Raises:
  118. ModuleNotFoundError:
  119. If `matplotlib` is not installed
  120. .. plot::
  121. :scale: 75
  122. >>> from torch import rand, randint
  123. >>> # Example plotting a single value
  124. >>> from torchmetrics.classification import BinaryPrecision
  125. >>> metric = BinaryPrecision()
  126. >>> metric.update(rand(10), randint(2,(10,)))
  127. >>> fig_, ax_ = metric.plot()
  128. .. plot::
  129. :scale: 75
  130. >>> from torch import rand, randint
  131. >>> # Example plotting multiple values
  132. >>> from torchmetrics.classification import BinaryPrecision
  133. >>> metric = BinaryPrecision()
  134. >>> values = [ ]
  135. >>> for _ in range(10):
  136. ... values.append(metric(rand(10), randint(2,(10,))))
  137. >>> fig_, ax_ = metric.plot(values)
  138. """
  139. return self._plot(val, ax)
  140. class MulticlassPrecision(MulticlassStatScores):
  141. r"""Compute `Precision`_ for multiclass tasks.
  142. .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}
  143. Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
  144. respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
  145. encountered for any class, the metric for that class will be set to `zero_division` (0 or 1, default is 0) and
  146. the overall metric may therefore be affected in turn.
  147. As input to ``forward`` and ``update`` the metric accepts the following input:
  148. - ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` or float tensor of shape ``(N, C, ..)``.
  149. If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert
  150. probabilities/logits into an int tensor.
  151. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``.
  152. As output to ``forward`` and ``compute`` the metric returns the following output:
  153. - ``mcp`` (:class:`~torch.Tensor`): The returned shape depends on the ``average`` and ``multidim_average``
  154. arguments:
  155. - If ``multidim_average`` is set to ``global``:
  156. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  157. - If ``average=None/'none'``, the shape will be ``(C,)``
  158. - If ``multidim_average`` is set to ``samplewise``:
  159. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  160. - If ``average=None/'none'``, the shape will be ``(N, C)``
  161. If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
  162. which the reduction will then be applied over instead of the sample dimension ``N``.
  163. Args:
  164. num_classes: Integer specifying the number of classes
  165. average:
  166. Defines the reduction that is applied over labels. Should be one of the following:
  167. - ``micro``: Sum statistics over all labels
  168. - ``macro``: Calculate statistics for each label and average them
  169. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  170. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  171. top_k:
  172. Number of highest probability or logit score predictions considered to find the correct label.
  173. Only works when ``preds`` contain probabilities/logits.
  174. multidim_average:
  175. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  176. - ``global``: Additional dimensions are flatted along the batch dimension
  177. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  178. The statistics in this case are calculated over the additional dimensions.
  179. ignore_index:
  180. Specifies a target value that is ignored and does not contribute to the metric calculation
  181. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  182. Set to ``False`` for faster computations.
  183. zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
  184. Example (preds is int tensor):
  185. >>> from torch import tensor
  186. >>> from torchmetrics.classification import MulticlassPrecision
  187. >>> target = tensor([2, 1, 0, 0])
  188. >>> preds = tensor([2, 1, 0, 1])
  189. >>> metric = MulticlassPrecision(num_classes=3)
  190. >>> metric(preds, target)
  191. tensor(0.8333)
  192. >>> mcp = MulticlassPrecision(num_classes=3, average=None)
  193. >>> mcp(preds, target)
  194. tensor([1.0000, 0.5000, 1.0000])
  195. Example (preds is float tensor):
  196. >>> from torchmetrics.classification import MulticlassPrecision
  197. >>> target = tensor([2, 1, 0, 0])
  198. >>> preds = tensor([[0.16, 0.26, 0.58],
  199. ... [0.22, 0.61, 0.17],
  200. ... [0.71, 0.09, 0.20],
  201. ... [0.05, 0.82, 0.13]])
  202. >>> metric = MulticlassPrecision(num_classes=3)
  203. >>> metric(preds, target)
  204. tensor(0.8333)
  205. >>> mcp = MulticlassPrecision(num_classes=3, average=None)
  206. >>> mcp(preds, target)
  207. tensor([1.0000, 0.5000, 1.0000])
  208. Example (multidim tensors):
  209. >>> from torchmetrics.classification import MulticlassPrecision
  210. >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
  211. >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
  212. >>> metric = MulticlassPrecision(num_classes=3, multidim_average='samplewise')
  213. >>> metric(preds, target)
  214. tensor([0.3889, 0.2778])
  215. >>> mcp = MulticlassPrecision(num_classes=3, multidim_average='samplewise', average=None)
  216. >>> mcp(preds, target)
  217. tensor([[0.6667, 0.0000, 0.5000],
  218. [0.0000, 0.5000, 0.3333]])
  219. """
  220. is_differentiable: bool = False
  221. higher_is_better: Optional[bool] = True
  222. full_state_update: bool = False
  223. plot_lower_bound: float = 0.0
  224. plot_upper_bound: float = 1.0
  225. plot_legend_name: str = "Class"
  226. def compute(self) -> Tensor:
  227. """Compute metric."""
  228. tp, fp, tn, fn = self._final_state()
  229. return _precision_recall_reduce(
  230. "precision",
  231. tp,
  232. fp,
  233. tn,
  234. fn,
  235. average=self.average,
  236. multidim_average=self.multidim_average,
  237. top_k=self.top_k,
  238. zero_division=self.zero_division,
  239. )
  240. def plot(
  241. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  242. ) -> _PLOT_OUT_TYPE:
  243. """Plot a single or multiple values from the metric.
  244. Args:
  245. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  246. If no value is provided, will automatically call `metric.compute` and plot that result.
  247. ax: An matplotlib axis object. If provided will add plot to that axis
  248. Returns:
  249. Figure object and Axes object
  250. Raises:
  251. ModuleNotFoundError:
  252. If `matplotlib` is not installed
  253. .. plot::
  254. :scale: 75
  255. >>> from torch import randint
  256. >>> # Example plotting a single value per class
  257. >>> from torchmetrics.classification import MulticlassPrecision
  258. >>> metric = MulticlassPrecision(num_classes=3, average=None)
  259. >>> metric.update(randint(3, (20,)), randint(3, (20,)))
  260. >>> fig_, ax_ = metric.plot()
  261. .. plot::
  262. :scale: 75
  263. >>> from torch import randint
  264. >>> # Example plotting a multiple values per class
  265. >>> from torchmetrics.classification import MulticlassPrecision
  266. >>> metric = MulticlassPrecision(num_classes=3, average=None)
  267. >>> values = []
  268. >>> for _ in range(20):
  269. ... values.append(metric(randint(3, (20,)), randint(3, (20,))))
  270. >>> fig_, ax_ = metric.plot(values)
  271. """
  272. return self._plot(val, ax)
  273. class MultilabelPrecision(MultilabelStatScores):
  274. r"""Compute `Precision`_ for multilabel tasks.
  275. .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}
  276. Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
  277. respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
  278. encountered for any label, the metric for that label will be set to `zero_division` (0 or 1, default is 0) and
  279. the overall metric may therefore be affected in turn.
  280. As input to ``forward`` and ``update`` the metric accepts the following input:
  281. - ``preds`` (:class:`~torch.Tensor`): An int tensor or float tensor of shape ``(N, C, ...)``.
  282. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and
  283. will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value
  284. in ``threshold``.
  285. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``.
  286. As output to ``forward`` and ``compute`` the metric returns the following output:
  287. - ``mlp`` (:class:`~torch.Tensor`): The returned shape depends on the ``average`` and ``multidim_average``
  288. arguments:
  289. - If ``multidim_average`` is set to ``global``:
  290. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  291. - If ``average=None/'none'``, the shape will be ``(C,)``
  292. - If ``multidim_average`` is set to ``samplewise``:
  293. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  294. - If ``average=None/'none'``, the shape will be ``(N, C)``
  295. If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
  296. which the reduction will then be applied over instead of the sample dimension ``N``.
  297. Args:
  298. num_labels: Integer specifying the number of labels
  299. threshold: Threshold for transforming probability to binary (0,1) predictions
  300. average:
  301. Defines the reduction that is applied over labels. Should be one of the following:
  302. - ``micro``: Sum statistics over all labels
  303. - ``macro``: Calculate statistics for each label and average them
  304. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  305. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  306. multidim_average:
  307. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  308. - ``global``: Additional dimensions are flatted along the batch dimension
  309. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  310. The statistics in this case are calculated over the additional dimensions.
  311. ignore_index:
  312. Specifies a target value that is ignored and does not contribute to the metric calculation
  313. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  314. Set to ``False`` for faster computations.
  315. zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
  316. Example (preds is int tensor):
  317. >>> from torch import tensor
  318. >>> from torchmetrics.classification import MultilabelPrecision
  319. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  320. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  321. >>> metric = MultilabelPrecision(num_labels=3)
  322. >>> metric(preds, target)
  323. tensor(0.5000)
  324. >>> mlp = MultilabelPrecision(num_labels=3, average=None)
  325. >>> mlp(preds, target)
  326. tensor([1.0000, 0.0000, 0.5000])
  327. Example (preds is float tensor):
  328. >>> from torchmetrics.classification import MultilabelPrecision
  329. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  330. >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
  331. >>> metric = MultilabelPrecision(num_labels=3)
  332. >>> metric(preds, target)
  333. tensor(0.5000)
  334. >>> mlp = MultilabelPrecision(num_labels=3, average=None)
  335. >>> mlp(preds, target)
  336. tensor([1.0000, 0.0000, 0.5000])
  337. Example (multidim tensors):
  338. >>> from torchmetrics.classification import MultilabelPrecision
  339. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  340. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  341. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  342. >>> metric = MultilabelPrecision(num_labels=3, multidim_average='samplewise')
  343. >>> metric(preds, target)
  344. tensor([0.3333, 0.0000])
  345. >>> mlp = MultilabelPrecision(num_labels=3, multidim_average='samplewise', average=None)
  346. >>> mlp(preds, target)
  347. tensor([[0.5000, 0.5000, 0.0000],
  348. [0.0000, 0.0000, 0.0000]])
  349. """
  350. is_differentiable: bool = False
  351. higher_is_better: Optional[bool] = True
  352. full_state_update: bool = False
  353. plot_lower_bound: float = 0.0
  354. plot_upper_bound: float = 1.0
  355. plot_legend_name: str = "Label"
  356. def compute(self) -> Tensor:
  357. """Compute metric."""
  358. tp, fp, tn, fn = self._final_state()
  359. return _precision_recall_reduce(
  360. "precision",
  361. tp,
  362. fp,
  363. tn,
  364. fn,
  365. average=self.average,
  366. multidim_average=self.multidim_average,
  367. multilabel=True,
  368. zero_division=self.zero_division,
  369. )
  370. def plot(
  371. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  372. ) -> _PLOT_OUT_TYPE:
  373. """Plot a single or multiple values from the metric.
  374. Args:
  375. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  376. If no value is provided, will automatically call `metric.compute` and plot that result.
  377. ax: An matplotlib axis object. If provided will add plot to that axis
  378. Returns:
  379. Figure object and Axes object
  380. Raises:
  381. ModuleNotFoundError:
  382. If `matplotlib` is not installed
  383. .. plot::
  384. :scale: 75
  385. >>> from torch import rand, randint
  386. >>> # Example plotting a single value
  387. >>> from torchmetrics.classification import MultilabelPrecision
  388. >>> metric = MultilabelPrecision(num_labels=3)
  389. >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
  390. >>> fig_, ax_ = metric.plot()
  391. .. plot::
  392. :scale: 75
  393. >>> from torch import rand, randint
  394. >>> # Example plotting multiple values
  395. >>> from torchmetrics.classification import MultilabelPrecision
  396. >>> metric = MultilabelPrecision(num_labels=3)
  397. >>> values = [ ]
  398. >>> for _ in range(10):
  399. ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
  400. >>> fig_, ax_ = metric.plot(values)
  401. """
  402. return self._plot(val, ax)
  403. class BinaryRecall(BinaryStatScores):
  404. r"""Compute `Recall`_ for binary tasks.
  405. .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}
  406. Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
  407. respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
  408. encountered a score of `zero_division` (0 or 1, default is 0) is returned.
  409. As input to ``forward`` and ``update`` the metric accepts the following input:
  410. - ``preds`` (:class:`~torch.Tensor`): An int tensor or float tensor of shape ``(N, ...)``. If preds is a
  411. floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply
  412. sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
  413. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``
  414. As output to ``forward`` and ``compute`` the metric returns the following output:
  415. - ``br`` (:class:`~torch.Tensor`): If ``multidim_average`` is set to ``global``, the metric returns a scalar
  416. value. If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of
  417. a scalar value per sample.
  418. If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
  419. which the reduction will then be applied over instead of the sample dimension ``N``.
  420. Args:
  421. threshold: Threshold for transforming probability to binary {0,1} predictions
  422. multidim_average:
  423. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  424. - ``global``: Additional dimensions are flatted along the batch dimension
  425. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  426. The statistics in this case are calculated over the additional dimensions.
  427. ignore_index:
  428. Specifies a target value that is ignored and does not contribute to the metric calculation
  429. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  430. Set to ``False`` for faster computations.
  431. zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
  432. Example (preds is int tensor):
  433. >>> from torch import tensor
  434. >>> from torchmetrics.classification import BinaryRecall
  435. >>> target = tensor([0, 1, 0, 1, 0, 1])
  436. >>> preds = tensor([0, 0, 1, 1, 0, 1])
  437. >>> metric = BinaryRecall()
  438. >>> metric(preds, target)
  439. tensor(0.6667)
  440. Example (preds is float tensor):
  441. >>> from torchmetrics.classification import BinaryRecall
  442. >>> target = tensor([0, 1, 0, 1, 0, 1])
  443. >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
  444. >>> metric = BinaryRecall()
  445. >>> metric(preds, target)
  446. tensor(0.6667)
  447. Example (multidim tensors):
  448. >>> from torchmetrics.classification import BinaryRecall
  449. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  450. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  451. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  452. >>> metric = BinaryRecall(multidim_average='samplewise')
  453. >>> metric(preds, target)
  454. tensor([0.6667, 0.0000])
  455. """
  456. is_differentiable: bool = False
  457. higher_is_better: Optional[bool] = True
  458. full_state_update: bool = False
  459. plot_lower_bound: float = 0.0
  460. plot_upper_bound: float = 1.0
  461. def compute(self) -> Tensor:
  462. """Compute metric."""
  463. tp, fp, tn, fn = self._final_state()
  464. return _precision_recall_reduce(
  465. "recall",
  466. tp,
  467. fp,
  468. tn,
  469. fn,
  470. average="binary",
  471. multidim_average=self.multidim_average,
  472. zero_division=self.zero_division,
  473. )
  474. def plot(
  475. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  476. ) -> _PLOT_OUT_TYPE:
  477. """Plot a single or multiple values from the metric.
  478. Args:
  479. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  480. If no value is provided, will automatically call `metric.compute` and plot that result.
  481. ax: An matplotlib axis object. If provided will add plot to that axis
  482. Returns:
  483. Figure object and Axes object
  484. Raises:
  485. ModuleNotFoundError:
  486. If `matplotlib` is not installed
  487. .. plot::
  488. :scale: 75
  489. >>> from torch import rand, randint
  490. >>> # Example plotting a single value
  491. >>> from torchmetrics.classification import BinaryRecall
  492. >>> metric = BinaryRecall()
  493. >>> metric.update(rand(10), randint(2,(10,)))
  494. >>> fig_, ax_ = metric.plot()
  495. .. plot::
  496. :scale: 75
  497. >>> from torch import rand, randint
  498. >>> # Example plotting multiple values
  499. >>> from torchmetrics.classification import BinaryRecall
  500. >>> metric = BinaryRecall()
  501. >>> values = [ ]
  502. >>> for _ in range(10):
  503. ... values.append(metric(rand(10), randint(2,(10,))))
  504. >>> fig_, ax_ = metric.plot(values)
  505. """
  506. return self._plot(val, ax)
  507. class MulticlassRecall(MulticlassStatScores):
  508. r"""Compute `Recall`_ for multiclass tasks.
  509. .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}
  510. Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
  511. respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
  512. encountered for any class, the metric for that class will be set to `zero_division` (0 or 1, default is 0) and
  513. the overall metric may therefore be affected in turn.
  514. As input to ``forward`` and ``update`` the metric accepts the following input:
  515. - ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` or float tensor of shape ``(N, C, ..)``
  516. If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert
  517. probabilities/logits into an int tensor.
  518. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``
  519. As output to ``forward`` and ``compute`` the metric returns the following output:
  520. - ``mcr`` (:class:`~torch.Tensor`): The returned shape depends on the ``average`` and ``multidim_average``
  521. arguments:
  522. - If ``multidim_average`` is set to ``global``:
  523. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  524. - If ``average=None/'none'``, the shape will be ``(C,)``
  525. - If ``multidim_average`` is set to ``samplewise``:
  526. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  527. - If ``average=None/'none'``, the shape will be ``(N, C)``
  528. If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
  529. which the reduction will then be applied over instead of the sample dimension ``N``.
  530. Args:
  531. num_classes: Integer specifying the number of classes
  532. average:
  533. Defines the reduction that is applied over labels. Should be one of the following:
  534. - ``micro``: Sum statistics over all labels
  535. - ``macro``: Calculate statistics for each label and average them
  536. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  537. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  538. top_k:
  539. Number of highest probability or logit score predictions considered to find the correct label.
  540. Only works when ``preds`` contain probabilities/logits.
  541. multidim_average:
  542. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  543. - ``global``: Additional dimensions are flatted along the batch dimension
  544. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  545. The statistics in this case are calculated over the additional dimensions.
  546. ignore_index:
  547. Specifies a target value that is ignored and does not contribute to the metric calculation
  548. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  549. Set to ``False`` for faster computations.
  550. zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
  551. Example (preds is int tensor):
  552. >>> from torch import tensor
  553. >>> from torchmetrics.classification import MulticlassRecall
  554. >>> target = tensor([2, 1, 0, 0])
  555. >>> preds = tensor([2, 1, 0, 1])
  556. >>> metric = MulticlassRecall(num_classes=3)
  557. >>> metric(preds, target)
  558. tensor(0.8333)
  559. >>> mcr = MulticlassRecall(num_classes=3, average=None)
  560. >>> mcr(preds, target)
  561. tensor([0.5000, 1.0000, 1.0000])
  562. Example (preds is float tensor):
  563. >>> from torchmetrics.classification import MulticlassRecall
  564. >>> target = tensor([2, 1, 0, 0])
  565. >>> preds = tensor([[0.16, 0.26, 0.58],
  566. ... [0.22, 0.61, 0.17],
  567. ... [0.71, 0.09, 0.20],
  568. ... [0.05, 0.82, 0.13]])
  569. >>> metric = MulticlassRecall(num_classes=3)
  570. >>> metric(preds, target)
  571. tensor(0.8333)
  572. >>> mcr = MulticlassRecall(num_classes=3, average=None)
  573. >>> mcr(preds, target)
  574. tensor([0.5000, 1.0000, 1.0000])
  575. Example (multidim tensors):
  576. >>> from torchmetrics.classification import MulticlassRecall
  577. >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
  578. >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
  579. >>> metric = MulticlassRecall(num_classes=3, multidim_average='samplewise')
  580. >>> metric(preds, target)
  581. tensor([0.5000, 0.2778])
  582. >>> mcr = MulticlassRecall(num_classes=3, multidim_average='samplewise', average=None)
  583. >>> mcr(preds, target)
  584. tensor([[1.0000, 0.0000, 0.5000],
  585. [0.0000, 0.3333, 0.5000]])
  586. """
  587. is_differentiable: bool = False
  588. higher_is_better: Optional[bool] = True
  589. full_state_update: bool = False
  590. plot_lower_bound: float = 0.0
  591. plot_upper_bound: float = 1.0
  592. plot_legend_name: str = "Class"
  593. def compute(self) -> Tensor:
  594. """Compute metric."""
  595. tp, fp, tn, fn = self._final_state()
  596. return _precision_recall_reduce(
  597. "recall",
  598. tp,
  599. fp,
  600. tn,
  601. fn,
  602. average=self.average,
  603. multidim_average=self.multidim_average,
  604. top_k=self.top_k,
  605. zero_division=self.zero_division,
  606. )
  607. def plot(
  608. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  609. ) -> _PLOT_OUT_TYPE:
  610. """Plot a single or multiple values from the metric.
  611. Args:
  612. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  613. If no value is provided, will automatically call `metric.compute` and plot that result.
  614. ax: An matplotlib axis object. If provided will add plot to that axis
  615. Returns:
  616. Figure object and Axes object
  617. Raises:
  618. ModuleNotFoundError:
  619. If `matplotlib` is not installed
  620. .. plot::
  621. :scale: 75
  622. >>> from torch import randint
  623. >>> # Example plotting a single value per class
  624. >>> from torchmetrics.classification import MulticlassRecall
  625. >>> metric = MulticlassRecall(num_classes=3, average=None)
  626. >>> metric.update(randint(3, (20,)), randint(3, (20,)))
  627. >>> fig_, ax_ = metric.plot()
  628. .. plot::
  629. :scale: 75
  630. >>> from torch import randint
  631. >>> # Example plotting a multiple values per class
  632. >>> from torchmetrics.classification import MulticlassRecall
  633. >>> metric = MulticlassRecall(num_classes=3, average=None)
  634. >>> values = []
  635. >>> for _ in range(20):
  636. ... values.append(metric(randint(3, (20,)), randint(3, (20,))))
  637. >>> fig_, ax_ = metric.plot(values)
  638. """
  639. return self._plot(val, ax)
  640. class MultilabelRecall(MultilabelStatScores):
  641. r"""Compute `Recall`_ for multilabel tasks.
  642. .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}
  643. Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
  644. respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
  645. encountered for any label, the metric for that label will be set to `zero_division` (0 or 1, default is 0) and
  646. the overall metric may therefore be affected in turn.
  647. As input to ``forward`` and ``update`` the metric accepts the following input:
  648. - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, C, ...)``. If preds is a floating
  649. point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid
  650. per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
  651. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``
  652. As output to ``forward`` and ``compute`` the metric returns the following output:
  653. - ``mlr`` (:class:`~torch.Tensor`): The returned shape depends on the ``average`` and ``multidim_average``
  654. arguments:
  655. - If ``multidim_average`` is set to ``global``:
  656. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  657. - If ``average=None/'none'``, the shape will be ``(C,)``
  658. - If ``multidim_average`` is set to ``samplewise``:
  659. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  660. - If ``average=None/'none'``, the shape will be ``(N, C)``
  661. If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
  662. which the reduction will then be applied over instead of the sample dimension ``N``.
  663. Args:
  664. num_labels: Integer specifying the number of labels
  665. threshold: Threshold for transforming probability to binary (0,1) predictions
  666. average:
  667. Defines the reduction that is applied over labels. Should be one of the following:
  668. - ``micro``: Sum statistics over all labels
  669. - ``macro``: Calculate statistics for each label and average them
  670. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  671. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  672. multidim_average:
  673. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  674. - ``global``: Additional dimensions are flatted along the batch dimension
  675. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  676. The statistics in this case are calculated over the additional dimensions.
  677. ignore_index:
  678. Specifies a target value that is ignored and does not contribute to the metric calculation
  679. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  680. Set to ``False`` for faster computations.
  681. zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
  682. Example (preds is int tensor):
  683. >>> from torch import tensor
  684. >>> from torchmetrics.classification import MultilabelRecall
  685. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  686. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  687. >>> metric = MultilabelRecall(num_labels=3)
  688. >>> metric(preds, target)
  689. tensor(0.6667)
  690. >>> mlr = MultilabelRecall(num_labels=3, average=None)
  691. >>> mlr(preds, target)
  692. tensor([1., 0., 1.])
  693. Example (preds is float tensor):
  694. >>> from torchmetrics.classification import MultilabelRecall
  695. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  696. >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
  697. >>> metric = MultilabelRecall(num_labels=3)
  698. >>> metric(preds, target)
  699. tensor(0.6667)
  700. >>> mlr = MultilabelRecall(num_labels=3, average=None)
  701. >>> mlr(preds, target)
  702. tensor([1., 0., 1.])
  703. Example (multidim tensors):
  704. >>> from torchmetrics.classification import MultilabelRecall
  705. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  706. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  707. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  708. >>> metric = MultilabelRecall(num_labels=3, multidim_average='samplewise')
  709. >>> metric(preds, target)
  710. tensor([0.6667, 0.0000])
  711. >>> mlr = MultilabelRecall(num_labels=3, multidim_average='samplewise', average=None)
  712. >>> mlr(preds, target)
  713. tensor([[1., 1., 0.],
  714. [0., 0., 0.]])
  715. """
  716. is_differentiable: bool = False
  717. higher_is_better: Optional[bool] = True
  718. full_state_update: bool = False
  719. plot_lower_bound: float = 0.0
  720. plot_upper_bound: float = 1.0
  721. plot_legend_name: str = "Label"
  722. def compute(self) -> Tensor:
  723. """Compute metric."""
  724. tp, fp, tn, fn = self._final_state()
  725. return _precision_recall_reduce(
  726. "recall",
  727. tp,
  728. fp,
  729. tn,
  730. fn,
  731. average=self.average,
  732. multidim_average=self.multidim_average,
  733. multilabel=True,
  734. zero_division=self.zero_division,
  735. )
  736. def plot(
  737. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  738. ) -> _PLOT_OUT_TYPE:
  739. """Plot a single or multiple values from the metric.
  740. Args:
  741. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  742. If no value is provided, will automatically call `metric.compute` and plot that result.
  743. ax: An matplotlib axis object. If provided will add plot to that axis
  744. Returns:
  745. Figure object and Axes object
  746. Raises:
  747. ModuleNotFoundError:
  748. If `matplotlib` is not installed
  749. .. plot::
  750. :scale: 75
  751. >>> from torch import rand, randint
  752. >>> # Example plotting a single value
  753. >>> from torchmetrics.classification import MultilabelRecall
  754. >>> metric = MultilabelRecall(num_labels=3)
  755. >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
  756. >>> fig_, ax_ = metric.plot()
  757. .. plot::
  758. :scale: 75
  759. >>> from torch import rand, randint
  760. >>> # Example plotting multiple values
  761. >>> from torchmetrics.classification import MultilabelRecall
  762. >>> metric = MultilabelRecall(num_labels=3)
  763. >>> values = [ ]
  764. >>> for _ in range(10):
  765. ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
  766. >>> fig_, ax_ = metric.plot(values)
  767. """
  768. return self._plot(val, ax)
  769. class Precision(_ClassificationTaskWrapper):
  770. r"""Compute `Precision`_.
  771. .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}
  772. Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
  773. respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
  774. encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may
  775. therefore be affected in turn.
  776. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  777. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  778. :class:`~torchmetrics.classification.BinaryPrecision`, :class:`~torchmetrics.classification.MulticlassPrecision` and
  779. :class:`~torchmetrics.classification.MultilabelPrecision` for the specific details of each argument influence and
  780. examples.
  781. Legacy Example:
  782. >>> from torch import tensor
  783. >>> preds = tensor([2, 0, 2, 1])
  784. >>> target = tensor([1, 1, 2, 0])
  785. >>> precision = Precision(task="multiclass", average='macro', num_classes=3)
  786. >>> precision(preds, target)
  787. tensor(0.1667)
  788. >>> precision = Precision(task="multiclass", average='micro', num_classes=3)
  789. >>> precision(preds, target)
  790. tensor(0.2500)
  791. """
  792. def __new__( # type: ignore[misc]
  793. cls: type["Precision"],
  794. task: Literal["binary", "multiclass", "multilabel"],
  795. threshold: float = 0.5,
  796. num_classes: Optional[int] = None,
  797. num_labels: Optional[int] = None,
  798. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
  799. multidim_average: Optional[Literal["global", "samplewise"]] = "global",
  800. top_k: Optional[int] = 1,
  801. ignore_index: Optional[int] = None,
  802. validate_args: bool = True,
  803. **kwargs: Any,
  804. ) -> Metric:
  805. """Initialize task metric."""
  806. assert multidim_average is not None # noqa: S101 # needed for mypy
  807. kwargs.update({
  808. "multidim_average": multidim_average,
  809. "ignore_index": ignore_index,
  810. "validate_args": validate_args,
  811. })
  812. task = ClassificationTask.from_str(task)
  813. if task == ClassificationTask.BINARY:
  814. return BinaryPrecision(threshold, **kwargs)
  815. if task == ClassificationTask.MULTICLASS:
  816. if not isinstance(num_classes, int):
  817. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  818. if not isinstance(top_k, int):
  819. raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`")
  820. return MulticlassPrecision(num_classes, top_k, average, **kwargs)
  821. if task == ClassificationTask.MULTILABEL:
  822. if not isinstance(num_labels, int):
  823. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  824. return MultilabelPrecision(num_labels, threshold, average, **kwargs)
  825. raise ValueError(f"Task {task} not supported!")
  826. class Recall(_ClassificationTaskWrapper):
  827. r"""Compute `Recall`_.
  828. .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}
  829. Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and
  830. false negatives respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this
  831. case is encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may
  832. therefore be affected in turn.
  833. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  834. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  835. :class:`~torchmetrics.classification.BinaryRecall`,
  836. :class:`~torchmetrics.classification.MulticlassRecall` and :class:`~torchmetrics.classification.MultilabelRecall`
  837. for the specific details of each argument influence and examples.
  838. Legacy Example:
  839. >>> from torch import tensor
  840. >>> preds = tensor([2, 0, 2, 1])
  841. >>> target = tensor([1, 1, 2, 0])
  842. >>> recall = Recall(task="multiclass", average='macro', num_classes=3)
  843. >>> recall(preds, target)
  844. tensor(0.3333)
  845. >>> recall = Recall(task="multiclass", average='micro', num_classes=3)
  846. >>> recall(preds, target)
  847. tensor(0.2500)
  848. """
  849. def __new__( # type: ignore[misc]
  850. cls: type["Recall"],
  851. task: Literal["binary", "multiclass", "multilabel"],
  852. threshold: float = 0.5,
  853. num_classes: Optional[int] = None,
  854. num_labels: Optional[int] = None,
  855. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
  856. multidim_average: Optional[Literal["global", "samplewise"]] = "global",
  857. top_k: Optional[int] = 1,
  858. ignore_index: Optional[int] = None,
  859. validate_args: bool = True,
  860. **kwargs: Any,
  861. ) -> Metric:
  862. """Initialize task metric."""
  863. task = ClassificationTask.from_str(task)
  864. assert multidim_average is not None # noqa: S101 # needed for mypy
  865. kwargs.update({
  866. "multidim_average": multidim_average,
  867. "ignore_index": ignore_index,
  868. "validate_args": validate_args,
  869. })
  870. if task == ClassificationTask.BINARY:
  871. return BinaryRecall(threshold, **kwargs)
  872. if task == ClassificationTask.MULTICLASS:
  873. if not isinstance(num_classes, int):
  874. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  875. if not isinstance(top_k, int):
  876. raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`")
  877. return MulticlassRecall(num_classes, top_k, average, **kwargs)
  878. if task == ClassificationTask.MULTILABEL:
  879. if not isinstance(num_labels, int):
  880. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  881. return MultilabelRecall(num_labels, threshold, average, **kwargs)
  882. return None # type: ignore[return-value]