average_precision.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.classification.base import _ClassificationTaskWrapper
  19. from torchmetrics.classification.precision_recall_curve import (
  20. BinaryPrecisionRecallCurve,
  21. MulticlassPrecisionRecallCurve,
  22. MultilabelPrecisionRecallCurve,
  23. )
  24. from torchmetrics.functional.classification.average_precision import (
  25. _binary_average_precision_compute,
  26. _multiclass_average_precision_arg_validation,
  27. _multiclass_average_precision_compute,
  28. _multilabel_average_precision_arg_validation,
  29. _multilabel_average_precision_compute,
  30. )
  31. from torchmetrics.metric import Metric
  32. from torchmetrics.utilities.data import dim_zero_cat
  33. from torchmetrics.utilities.enums import ClassificationTask
  34. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  35. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  36. if not _MATPLOTLIB_AVAILABLE:
  37. __doctest_skip__ = [
  38. "BinaryAveragePrecision.plot",
  39. "MulticlassAveragePrecision.plot",
  40. "MultilabelAveragePrecision.plot",
  41. ]
  42. class BinaryAveragePrecision(BinaryPrecisionRecallCurve):
  43. r"""Compute the average precision (AP) score for binary tasks.
  44. The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the
  45. difference in recall from the previous threshold as weight:
  46. .. math::
  47. AP = \sum_{n} (R_n - R_{n-1}) P_n
  48. where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is
  49. equivalent to the area under the precision-recall curve (AUPRC).
  50. As input to ``forward`` and ``update`` the metric accepts the following input:
  51. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)`` containing probabilities or logits for
  52. each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  53. sigmoid per element.
  54. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and
  55. therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the
  56. positive class.
  57. As output to ``forward`` and ``compute`` the metric returns the following output:
  58. - ``bap`` (:class:`~torch.Tensor`): A single scalar with the average precision score
  59. Additional dimension ``...`` will be flattened into the batch dimension.
  60. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  61. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  62. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  63. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  64. size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
  65. Args:
  66. thresholds:
  67. Can be one of:
  68. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  69. all the data. Most accurate but also most memory consuming approach.
  70. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  71. 0 to 1 as bins for the calculation.
  72. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  73. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  74. bins for the calculation.
  75. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  76. Set to ``False`` for faster computations.
  77. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  78. Example:
  79. >>> from torch import tensor
  80. >>> from torchmetrics.classification import BinaryAveragePrecision
  81. >>> preds = tensor([0, 0.5, 0.7, 0.8])
  82. >>> target = tensor([0, 1, 1, 0])
  83. >>> metric = BinaryAveragePrecision(thresholds=None)
  84. >>> metric(preds, target)
  85. tensor(0.5833)
  86. >>> bap = BinaryAveragePrecision(thresholds=5)
  87. >>> bap(preds, target)
  88. tensor(0.6667)
  89. """
  90. is_differentiable: bool = False
  91. higher_is_better: bool = True
  92. full_state_update: bool = False
  93. plot_lower_bound: float = 0.0
  94. plot_upper_bound: float = 1.0
  95. def compute(self) -> Tensor: # type: ignore[override]
  96. """Compute metric."""
  97. state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
  98. return _binary_average_precision_compute(state, self.thresholds)
  99. def plot( # type: ignore[override]
  100. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  101. ) -> _PLOT_OUT_TYPE:
  102. """Plot a single or multiple values from the metric.
  103. Args:
  104. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  105. If no value is provided, will automatically call `metric.compute` and plot that result.
  106. ax: An matplotlib axis object. If provided will add plot to that axis
  107. Returns:
  108. Figure and Axes object
  109. Raises:
  110. ModuleNotFoundError:
  111. If `matplotlib` is not installed
  112. .. plot::
  113. :scale: 75
  114. >>> # Example plotting a single
  115. >>> import torch
  116. >>> from torchmetrics.classification import BinaryAveragePrecision
  117. >>> metric = BinaryAveragePrecision()
  118. >>> metric.update(torch.rand(20,), torch.randint(2, (20,)))
  119. >>> fig_, ax_ = metric.plot()
  120. .. plot::
  121. :scale: 75
  122. >>> # Example plotting multiple values
  123. >>> import torch
  124. >>> from torchmetrics.classification import BinaryAveragePrecision
  125. >>> metric = BinaryAveragePrecision()
  126. >>> values = [ ]
  127. >>> for _ in range(10):
  128. ... values.append(metric(torch.rand(20,), torch.randint(2, (20,))))
  129. >>> fig_, ax_ = metric.plot(values)
  130. """
  131. return self._plot(val, ax)
  132. class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve):
  133. r"""Compute the average precision (AP) score for multiclass tasks.
  134. The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the
  135. difference in recall from the previous threshold as weight:
  136. .. math::
  137. AP = \sum_{n} (R_n - R_{n-1}) P_n
  138. where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is
  139. equivalent to the area under the precision-recall curve (AUPRC).
  140. For multiclass the metric is calculated by iteratively treating each class as the positive class and all other
  141. classes as the negative, which is referred to as the one-vs-rest approach. One-vs-one is currently not supported by
  142. this metric. By default the reported metric is then the average over all classes, but this behavior can be changed
  143. by setting the ``average`` argument.
  144. As input to ``forward`` and ``update`` the metric accepts the following input:
  145. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)`` containing probabilities or logits
  146. for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto
  147. apply softmax per sample.
  148. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and
  149. therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).
  150. As output to ``forward`` and ``compute`` the metric returns the following output:
  151. - ``mcap`` (:class:`~torch.Tensor`): If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be
  152. returned with AP score per class. If `average="macro"|"weighted"` then a single scalar is returned.
  153. Additional dimension ``...`` will be flattened into the batch dimension.
  154. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  155. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  156. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  157. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  158. size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory).
  159. Args:
  160. num_classes: Integer specifying the number of classes
  161. average:
  162. Defines the reduction that is applied over classes. Should be one of the following:
  163. - ``macro``: Calculate score for each class and average them
  164. - ``weighted``: calculates score for each class and computes weighted average using their support
  165. - ``"none"`` or ``None``: calculates score for each class and applies no reduction
  166. thresholds:
  167. Can be one of:
  168. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  169. all the data. Most accurate but also most memory consuming approach.
  170. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  171. 0 to 1 as bins for the calculation.
  172. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  173. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  174. bins for the calculation.
  175. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  176. Set to ``False`` for faster computations.
  177. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  178. Example:
  179. >>> from torch import tensor
  180. >>> from torchmetrics.classification import MulticlassAveragePrecision
  181. >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
  182. ... [0.05, 0.75, 0.05, 0.05, 0.05],
  183. ... [0.05, 0.05, 0.75, 0.05, 0.05],
  184. ... [0.05, 0.05, 0.05, 0.75, 0.05]])
  185. >>> target = tensor([0, 1, 3, 2])
  186. >>> metric = MulticlassAveragePrecision(num_classes=5, average="macro", thresholds=None)
  187. >>> metric(preds, target)
  188. tensor(0.6250)
  189. >>> mcap = MulticlassAveragePrecision(num_classes=5, average=None, thresholds=None)
  190. >>> mcap(preds, target)
  191. tensor([1.0000, 1.0000, 0.2500, 0.2500, nan])
  192. >>> mcap = MulticlassAveragePrecision(num_classes=5, average="macro", thresholds=5)
  193. >>> mcap(preds, target)
  194. tensor(0.5000)
  195. >>> mcap = MulticlassAveragePrecision(num_classes=5, average=None, thresholds=5)
  196. >>> mcap(preds, target)
  197. tensor([1.0000, 1.0000, 0.2500, 0.2500, -0.0000])
  198. """
  199. is_differentiable: bool = False
  200. higher_is_better: bool = True
  201. full_state_update: bool = False
  202. plot_lower_bound: float = 0.0
  203. plot_upper_bound: float = 1.0
  204. plot_legend_name: str = "Class"
  205. def __init__(
  206. self,
  207. num_classes: int,
  208. average: Optional[Literal["macro", "weighted", "none"]] = "macro",
  209. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  210. ignore_index: Optional[int] = None,
  211. validate_args: bool = True,
  212. **kwargs: Any,
  213. ) -> None:
  214. super().__init__(
  215. num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs
  216. )
  217. if validate_args:
  218. _multiclass_average_precision_arg_validation(num_classes, average, thresholds, ignore_index)
  219. self.average = average # type: ignore[assignment]
  220. self.validate_args = validate_args
  221. def compute(self) -> Tensor: # type: ignore[override]
  222. """Compute metric."""
  223. state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
  224. return _multiclass_average_precision_compute(
  225. state,
  226. self.num_classes,
  227. self.average, # type: ignore[arg-type]
  228. self.thresholds,
  229. )
  230. def plot( # type: ignore[override]
  231. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  232. ) -> _PLOT_OUT_TYPE:
  233. """Plot a single or multiple values from the metric.
  234. Args:
  235. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  236. If no value is provided, will automatically call `metric.compute` and plot that result.
  237. ax: An matplotlib axis object. If provided will add plot to that axis
  238. Returns:
  239. Figure and Axes object
  240. Raises:
  241. ModuleNotFoundError:
  242. If `matplotlib` is not installed
  243. .. plot::
  244. :scale: 75
  245. >>> # Example plotting a single
  246. >>> import torch
  247. >>> from torchmetrics.classification import MulticlassAveragePrecision
  248. >>> metric = MulticlassAveragePrecision(num_classes=3)
  249. >>> metric.update(torch.randn(20, 3), torch.randint(3,(20,)))
  250. >>> fig_, ax_ = metric.plot()
  251. .. plot::
  252. :scale: 75
  253. >>> # Example plotting multiple values
  254. >>> import torch
  255. >>> from torchmetrics.classification import MulticlassAveragePrecision
  256. >>> metric = MulticlassAveragePrecision(num_classes=3)
  257. >>> values = [ ]
  258. >>> for _ in range(10):
  259. ... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,))))
  260. >>> fig_, ax_ = metric.plot(values)
  261. """
  262. return self._plot(val, ax)
  263. class MultilabelAveragePrecision(MultilabelPrecisionRecallCurve):
  264. r"""Compute the average precision (AP) score for multilabel tasks.
  265. The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the
  266. difference in recall from the previous threshold as weight:
  267. .. math::
  268. AP = \sum_{n} (R_n - R_{n-1}) P_n
  269. where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is
  270. equivalent to the area under the precision-recall curve (AUPRC).
  271. As input to ``forward`` and ``update`` the metric accepts the following input:
  272. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)`` containing probabilities or logits
  273. for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto
  274. apply sigmoid per element.
  275. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)`` containing ground truth labels, and
  276. therefore only contain {0,1} values (except if `ignore_index` is specified).
  277. As output to ``forward`` and ``compute`` the metric returns the following output:
  278. - ``mlap`` (:class:`~torch.Tensor`): If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be
  279. returned with AP score per class. If `average="micro|macro"|"weighted"` then a single scalar is returned.
  280. Additional dimension ``...`` will be flattened into the batch dimension.
  281. The implementation both supports calculating the metric in a non-binned but accurate version and a binned
  282. version that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate
  283. the non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the
  284. `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  285. size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
  286. Args:
  287. num_labels: Integer specifying the number of labels
  288. average:
  289. Defines the reduction that is applied over labels. Should be one of the following:
  290. - ``micro``: Sum score over all labels
  291. - ``macro``: Calculate score for each label and average them
  292. - ``weighted``: calculates score for each label and computes weighted average using their support
  293. - ``"none"`` or ``None``: calculates score for each label and applies no reduction
  294. thresholds:
  295. Can be one of:
  296. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  297. all the data. Most accurate but also most memory consuming approach.
  298. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  299. 0 to 1 as bins for the calculation.
  300. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  301. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  302. bins for the calculation.
  303. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  304. Set to ``False`` for faster computations.
  305. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  306. Example:
  307. >>> from torch import tensor
  308. >>> from torchmetrics.classification import MultilabelAveragePrecision
  309. >>> preds = tensor([[0.75, 0.05, 0.35],
  310. ... [0.45, 0.75, 0.05],
  311. ... [0.05, 0.55, 0.75],
  312. ... [0.05, 0.65, 0.05]])
  313. >>> target = tensor([[1, 0, 1],
  314. ... [0, 0, 0],
  315. ... [0, 1, 1],
  316. ... [1, 1, 1]])
  317. >>> metric = MultilabelAveragePrecision(num_labels=3, average="macro", thresholds=None)
  318. >>> metric(preds, target)
  319. tensor(0.7500)
  320. >>> mlap = MultilabelAveragePrecision(num_labels=3, average=None, thresholds=None)
  321. >>> mlap(preds, target)
  322. tensor([0.7500, 0.5833, 0.9167])
  323. >>> mlap = MultilabelAveragePrecision(num_labels=3, average="macro", thresholds=5)
  324. >>> mlap(preds, target)
  325. tensor(0.7778)
  326. >>> mlap = MultilabelAveragePrecision(num_labels=3, average=None, thresholds=5)
  327. >>> mlap(preds, target)
  328. tensor([0.7500, 0.6667, 0.9167])
  329. """
  330. is_differentiable: bool = False
  331. higher_is_better: bool = True
  332. full_state_update: bool = False
  333. plot_lower_bound: float = 0.0
  334. plot_upper_bound: float = 1.0
  335. plot_legend_name: str = "Label"
  336. def __init__(
  337. self,
  338. num_labels: int,
  339. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  340. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  341. ignore_index: Optional[int] = None,
  342. validate_args: bool = True,
  343. **kwargs: Any,
  344. ) -> None:
  345. super().__init__(
  346. num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs
  347. )
  348. if validate_args:
  349. _multilabel_average_precision_arg_validation(num_labels, average, thresholds, ignore_index)
  350. self.average = average
  351. self.validate_args = validate_args
  352. def compute(self) -> Tensor: # type: ignore[override]
  353. """Compute metric."""
  354. state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
  355. return _multilabel_average_precision_compute(
  356. state, self.num_labels, self.average, self.thresholds, self.ignore_index
  357. )
  358. def plot( # type: ignore[override]
  359. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  360. ) -> _PLOT_OUT_TYPE:
  361. """Plot a single or multiple values from the metric.
  362. Args:
  363. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  364. If no value is provided, will automatically call `metric.compute` and plot that result.
  365. ax: An matplotlib axis object. If provided will add plot to that axis
  366. Returns:
  367. Figure and Axes object
  368. Raises:
  369. ModuleNotFoundError:
  370. If `matplotlib` is not installed
  371. .. plot::
  372. :scale: 75
  373. >>> # Example plotting a single
  374. >>> import torch
  375. >>> from torchmetrics.classification import MultilabelAveragePrecision
  376. >>> metric = MultilabelAveragePrecision(num_labels=3)
  377. >>> metric.update(torch.rand(20,3), torch.randint(2, (20,3)))
  378. >>> fig_, ax_ = metric.plot()
  379. .. plot::
  380. :scale: 75
  381. >>> # Example plotting multiple values
  382. >>> import torch
  383. >>> from torchmetrics.classification import MultilabelAveragePrecision
  384. >>> metric = MultilabelAveragePrecision(num_labels=3)
  385. >>> values = [ ]
  386. >>> for _ in range(10):
  387. ... values.append(metric(torch.rand(20,3), torch.randint(2, (20,3))))
  388. >>> fig_, ax_ = metric.plot(values)
  389. """
  390. return self._plot(val, ax)
  391. class AveragePrecision(_ClassificationTaskWrapper):
  392. r"""Compute the average precision (AP) score.
  393. The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the
  394. difference in recall from the previous threshold as weight:
  395. .. math::
  396. AP = \sum_{n} (R_n - R_{n-1}) P_n
  397. where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is
  398. equivalent to the area under the precision-recall curve (AUPRC).
  399. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  400. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  401. :class:`~torchmetrics.classification.BinaryAveragePrecision`,
  402. :class:`~torchmetrics.classification.MulticlassAveragePrecision` and
  403. :class:`~torchmetrics.classification.MultilabelAveragePrecision` for the specific details of each argument
  404. influence and examples.
  405. Legacy Example:
  406. >>> from torch import tensor
  407. >>> pred = tensor([0, 0.1, 0.8, 0.4])
  408. >>> target = tensor([0, 1, 1, 1])
  409. >>> average_precision = AveragePrecision(task="binary")
  410. >>> average_precision(pred, target)
  411. tensor(1.)
  412. >>> pred = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
  413. ... [0.05, 0.75, 0.05, 0.05, 0.05],
  414. ... [0.05, 0.05, 0.75, 0.05, 0.05],
  415. ... [0.05, 0.05, 0.05, 0.75, 0.05]])
  416. >>> target = tensor([0, 1, 3, 2])
  417. >>> average_precision = AveragePrecision(task="multiclass", num_classes=5, average=None)
  418. >>> average_precision(pred, target)
  419. tensor([1.0000, 1.0000, 0.2500, 0.2500, nan])
  420. """
  421. def __new__( # type: ignore[misc]
  422. cls: type["AveragePrecision"],
  423. task: Literal["binary", "multiclass", "multilabel"],
  424. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  425. num_classes: Optional[int] = None,
  426. num_labels: Optional[int] = None,
  427. average: Optional[Literal["macro", "weighted", "none"]] = "macro",
  428. ignore_index: Optional[int] = None,
  429. validate_args: bool = True,
  430. **kwargs: Any,
  431. ) -> Metric:
  432. """Initialize task metric."""
  433. task = ClassificationTask.from_str(task)
  434. kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args})
  435. if task == ClassificationTask.BINARY:
  436. return BinaryAveragePrecision(**kwargs)
  437. if task == ClassificationTask.MULTICLASS:
  438. if not isinstance(num_classes, int):
  439. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  440. return MulticlassAveragePrecision(num_classes, average, **kwargs)
  441. if task == ClassificationTask.MULTILABEL:
  442. if not isinstance(num_labels, int):
  443. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  444. return MultilabelAveragePrecision(num_labels, average, **kwargs)
  445. raise ValueError(f"Task {task} not supported!")