precision_fixed_recall.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.classification.base import _ClassificationTaskWrapper
  19. from torchmetrics.classification.precision_recall_curve import (
  20. BinaryPrecisionRecallCurve,
  21. MulticlassPrecisionRecallCurve,
  22. MultilabelPrecisionRecallCurve,
  23. )
  24. from torchmetrics.functional.classification.precision_fixed_recall import _precision_at_recall
  25. from torchmetrics.functional.classification.recall_fixed_precision import (
  26. _binary_recall_at_fixed_precision_arg_validation,
  27. _binary_recall_at_fixed_precision_compute,
  28. _multiclass_recall_at_fixed_precision_arg_compute,
  29. _multiclass_recall_at_fixed_precision_arg_validation,
  30. _multilabel_recall_at_fixed_precision_arg_compute,
  31. _multilabel_recall_at_fixed_precision_arg_validation,
  32. )
  33. from torchmetrics.metric import Metric
  34. from torchmetrics.utilities.data import dim_zero_cat
  35. from torchmetrics.utilities.enums import ClassificationTask
  36. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  37. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  38. if not _MATPLOTLIB_AVAILABLE:
  39. __doctest_skip__ = [
  40. "BinaryPrecisionAtFixedRecall.plot",
  41. "MulticlassPrecisionAtFixedRecall.plot",
  42. "MultilabelPrecisionAtFixedRecall.plot",
  43. ]
  44. class BinaryPrecisionAtFixedRecall(BinaryPrecisionRecallCurve):
  45. r"""Compute the highest possible precision value given the minimum recall thresholds provided.
  46. This is done by first calculating the precision-recall curve for different thresholds and the find the precision for
  47. a given recall level.
  48. As input to ``forward`` and ``update`` the metric accepts the following input:
  49. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)``. Preds should be a tensor containing
  50. probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input
  51. to be logits and will auto apply sigmoid per element.
  52. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing
  53. ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value
  54. 1 always encodes the positive class.
  55. .. tip::
  56. Additional dimension ``...`` will be flattened into the batch dimension.
  57. As output to ``forward`` and ``compute`` the metric returns the following output:
  58. - ``precision`` (:class:`~torch.Tensor`): A scalar tensor with the maximum precision for the given recall level
  59. - ``threshold`` (:class:`~torch.Tensor`): A scalar tensor with the corresponding threshold level
  60. .. note::
  61. The implementation both supports calculating the metric in a non-binned but accurate version and a
  62. binned version that is less accurate but more memory efficient. Setting the `thresholds` argument to ``None``
  63. will activate the non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting
  64. the `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory
  65. of size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
  66. Args:
  67. min_recall: float value specifying minimum recall threshold.
  68. thresholds:
  69. Can be one of:
  70. - If set to ``None``, will use a non-binned approach where thresholds are dynamically calculated from
  71. all the data. Most accurate but also most memory consuming approach.
  72. - If set to an ``int`` (larger than 1), will use that number of thresholds linearly spaced from
  73. 0 to 1 as bins for the calculation.
  74. - If set to an ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation
  75. - If set to an 1d :class:`~torch.Tensor` of floats, will use the indicated thresholds in the tensor as
  76. bins for the calculation.
  77. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  78. Set to ``False`` for faster computations.
  79. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  80. Example:
  81. >>> from torch import tensor
  82. >>> from torchmetrics.classification import BinaryPrecisionAtFixedRecall
  83. >>> preds = tensor([0, 0.5, 0.7, 0.8])
  84. >>> target = tensor([0, 1, 1, 0])
  85. >>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5, thresholds=None)
  86. >>> metric(preds, target)
  87. (tensor(0.6667), tensor(0.5000))
  88. >>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5, thresholds=5)
  89. >>> metric(preds, target)
  90. (tensor(0.6667), tensor(0.5000))
  91. """
  92. is_differentiable: bool = False
  93. higher_is_better: Optional[bool] = None
  94. full_state_update: bool = False
  95. plot_lower_bound: float = 0.0
  96. plot_upper_bound: float = 1.0
  97. def __init__(
  98. self,
  99. min_recall: float,
  100. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  101. ignore_index: Optional[int] = None,
  102. validate_args: bool = True,
  103. **kwargs: Any,
  104. ) -> None:
  105. super().__init__(thresholds, ignore_index, validate_args=False, **kwargs)
  106. if validate_args:
  107. _binary_recall_at_fixed_precision_arg_validation(min_recall, thresholds, ignore_index)
  108. self.validate_args = validate_args
  109. self.min_recall = min_recall
  110. def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
  111. """Compute metric."""
  112. state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
  113. return _binary_recall_at_fixed_precision_compute(
  114. state, self.thresholds, self.min_recall, reduce_fn=_precision_at_recall
  115. )
  116. def plot( # type: ignore[override]
  117. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  118. ) -> _PLOT_OUT_TYPE:
  119. """Plot a single or multiple values from the metric.
  120. Args:
  121. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  122. If no value is provided, will automatically call `metric.compute` and plot that result.
  123. ax: An matplotlib axis object. If provided will add plot to that axis
  124. Returns:
  125. Figure object and Axes object
  126. Raises:
  127. ModuleNotFoundError:
  128. If `matplotlib` is not installed
  129. .. plot::
  130. :scale: 75
  131. >>> from torch import rand, randint
  132. >>> # Example plotting a single value
  133. >>> from torchmetrics.classification import BinaryPrecisionAtFixedRecall
  134. >>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5)
  135. >>> metric.update(rand(10), randint(2,(10,)))
  136. >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default
  137. .. plot::
  138. :scale: 75
  139. >>> from torch import rand, randint
  140. >>> # Example plotting multiple values
  141. >>> from torchmetrics.classification import BinaryPrecisionAtFixedRecall
  142. >>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5)
  143. >>> values = [ ]
  144. >>> for _ in range(10):
  145. ... # we index by 0 such that only the maximum recall value is plotted
  146. ... values.append(metric(rand(10), randint(2,(10,)))[0])
  147. >>> fig_, ax_ = metric.plot(values)
  148. """
  149. val = val or self.compute()[0] # by default we select the maximum recall value to plot
  150. return self._plot(val, ax)
  151. class MulticlassPrecisionAtFixedRecall(MulticlassPrecisionRecallCurve):
  152. r"""Compute the highest possible precision value given the minimum recall thresholds provided.
  153. This is done by first calculating the precision-recall curve for different thresholds and the find the precision for
  154. a given recall level.
  155. As input to ``forward`` and ``update`` the metric accepts the following input:
  156. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor
  157. containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider
  158. the input to be logits and will auto apply softmax per sample.
  159. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing
  160. ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if `ignore_index`
  161. is specified).
  162. .. tip::
  163. Additional dimension ``...`` will be flattened into the batch dimension.
  164. As output to ``forward`` and ``compute`` the metric returns a tuple of either 2 tensors or 2 lists containing:
  165. - ``precision`` (:class:`~torch.Tensor`): A 1d tensor of size ``(n_classes, )`` with the maximum precision for the
  166. given recall level per class
  167. - ``threshold`` (:class:`~torch.Tensor`): A 1d tensor of size ``(n_classes, )`` with the corresponding threshold
  168. level per class
  169. .. note::
  170. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  171. that is less accurate but more memory efficient. Setting the `thresholds` argument to ``None`` will activate the
  172. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  173. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  174. size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory).
  175. Args:
  176. num_classes: Integer specifying the number of classes
  177. min_recall: float value specifying minimum recall threshold.
  178. thresholds:
  179. Can be one of:
  180. - If set to ``None``, will use a non-binned approach where thresholds are dynamically calculated from
  181. all the data. Most accurate but also most memory consuming approach.
  182. - If set to an ``int`` (larger than 1), will use that number of thresholds linearly spaced from
  183. 0 to 1 as bins for the calculation.
  184. - If set to an ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation
  185. - If set to an 1d :class:`~torch.Tensor` of floats, will use the indicated thresholds in the tensor as
  186. bins for the calculation.
  187. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  188. Set to ``False`` for faster computations.
  189. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  190. Example:
  191. >>> from torch import tensor
  192. >>> from torchmetrics.classification import MulticlassPrecisionAtFixedRecall
  193. >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
  194. ... [0.05, 0.75, 0.05, 0.05, 0.05],
  195. ... [0.05, 0.05, 0.75, 0.05, 0.05],
  196. ... [0.05, 0.05, 0.05, 0.75, 0.05]])
  197. >>> target = tensor([0, 1, 3, 2])
  198. >>> metric = MulticlassPrecisionAtFixedRecall(num_classes=5, min_recall=0.5, thresholds=None)
  199. >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE
  200. (tensor([1.0000, 1.0000, 0.2500, 0.2500, 0.0000]),
  201. tensor([0.7500, 0.7500, 0.0500, 0.0500, nan]))
  202. >>> mcrafp = MulticlassPrecisionAtFixedRecall(num_classes=5, min_recall=0.5, thresholds=5)
  203. >>> mcrafp(preds, target) # doctest: +NORMALIZE_WHITESPACE
  204. (tensor([1.0000, 1.0000, 0.2500, 0.2500, 0.0000]),
  205. tensor([0.7500, 0.7500, 0.0000, 0.0000, nan]))
  206. """
  207. is_differentiable: bool = False
  208. higher_is_better: Optional[bool] = None
  209. full_state_update: bool = False
  210. plot_lower_bound: float = 0.0
  211. plot_upper_bound: float = 1.0
  212. plot_legend_name: str = "Class"
  213. def __init__(
  214. self,
  215. num_classes: int,
  216. min_recall: float,
  217. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  218. ignore_index: Optional[int] = None,
  219. validate_args: bool = True,
  220. **kwargs: Any,
  221. ) -> None:
  222. super().__init__(
  223. num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs
  224. )
  225. if validate_args:
  226. _multiclass_recall_at_fixed_precision_arg_validation(num_classes, min_recall, thresholds, ignore_index)
  227. self.validate_args = validate_args
  228. self.min_recall = min_recall
  229. def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
  230. """Compute metric."""
  231. state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
  232. return _multiclass_recall_at_fixed_precision_arg_compute(
  233. state, self.num_classes, self.thresholds, self.min_recall, reduce_fn=_precision_at_recall
  234. )
  235. def plot( # type: ignore[override]
  236. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  237. ) -> _PLOT_OUT_TYPE:
  238. """Plot a single or multiple values from the metric.
  239. Args:
  240. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  241. If no value is provided, will automatically call `metric.compute` and plot that result.
  242. ax: An matplotlib axis object. If provided will add plot to that axis
  243. Returns:
  244. Figure object and Axes object
  245. Raises:
  246. ModuleNotFoundError:
  247. If `matplotlib` is not installed
  248. .. plot::
  249. :scale: 75
  250. >>> from torch import rand, randint
  251. >>> # Example plotting a single value per class
  252. >>> from torchmetrics.classification import MulticlassPrecisionAtFixedRecall
  253. >>> metric = MulticlassPrecisionAtFixedRecall(num_classes=3, min_recall=0.5)
  254. >>> metric.update(rand(20, 3).softmax(dim=-1), randint(3, (20,)))
  255. >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default
  256. .. plot::
  257. :scale: 75
  258. >>> from torch import rand, randint
  259. >>> # Example plotting a multiple values per class
  260. >>> from torchmetrics.classification import MulticlassPrecisionAtFixedRecall
  261. >>> metric = MulticlassPrecisionAtFixedRecall(num_classes=3, min_recall=0.5)
  262. >>> values = []
  263. >>> for _ in range(20):
  264. ... # we index by 0 such that only the maximum recall value is plotted
  265. ... values.append(metric(rand(20, 3).softmax(dim=-1), randint(3, (20,)))[0])
  266. >>> fig_, ax_ = metric.plot(values)
  267. """
  268. val = val or self.compute()[0] # by default we select the maximum recall value to plot
  269. return self._plot(val, ax)
  270. class MultilabelPrecisionAtFixedRecall(MultilabelPrecisionRecallCurve):
  271. r"""Compute the highest possible precision value given the minimum recall thresholds provided.
  272. This is done by first calculating the precision-recall curve for different thresholds and the find the precision for
  273. a given recall level.
  274. As input to ``forward`` and ``update`` the metric accepts the following input:
  275. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor
  276. containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider
  277. the input to be logits and will auto apply sigmoid per element.
  278. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing
  279. ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value
  280. 1 always encodes the positive class.
  281. .. tip::
  282. Additional dimension ``...`` will be flattened into the batch dimension.
  283. As output to ``forward`` and ``compute`` the metric returns a tuple of either 2 tensors or 2 lists containing:
  284. - ``precision`` (:class:`~torch.Tensor`): A 1d tensor of size ``(n_classes, )`` with the maximum precision for the
  285. given recall level per class
  286. - ``threshold`` (:class:`~torch.Tensor`): A 1d tensor of size ``(n_classes, )`` with the corresponding threshold
  287. level per class
  288. .. note::
  289. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  290. that is less accurate but more memory efficient. Setting the `thresholds` argument to ``None`` will activate the
  291. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  292. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  293. size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
  294. Args:
  295. num_labels: Integer specifying the number of labels
  296. min_recall: float value specifying minimum recall threshold.
  297. thresholds:
  298. Can be one of:
  299. - If set to ``None``, will use a non-binned approach where thresholds are dynamically calculated from
  300. all the data. Most accurate but also most memory consuming approach.
  301. - If set to an ``int`` (larger than 1), will use that number of thresholds linearly spaced from
  302. 0 to 1 as bins for the calculation.
  303. - If set to an ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation
  304. - If set to an 1d :class:`~torch.Tensor` of floats, will use the indicated thresholds in the tensor as
  305. bins for the calculation.
  306. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  307. Set to ``False`` for faster computations.
  308. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  309. Example:
  310. >>> from torch import tensor
  311. >>> from torchmetrics.classification import MultilabelPrecisionAtFixedRecall
  312. >>> preds = tensor([[0.75, 0.05, 0.35],
  313. ... [0.45, 0.75, 0.05],
  314. ... [0.05, 0.55, 0.75],
  315. ... [0.05, 0.65, 0.05]])
  316. >>> target = tensor([[1, 0, 1],
  317. ... [0, 0, 0],
  318. ... [0, 1, 1],
  319. ... [1, 1, 1]])
  320. >>> metric = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5, thresholds=None)
  321. >>> metric(preds, target)
  322. (tensor([1.0000, 0.6667, 1.0000]), tensor([0.7500, 0.5500, 0.3500]))
  323. >>> mlrafp = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5, thresholds=5)
  324. >>> mlrafp(preds, target)
  325. (tensor([1.0000, 0.6667, 1.0000]), tensor([0.7500, 0.5000, 0.2500]))
  326. """
  327. is_differentiable: bool = False
  328. higher_is_better: Optional[bool] = None
  329. full_state_update: bool = False
  330. plot_lower_bound: float = 0.0
  331. plot_upper_bound: float = 1.0
  332. plot_legend_name: str = "Label"
  333. def __init__(
  334. self,
  335. num_labels: int,
  336. min_recall: float,
  337. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  338. ignore_index: Optional[int] = None,
  339. validate_args: bool = True,
  340. **kwargs: Any,
  341. ) -> None:
  342. super().__init__(
  343. num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs
  344. )
  345. if validate_args:
  346. _multilabel_recall_at_fixed_precision_arg_validation(num_labels, min_recall, thresholds, ignore_index)
  347. self.validate_args = validate_args
  348. self.min_recall = min_recall
  349. def compute(self) -> tuple[Tensor, Tensor]: # type: ignore[override]
  350. """Compute metric."""
  351. state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
  352. return _multilabel_recall_at_fixed_precision_arg_compute(
  353. state, self.num_labels, self.thresholds, self.ignore_index, self.min_recall, reduce_fn=_precision_at_recall
  354. )
  355. def plot( # type: ignore[override]
  356. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  357. ) -> _PLOT_OUT_TYPE:
  358. """Plot a single or multiple values from the metric.
  359. Args:
  360. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  361. If no value is provided, will automatically call `metric.compute` and plot that result.
  362. ax: An matplotlib axis object. If provided will add plot to that axis
  363. Returns:
  364. Figure object and Axes object
  365. Raises:
  366. ModuleNotFoundError:
  367. If `matplotlib` is not installed
  368. .. plot::
  369. :scale: 75
  370. >>> from torch import rand, randint
  371. >>> # Example plotting a single value
  372. >>> from torchmetrics.classification import MultilabelPrecisionAtFixedRecall
  373. >>> metric = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5)
  374. >>> metric.update(rand(20, 3), randint(2, (20, 3)))
  375. >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default
  376. .. plot::
  377. :scale: 75
  378. >>> from torch import rand, randint
  379. >>> # Example plotting multiple values
  380. >>> from torchmetrics.classification import MultilabelPrecisionAtFixedRecall
  381. >>> metric = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5)
  382. >>> values = [ ]
  383. >>> for _ in range(10):
  384. ... # we index by 0 such that only the maximum recall value is plotted
  385. ... values.append(metric(rand(20, 3), randint(2, (20, 3)))[0])
  386. >>> fig_, ax_ = metric.plot(values)
  387. """
  388. val = val or self.compute()[0] # by default we select the maximum recall value to plot
  389. return self._plot(val, ax)
  390. class PrecisionAtFixedRecall(_ClassificationTaskWrapper):
  391. r"""Compute the highest possible recall value given the minimum precision thresholds provided.
  392. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for
  393. a given precision level.
  394. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  395. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  396. :class:`~torchmetrics.classification.BinaryPrecisionAtFixedRecall`,
  397. :class:`~torchmetrics.classification.MulticlassPrecisionAtFixedRecall` and
  398. :class:`~torchmetrics.classification.MultilabelPrecisionAtFixedRecall` for the specific details of each argument
  399. influence and examples.
  400. """
  401. def __new__( # type: ignore[misc]
  402. cls: type["PrecisionAtFixedRecall"],
  403. task: Literal["binary", "multiclass", "multilabel"],
  404. min_recall: float,
  405. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  406. num_classes: Optional[int] = None,
  407. num_labels: Optional[int] = None,
  408. ignore_index: Optional[int] = None,
  409. validate_args: bool = True,
  410. **kwargs: Any,
  411. ) -> Metric:
  412. """Initialize task metric."""
  413. task = ClassificationTask.from_str(task)
  414. if task == ClassificationTask.BINARY:
  415. return BinaryPrecisionAtFixedRecall(min_recall, thresholds, ignore_index, validate_args, **kwargs)
  416. if task == ClassificationTask.MULTICLASS:
  417. if not isinstance(num_classes, int):
  418. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  419. return MulticlassPrecisionAtFixedRecall(
  420. num_classes, min_recall, thresholds, ignore_index, validate_args, **kwargs
  421. )
  422. if task == ClassificationTask.MULTILABEL:
  423. if not isinstance(num_labels, int):
  424. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  425. return MultilabelPrecisionAtFixedRecall(
  426. num_labels, min_recall, thresholds, ignore_index, validate_args, **kwargs
  427. )
  428. raise ValueError(f"Task {task} not supported!")