recall_fixed_precision.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.classification.base import _ClassificationTaskWrapper
  19. from torchmetrics.classification.precision_recall_curve import (
  20. BinaryPrecisionRecallCurve,
  21. MulticlassPrecisionRecallCurve,
  22. MultilabelPrecisionRecallCurve,
  23. )
  24. from torchmetrics.functional.classification.recall_fixed_precision import (
  25. _binary_recall_at_fixed_precision_arg_validation,
  26. _binary_recall_at_fixed_precision_compute,
  27. _multiclass_recall_at_fixed_precision_arg_compute,
  28. _multiclass_recall_at_fixed_precision_arg_validation,
  29. _multilabel_recall_at_fixed_precision_arg_compute,
  30. _multilabel_recall_at_fixed_precision_arg_validation,
  31. )
  32. from torchmetrics.metric import Metric
  33. from torchmetrics.utilities.data import dim_zero_cat
  34. from torchmetrics.utilities.enums import ClassificationTask
  35. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  36. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  37. if not _MATPLOTLIB_AVAILABLE:
  38. __doctest_skip__ = [
  39. "BinaryRecallAtFixedPrecision.plot",
  40. "MulticlassRecallAtFixedPrecision.plot",
  41. "MultilabelRecallAtFixedPrecision.plot",
  42. ]
  43. class BinaryRecallAtFixedPrecision(BinaryPrecisionRecallCurve):
  44. r"""Compute the highest possible recall value given the minimum precision thresholds provided.
  45. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for
  46. a given precision level.
  47. As input to ``forward`` and ``update`` the metric accepts the following input:
  48. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)``. Preds should be a tensor containing
  49. probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input
  50. to be logits and will auto apply sigmoid per element.
  51. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing
  52. ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value
  53. 1 always encodes the positive class.
  54. .. tip::
  55. Additional dimension ``...`` will be flattened into the batch dimension.
  56. As output to ``forward`` and ``compute`` the metric returns the following output:
  57. - ``recall`` (:class:`~torch.Tensor`): A scalar tensor with the maximum recall for the given precision level
  58. - ``threshold`` (:class:`~torch.Tensor`): A scalar tensor with the corresponding threshold level
  59. .. note::
  60. The implementation both supports calculating the metric in a non-binned but accurate version and a
  61. binned version that is less accurate but more memory efficient. Setting the `thresholds` argument to ``None``
  62. will activate the non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting
  63. the `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory
  64. of size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
  65. Args:
  66. min_precision: float value specifying minimum precision threshold.
  67. thresholds:
  68. Can be one of:
  69. - If set to ``None``, will use a non-binned approach where thresholds are dynamically calculated from
  70. all the data. Most accurate but also most memory consuming approach.
  71. - If set to an ``int`` (larger than 1), will use that number of thresholds linearly spaced from
  72. 0 to 1 as bins for the calculation.
  73. - If set to an ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation
  74. - If set to an 1d :class:`~torch.Tensor` of floats, will use the indicated thresholds in the tensor as
  75. bins for the calculation.
  76. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  77. Set to ``False`` for faster computations.
  78. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  79. Example:
  80. >>> from torch import tensor
  81. >>> from torchmetrics.classification import BinaryRecallAtFixedPrecision
  82. >>> preds = tensor([0, 0.5, 0.7, 0.8])
  83. >>> target = tensor([0, 1, 1, 0])
  84. >>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5, thresholds=None)
  85. >>> metric(preds, target)
  86. (tensor(1.), tensor(0.5000))
  87. >>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5, thresholds=5)
  88. >>> metric(preds, target)
  89. (tensor(1.), tensor(0.5000))
  90. """
  91. is_differentiable: bool = False
  92. higher_is_better: Optional[bool] = None
  93. full_state_update: bool = False
  94. plot_lower_bound: float = 0.0
  95. plot_upper_bound: float = 1.0
  96. def __init__(
  97. self,
  98. min_precision: float,
  99. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  100. ignore_index: Optional[int] = None,
  101. validate_args: bool = True,
  102. **kwargs: Any,
  103. ) -> None:
  104. super().__init__(thresholds, ignore_index, validate_args=False, **kwargs)
  105. if validate_args:
  106. _binary_recall_at_fixed_precision_arg_validation(min_precision, thresholds, ignore_index)
  107. self.validate_args = validate_args
  108. self.min_precision = min_precision
  109. def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
  110. """Compute metric."""
  111. state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
  112. return _binary_recall_at_fixed_precision_compute(state, self.thresholds, self.min_precision)
  113. def plot( # type: ignore[override]
  114. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  115. ) -> _PLOT_OUT_TYPE:
  116. """Plot a single or multiple values from the metric.
  117. Args:
  118. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  119. If no value is provided, will automatically call `metric.compute` and plot that result.
  120. ax: An matplotlib axis object. If provided will add plot to that axis
  121. Returns:
  122. Figure object and Axes object
  123. Raises:
  124. ModuleNotFoundError:
  125. If `matplotlib` is not installed
  126. .. plot::
  127. :scale: 75
  128. >>> from torch import rand, randint
  129. >>> # Example plotting a single value
  130. >>> from torchmetrics.classification import BinaryRecallAtFixedPrecision
  131. >>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5)
  132. >>> metric.update(rand(10), randint(2,(10,)))
  133. >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default
  134. .. plot::
  135. :scale: 75
  136. >>> from torch import rand, randint
  137. >>> # Example plotting multiple values
  138. >>> from torchmetrics.classification import BinaryRecallAtFixedPrecision
  139. >>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5)
  140. >>> values = [ ]
  141. >>> for _ in range(10):
  142. ... # we index by 0 such that only the maximum recall value is plotted
  143. ... values.append(metric(rand(10), randint(2,(10,)))[0])
  144. >>> fig_, ax_ = metric.plot(values)
  145. """
  146. val = val or self.compute()[0] # by default we select the maximum recall value to plot
  147. return self._plot(val, ax)
  148. class MulticlassRecallAtFixedPrecision(MulticlassPrecisionRecallCurve):
  149. r"""Compute the highest possible recall value given the minimum precision thresholds provided.
  150. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for
  151. a given precision level.
  152. For multiclass the metric is calculated by iteratively treating each class as the positive class and all other
  153. classes as the negative, which is referred to as the one-vs-rest approach. One-vs-one is currently not supported by
  154. this metric.
  155. As input to ``forward`` and ``update`` the metric accepts the following input:
  156. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor
  157. containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider
  158. the input to be logits and will auto apply softmax per sample.
  159. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing
  160. ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if `ignore_index`
  161. is specified).
  162. .. tip::
  163. Additional dimension ``...`` will be flattened into the batch dimension.
  164. As output to ``forward`` and ``compute`` the metric returns a tuple of either 2 tensors or 2 lists containing:
  165. - ``recall`` (:class:`~torch.Tensor`): A 1d tensor of size ``(n_classes, )`` with the maximum recall for the
  166. given precision level per class
  167. - ``threshold`` (:class:`~torch.Tensor`): A 1d tensor of size ``(n_classes, )`` with the corresponding threshold
  168. level per class
  169. .. note::
  170. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  171. that is less accurate but more memory efficient. Setting the `thresholds` argument to ``None`` will activate the
  172. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  173. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  174. size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory).
  175. Args:
  176. num_classes: Integer specifying the number of classes
  177. min_precision: float value specifying minimum precision threshold.
  178. thresholds:
  179. Can be one of:
  180. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  181. all the data. Most accurate but also most memory consuming approach.
  182. - If set to an ``int`` (larger than 1), will use that number of thresholds linearly spaced from
  183. 0 to 1 as bins for the calculation.
  184. - If set to an ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation
  185. - If set to an 1d :class:`~torch.Tensor` of floats, will use the indicated thresholds in the tensor as
  186. bins for the calculation.
  187. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  188. Set to ``False`` for faster computations.
  189. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  190. Example:
  191. >>> from torch import tensor
  192. >>> from torchmetrics.classification import MulticlassRecallAtFixedPrecision
  193. >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
  194. ... [0.05, 0.75, 0.05, 0.05, 0.05],
  195. ... [0.05, 0.05, 0.75, 0.05, 0.05],
  196. ... [0.05, 0.05, 0.05, 0.75, 0.05]])
  197. >>> target = tensor([0, 1, 3, 2])
  198. >>> metric = MulticlassRecallAtFixedPrecision(num_classes=5, min_precision=0.5, thresholds=None)
  199. >>> metric(preds, target)
  200. (tensor([1., 1., 0., 0., 0.]), tensor([0.7500, 0.7500, nan, nan, nan]))
  201. >>> mcrafp = MulticlassRecallAtFixedPrecision(num_classes=5, min_precision=0.5, thresholds=5)
  202. >>> mcrafp(preds, target)
  203. (tensor([1., 1., 0., 0., 0.]), tensor([0.7500, 0.7500, nan, nan, nan]))
  204. """
  205. is_differentiable: bool = False
  206. higher_is_better: Optional[bool] = None
  207. full_state_update: bool = False
  208. plot_lower_bound: float = 0.0
  209. plot_upper_bound: float = 1.0
  210. plot_legend_name: str = "Class"
  211. def __init__(
  212. self,
  213. num_classes: int,
  214. min_precision: float,
  215. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  216. ignore_index: Optional[int] = None,
  217. validate_args: bool = True,
  218. **kwargs: Any,
  219. ) -> None:
  220. super().__init__(
  221. num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs
  222. )
  223. if validate_args:
  224. _multiclass_recall_at_fixed_precision_arg_validation(num_classes, min_precision, thresholds, ignore_index)
  225. self.validate_args = validate_args
  226. self.min_precision = min_precision
  227. def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
  228. """Compute metric."""
  229. state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
  230. return _multiclass_recall_at_fixed_precision_arg_compute(
  231. state, self.num_classes, self.thresholds, self.min_precision
  232. )
  233. def plot( # type: ignore[override]
  234. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  235. ) -> _PLOT_OUT_TYPE:
  236. """Plot a single or multiple values from the metric.
  237. Args:
  238. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  239. If no value is provided, will automatically call `metric.compute` and plot that result.
  240. ax: An matplotlib axis object. If provided will add plot to that axis
  241. Returns:
  242. Figure object and Axes object
  243. Raises:
  244. ModuleNotFoundError:
  245. If `matplotlib` is not installed
  246. .. plot::
  247. :scale: 75
  248. >>> from torch import rand, randint
  249. >>> # Example plotting a single value per class
  250. >>> from torchmetrics.classification import MulticlassRecallAtFixedPrecision
  251. >>> metric = MulticlassRecallAtFixedPrecision(num_classes=3, min_precision=0.5)
  252. >>> metric.update(rand(20, 3).softmax(dim=-1), randint(3, (20,)))
  253. >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default
  254. .. plot::
  255. :scale: 75
  256. >>> from torch import rand, randint
  257. >>> # Example plotting a multiple values per class
  258. >>> from torchmetrics.classification import MulticlassRecallAtFixedPrecision
  259. >>> metric = MulticlassRecallAtFixedPrecision(num_classes=3, min_precision=0.5)
  260. >>> values = []
  261. >>> for _ in range(20):
  262. ... # we index by 0 such that only the maximum recall value is plotted
  263. ... values.append(metric(rand(20, 3).softmax(dim=-1), randint(3, (20,)))[0])
  264. >>> fig_, ax_ = metric.plot(values)
  265. """
  266. val = val or self.compute()[0] # by default we select the maximum recall value to plot
  267. return self._plot(val, ax)
  268. class MultilabelRecallAtFixedPrecision(MultilabelPrecisionRecallCurve):
  269. r"""Compute the highest possible recall value given the minimum precision thresholds provided.
  270. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for
  271. a given precision level.
  272. As input to ``forward`` and ``update`` the metric accepts the following input:
  273. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor
  274. containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider
  275. the input to be logits and will auto apply sigmoid per element.
  276. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing
  277. ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value
  278. 1 always encodes the positive class.
  279. .. tip::
  280. Additional dimension ``...`` will be flattened into the batch dimension.
  281. As output to ``forward`` and ``compute`` the metric returns a tuple of either 2 tensors or 2 lists containing:
  282. - ``recall`` (:class:`~torch.Tensor`): A 1d tensor of size ``(n_classes, )`` with the maximum recall for the
  283. given precision level per class
  284. - ``threshold`` (:class:`~torch.Tensor`): A 1d tensor of size ``(n_classes, )`` with the corresponding threshold
  285. level per class
  286. .. note::
  287. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  288. that is less accurate but more memory efficient. Setting the `thresholds` argument to ```None``` will activate
  289. the non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the
  290. `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  291. size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
  292. Args:
  293. num_labels: Integer specifying the number of labels
  294. min_precision: float value specifying minimum precision threshold.
  295. thresholds:
  296. Can be one of:
  297. - If set to ``None``, will use a non-binned approach where thresholds are dynamically calculated from
  298. all the data. Most accurate but also most memory consuming approach.
  299. - If set to an ``int`` (larger than 1), will use that number of thresholds linearly spaced from
  300. 0 to 1 as bins for the calculation.
  301. - If set to an ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation
  302. - If set to an 1d :class:`~torch.Tensor` of floats, will use the indicated thresholds in the tensor as
  303. bins for the calculation.
  304. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  305. Set to ``False`` for faster computations.
  306. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  307. Example:
  308. >>> from torch import tensor
  309. >>> from torchmetrics.classification import MultilabelRecallAtFixedPrecision
  310. >>> preds = tensor([[0.75, 0.05, 0.35],
  311. ... [0.45, 0.75, 0.05],
  312. ... [0.05, 0.55, 0.75],
  313. ... [0.05, 0.65, 0.05]])
  314. >>> target = tensor([[1, 0, 1],
  315. ... [0, 0, 0],
  316. ... [0, 1, 1],
  317. ... [1, 1, 1]])
  318. >>> metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5, thresholds=None)
  319. >>> metric(preds, target)
  320. (tensor([1., 1., 1.]), tensor([0.0500, 0.5500, 0.0500]))
  321. >>> mlrafp = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5, thresholds=5)
  322. >>> mlrafp(preds, target)
  323. (tensor([1., 1., 1.]), tensor([0.0000, 0.5000, 0.0000]))
  324. """
  325. is_differentiable: bool = False
  326. higher_is_better: Optional[bool] = None
  327. full_state_update: bool = False
  328. plot_lower_bound: float = 0.0
  329. plot_upper_bound: float = 1.0
  330. plot_legend_name: str = "Label"
  331. def __init__(
  332. self,
  333. num_labels: int,
  334. min_precision: float,
  335. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  336. ignore_index: Optional[int] = None,
  337. validate_args: bool = True,
  338. **kwargs: Any,
  339. ) -> None:
  340. super().__init__(
  341. num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs
  342. )
  343. if validate_args:
  344. _multilabel_recall_at_fixed_precision_arg_validation(num_labels, min_precision, thresholds, ignore_index)
  345. self.validate_args = validate_args
  346. self.min_precision = min_precision
  347. def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
  348. """Compute metric."""
  349. state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
  350. return _multilabel_recall_at_fixed_precision_arg_compute(
  351. state, self.num_labels, self.thresholds, self.ignore_index, self.min_precision
  352. )
  353. def plot( # type: ignore[override]
  354. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  355. ) -> _PLOT_OUT_TYPE:
  356. """Plot a single or multiple values from the metric.
  357. Args:
  358. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  359. If no value is provided, will automatically call `metric.compute` and plot that result.
  360. ax: An matplotlib axis object. If provided will add plot to that axis
  361. Returns:
  362. Figure object and Axes object
  363. Raises:
  364. ModuleNotFoundError:
  365. If `matplotlib` is not installed
  366. .. plot::
  367. :scale: 75
  368. >>> from torch import rand, randint
  369. >>> # Example plotting a single value
  370. >>> from torchmetrics.classification import MultilabelRecallAtFixedPrecision
  371. >>> metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5)
  372. >>> metric.update(rand(20, 3), randint(2, (20, 3)))
  373. >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default
  374. .. plot::
  375. :scale: 75
  376. >>> from torch import rand, randint
  377. >>> # Example plotting multiple values
  378. >>> from torchmetrics.classification import MultilabelRecallAtFixedPrecision
  379. >>> metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5)
  380. >>> values = [ ]
  381. >>> for _ in range(10):
  382. ... # we index by 0 such that only the maximum recall value is plotted
  383. ... values.append(metric(rand(20, 3), randint(2, (20, 3)))[0])
  384. >>> fig_, ax_ = metric.plot(values)
  385. """
  386. val = val or self.compute()[0] # by default we select the maximum recall value to plot
  387. return self._plot(val, ax)
  388. class RecallAtFixedPrecision(_ClassificationTaskWrapper):
  389. r"""Compute the highest possible recall value given the minimum precision thresholds provided.
  390. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for
  391. a given precision level.
  392. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  393. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  394. :class:`~torchmetrics.classification.BinaryRecallAtFixedPrecision`,
  395. :class:`~torchmetrics.classification.MulticlassRecallAtFixedPrecision` and
  396. :class:`~torchmetrics.classification.MultilabelRecallAtFixedPrecision` for the specific details of each argument
  397. influence and examples.
  398. """
  399. def __new__( # type: ignore[misc]
  400. cls: type["RecallAtFixedPrecision"],
  401. task: Literal["binary", "multiclass", "multilabel"],
  402. min_precision: float,
  403. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  404. num_classes: Optional[int] = None,
  405. num_labels: Optional[int] = None,
  406. ignore_index: Optional[int] = None,
  407. validate_args: bool = True,
  408. **kwargs: Any,
  409. ) -> Metric:
  410. """Initialize task metric."""
  411. task = ClassificationTask.from_str(task)
  412. if task == ClassificationTask.BINARY:
  413. return BinaryRecallAtFixedPrecision(min_precision, thresholds, ignore_index, validate_args, **kwargs)
  414. if task == ClassificationTask.MULTICLASS:
  415. if not isinstance(num_classes, int):
  416. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  417. return MulticlassRecallAtFixedPrecision(
  418. num_classes, min_precision, thresholds, ignore_index, validate_args, **kwargs
  419. )
  420. if task == ClassificationTask.MULTILABEL:
  421. if not isinstance(num_labels, int):
  422. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  423. return MultilabelRecallAtFixedPrecision(
  424. num_labels, min_precision, thresholds, ignore_index, validate_args, **kwargs
  425. )
  426. raise ValueError(f"Task {task} not supported!")