f_beta.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. from torch import Tensor
  16. from typing_extensions import Literal
  17. from torchmetrics.functional.classification.stat_scores import (
  18. _binary_stat_scores_arg_validation,
  19. _binary_stat_scores_format,
  20. _binary_stat_scores_tensor_validation,
  21. _binary_stat_scores_update,
  22. _multiclass_stat_scores_arg_validation,
  23. _multiclass_stat_scores_format,
  24. _multiclass_stat_scores_tensor_validation,
  25. _multiclass_stat_scores_update,
  26. _multilabel_stat_scores_arg_validation,
  27. _multilabel_stat_scores_format,
  28. _multilabel_stat_scores_tensor_validation,
  29. _multilabel_stat_scores_update,
  30. )
  31. from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide
  32. from torchmetrics.utilities.enums import ClassificationTask
  33. def _fbeta_reduce(
  34. tp: Tensor,
  35. fp: Tensor,
  36. tn: Tensor,
  37. fn: Tensor,
  38. beta: float,
  39. average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
  40. multidim_average: Literal["global", "samplewise"] = "global",
  41. multilabel: bool = False,
  42. zero_division: float = 0,
  43. ) -> Tensor:
  44. beta2 = beta**2
  45. if average == "binary":
  46. return _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp, zero_division)
  47. if average == "micro":
  48. tp = tp.sum(dim=0 if multidim_average == "global" else 1)
  49. fn = fn.sum(dim=0 if multidim_average == "global" else 1)
  50. fp = fp.sum(dim=0 if multidim_average == "global" else 1)
  51. return _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp, zero_division)
  52. fbeta_score = _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp, zero_division)
  53. return _adjust_weights_safe_divide(fbeta_score, average, multilabel, tp, fp, fn)
  54. def _binary_fbeta_score_arg_validation(
  55. beta: float,
  56. threshold: float = 0.5,
  57. multidim_average: Literal["global", "samplewise"] = "global",
  58. ignore_index: Optional[int] = None,
  59. zero_division: float = 0,
  60. ) -> None:
  61. if not (isinstance(beta, float) and beta > 0):
  62. raise ValueError(f"Expected argument `beta` to be a float larger than 0, but got {beta}.")
  63. _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
  64. def binary_fbeta_score(
  65. preds: Tensor,
  66. target: Tensor,
  67. beta: float,
  68. threshold: float = 0.5,
  69. multidim_average: Literal["global", "samplewise"] = "global",
  70. ignore_index: Optional[int] = None,
  71. validate_args: bool = True,
  72. zero_division: float = 0,
  73. ) -> Tensor:
  74. r"""Compute `F-score`_ metric for binary tasks.
  75. .. math::
  76. F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
  77. {(\beta^2 * \text{precision}) + \text{recall}}
  78. Accepts the following input tensors:
  79. - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside
  80. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  81. we convert to int tensor with thresholding using the value in ``threshold``.
  82. - ``target`` (int tensor): ``(N, ...)``
  83. Args:
  84. preds: Tensor with predictions
  85. target: Tensor with true labels
  86. beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight
  87. threshold: Threshold for transforming probability to binary {0,1} predictions
  88. multidim_average:
  89. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  90. - ``global``: Additional dimensions are flatted along the batch dimension
  91. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  92. The statistics in this case are calculated over the additional dimensions.
  93. ignore_index:
  94. Specifies a target value that is ignored and does not contribute to the metric calculation
  95. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  96. Set to ``False`` for faster computations.
  97. zero_division: Should be `0` or `1`. The value returned when
  98. :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`.
  99. Returns:
  100. If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average``
  101. is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample.
  102. Example (preds is int tensor):
  103. >>> from torch import tensor
  104. >>> from torchmetrics.functional.classification import binary_fbeta_score
  105. >>> target = tensor([0, 1, 0, 1, 0, 1])
  106. >>> preds = tensor([0, 0, 1, 1, 0, 1])
  107. >>> binary_fbeta_score(preds, target, beta=2.0)
  108. tensor(0.6667)
  109. Example (preds is float tensor):
  110. >>> from torchmetrics.functional.classification import binary_fbeta_score
  111. >>> target = tensor([0, 1, 0, 1, 0, 1])
  112. >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
  113. >>> binary_fbeta_score(preds, target, beta=2.0)
  114. tensor(0.6667)
  115. Example (multidim tensors):
  116. >>> from torchmetrics.functional.classification import binary_fbeta_score
  117. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  118. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  119. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  120. >>> binary_fbeta_score(preds, target, beta=2.0, multidim_average='samplewise')
  121. tensor([0.5882, 0.0000])
  122. """
  123. if validate_args:
  124. _binary_fbeta_score_arg_validation(beta, threshold, multidim_average, ignore_index, zero_division)
  125. _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index)
  126. preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index)
  127. tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average)
  128. return _fbeta_reduce(
  129. tp, fp, tn, fn, beta, average="binary", multidim_average=multidim_average, zero_division=zero_division
  130. )
  131. def _multiclass_fbeta_score_arg_validation(
  132. beta: float,
  133. num_classes: int,
  134. top_k: int = 1,
  135. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  136. multidim_average: Literal["global", "samplewise"] = "global",
  137. ignore_index: Optional[int] = None,
  138. zero_division: float = 0,
  139. ) -> None:
  140. if not (isinstance(beta, float) and beta > 0):
  141. raise ValueError(f"Expected argument `beta` to be a float larger than 0, but got {beta}.")
  142. _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index, zero_division)
  143. def multiclass_fbeta_score(
  144. preds: Tensor,
  145. target: Tensor,
  146. beta: float,
  147. num_classes: int,
  148. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  149. top_k: int = 1,
  150. multidim_average: Literal["global", "samplewise"] = "global",
  151. ignore_index: Optional[int] = None,
  152. validate_args: bool = True,
  153. zero_division: float = 0,
  154. ) -> Tensor:
  155. r"""Compute `F-score`_ metric for multiclass tasks.
  156. .. math::
  157. F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
  158. {(\beta^2 * \text{precision}) + \text{recall}}
  159. Accepts the following input tensors:
  160. - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point
  161. we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into
  162. an int tensor.
  163. - ``target`` (int tensor): ``(N, ...)``
  164. Args:
  165. preds: Tensor with predictions
  166. target: Tensor with true labels
  167. beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight
  168. num_classes: Integer specifying the number of classes
  169. average:
  170. Defines the reduction that is applied over labels. Should be one of the following:
  171. - ``micro``: Sum statistics over all labels
  172. - ``macro``: Calculate statistics for each label and average them
  173. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  174. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  175. top_k:
  176. Number of highest probability or logit score predictions considered to find the correct label.
  177. Only works when ``preds`` contain probabilities/logits.
  178. multidim_average:
  179. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  180. - ``global``: Additional dimensions are flatted along the batch dimension
  181. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  182. The statistics in this case are calculated over the additional dimensions.
  183. ignore_index:
  184. Specifies a target value that is ignored and does not contribute to the metric calculation
  185. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  186. Set to ``False`` for faster computations.
  187. zero_division: Should be `0` or `1`. The value returned when
  188. :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`.
  189. Returns:
  190. The returned shape depends on the ``average`` and ``multidim_average`` arguments:
  191. - If ``multidim_average`` is set to ``global``:
  192. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  193. - If ``average=None/'none'``, the shape will be ``(C,)``
  194. - If ``multidim_average`` is set to ``samplewise``:
  195. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  196. - If ``average=None/'none'``, the shape will be ``(N, C)``
  197. Example (preds is int tensor):
  198. >>> from torch import tensor
  199. >>> from torchmetrics.functional.classification import multiclass_fbeta_score
  200. >>> target = tensor([2, 1, 0, 0])
  201. >>> preds = tensor([2, 1, 0, 1])
  202. >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3)
  203. tensor(0.7963)
  204. >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, average=None)
  205. tensor([0.5556, 0.8333, 1.0000])
  206. Example (preds is float tensor):
  207. >>> from torchmetrics.functional.classification import multiclass_fbeta_score
  208. >>> target = tensor([2, 1, 0, 0])
  209. >>> preds = tensor([[0.16, 0.26, 0.58],
  210. ... [0.22, 0.61, 0.17],
  211. ... [0.71, 0.09, 0.20],
  212. ... [0.05, 0.82, 0.13]])
  213. >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3)
  214. tensor(0.7963)
  215. >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, average=None)
  216. tensor([0.5556, 0.8333, 1.0000])
  217. Example (multidim tensors):
  218. >>> from torchmetrics.functional.classification import multiclass_fbeta_score
  219. >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
  220. >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
  221. >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, multidim_average='samplewise')
  222. tensor([0.4697, 0.2706])
  223. >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, multidim_average='samplewise', average=None)
  224. tensor([[0.9091, 0.0000, 0.5000],
  225. [0.0000, 0.3571, 0.4545]])
  226. """
  227. if validate_args:
  228. _multiclass_fbeta_score_arg_validation(
  229. beta, num_classes, top_k, average, multidim_average, ignore_index, zero_division
  230. )
  231. _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index)
  232. preds, target = _multiclass_stat_scores_format(preds, target, top_k)
  233. tp, fp, tn, fn = _multiclass_stat_scores_update(
  234. preds, target, num_classes, top_k, average, multidim_average, ignore_index
  235. )
  236. return _fbeta_reduce(
  237. tp, fp, tn, fn, beta, average=average, multidim_average=multidim_average, zero_division=zero_division
  238. )
  239. def _multilabel_fbeta_score_arg_validation(
  240. beta: float,
  241. num_labels: int,
  242. threshold: float = 0.5,
  243. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  244. multidim_average: Literal["global", "samplewise"] = "global",
  245. ignore_index: Optional[int] = None,
  246. zero_division: float = 0,
  247. ) -> None:
  248. if not (isinstance(beta, float) and beta > 0):
  249. raise ValueError(f"Expected argument `beta` to be a float larger than 0, but got {beta}.")
  250. _multilabel_stat_scores_arg_validation(
  251. num_labels, threshold, average, multidim_average, ignore_index, zero_division
  252. )
  253. def multilabel_fbeta_score(
  254. preds: Tensor,
  255. target: Tensor,
  256. beta: float,
  257. num_labels: int,
  258. threshold: float = 0.5,
  259. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  260. multidim_average: Literal["global", "samplewise"] = "global",
  261. ignore_index: Optional[int] = None,
  262. validate_args: bool = True,
  263. zero_division: float = 0,
  264. ) -> Tensor:
  265. r"""Compute `F-score`_ metric for multilabel tasks.
  266. .. math::
  267. F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
  268. {(\beta^2 * \text{precision}) + \text{recall}}
  269. Accepts the following input tensors:
  270. - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside
  271. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  272. we convert to int tensor with thresholding using the value in ``threshold``.
  273. - ``target`` (int tensor): ``(N, C, ...)``
  274. Args:
  275. preds: Tensor with predictions
  276. target: Tensor with true labels
  277. beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight
  278. num_labels: Integer specifying the number of labels
  279. threshold: Threshold for transforming probability to binary (0,1) predictions
  280. average:
  281. Defines the reduction that is applied over labels. Should be one of the following:
  282. - ``micro``: Sum statistics over all labels
  283. - ``macro``: Calculate statistics for each label and average them
  284. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  285. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  286. multidim_average:
  287. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  288. - ``global``: Additional dimensions are flatted along the batch dimension
  289. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  290. The statistics in this case are calculated over the additional dimensions.
  291. ignore_index:
  292. Specifies a target value that is ignored and does not contribute to the metric calculation
  293. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  294. Set to ``False`` for faster computations.
  295. zero_division: Should be `0` or `1`. The value returned when
  296. :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`.
  297. Returns:
  298. The returned shape depends on the ``average`` and ``multidim_average`` arguments:
  299. - If ``multidim_average`` is set to ``global``:
  300. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  301. - If ``average=None/'none'``, the shape will be ``(C,)``
  302. - If ``multidim_average`` is set to ``samplewise``:
  303. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  304. - If ``average=None/'none'``, the shape will be ``(N, C)``
  305. Example (preds is int tensor):
  306. >>> from torch import tensor
  307. >>> from torchmetrics.functional.classification import multilabel_fbeta_score
  308. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  309. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  310. >>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3)
  311. tensor(0.6111)
  312. >>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3, average=None)
  313. tensor([1.0000, 0.0000, 0.8333])
  314. Example (preds is float tensor):
  315. >>> from torchmetrics.functional.classification import multilabel_fbeta_score
  316. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  317. >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
  318. >>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3)
  319. tensor(0.6111)
  320. >>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3, average=None)
  321. tensor([1.0000, 0.0000, 0.8333])
  322. Example (multidim tensors):
  323. >>> from torchmetrics.functional.classification import multilabel_fbeta_score
  324. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  325. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  326. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  327. >>> multilabel_fbeta_score(preds, target, num_labels=3, beta=2.0, multidim_average='samplewise')
  328. tensor([0.5556, 0.0000])
  329. >>> multilabel_fbeta_score(preds, target, num_labels=3, beta=2.0, multidim_average='samplewise', average=None)
  330. tensor([[0.8333, 0.8333, 0.0000],
  331. [0.0000, 0.0000, 0.0000]])
  332. """
  333. if validate_args:
  334. _multilabel_fbeta_score_arg_validation(
  335. beta, num_labels, threshold, average, multidim_average, ignore_index, zero_division
  336. )
  337. _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
  338. preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
  339. tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average)
  340. return _fbeta_reduce(
  341. tp,
  342. fp,
  343. tn,
  344. fn,
  345. beta,
  346. average=average,
  347. multidim_average=multidim_average,
  348. multilabel=True,
  349. zero_division=zero_division,
  350. )
  351. def binary_f1_score(
  352. preds: Tensor,
  353. target: Tensor,
  354. threshold: float = 0.5,
  355. multidim_average: Literal["global", "samplewise"] = "global",
  356. ignore_index: Optional[int] = None,
  357. validate_args: bool = True,
  358. zero_division: float = 0,
  359. ) -> Tensor:
  360. r"""Compute F-1 score for binary tasks.
  361. .. math::
  362. F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}
  363. Accepts the following input tensors:
  364. - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside
  365. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  366. we convert to int tensor with thresholding using the value in ``threshold``.
  367. - ``target`` (int tensor): ``(N, ...)``
  368. Args:
  369. preds: Tensor with predictions
  370. target: Tensor with true labels
  371. threshold: Threshold for transforming probability to binary {0,1} predictions
  372. multidim_average:
  373. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  374. - ``global``: Additional dimensions are flatted along the batch dimension
  375. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  376. The statistics in this case are calculated over the additional dimensions.
  377. ignore_index:
  378. Specifies a target value that is ignored and does not contribute to the metric calculation
  379. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  380. Set to ``False`` for faster computations.
  381. zero_division: Should be `0` or `1`. The value returned when
  382. :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`.
  383. Returns:
  384. If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average``
  385. is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample.
  386. Example (preds is int tensor):
  387. >>> from torch import tensor
  388. >>> from torchmetrics.functional.classification import binary_f1_score
  389. >>> target = tensor([0, 1, 0, 1, 0, 1])
  390. >>> preds = tensor([0, 0, 1, 1, 0, 1])
  391. >>> binary_f1_score(preds, target)
  392. tensor(0.6667)
  393. Example (preds is float tensor):
  394. >>> from torchmetrics.functional.classification import binary_f1_score
  395. >>> target = tensor([0, 1, 0, 1, 0, 1])
  396. >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
  397. >>> binary_f1_score(preds, target)
  398. tensor(0.6667)
  399. Example (multidim tensors):
  400. >>> from torchmetrics.functional.classification import binary_f1_score
  401. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  402. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  403. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  404. >>> binary_f1_score(preds, target, multidim_average='samplewise')
  405. tensor([0.5000, 0.0000])
  406. """
  407. return binary_fbeta_score(
  408. preds=preds,
  409. target=target,
  410. beta=1.0,
  411. threshold=threshold,
  412. multidim_average=multidim_average,
  413. ignore_index=ignore_index,
  414. validate_args=validate_args,
  415. zero_division=zero_division,
  416. )
  417. def multiclass_f1_score(
  418. preds: Tensor,
  419. target: Tensor,
  420. num_classes: int,
  421. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  422. top_k: int = 1,
  423. multidim_average: Literal["global", "samplewise"] = "global",
  424. ignore_index: Optional[int] = None,
  425. validate_args: bool = True,
  426. zero_division: float = 0,
  427. ) -> Tensor:
  428. r"""Compute F-1 score for multiclass tasks.
  429. .. math::
  430. F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}
  431. Accepts the following input tensors:
  432. - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point
  433. we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into
  434. an int tensor.
  435. - ``target`` (int tensor): ``(N, ...)``
  436. Args:
  437. preds: Tensor with predictions
  438. target: Tensor with true labels
  439. num_classes: Integer specifying the number of classes
  440. average:
  441. Defines the reduction that is applied over labels. Should be one of the following:
  442. - ``micro``: Sum statistics over all labels
  443. - ``macro``: Calculate statistics for each label and average them
  444. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  445. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  446. top_k:
  447. Number of highest probability or logit score predictions considered to find the correct label.
  448. Only works when ``preds`` contain probabilities/logits.
  449. multidim_average:
  450. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  451. - ``global``: Additional dimensions are flatted along the batch dimension
  452. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  453. The statistics in this case are calculated over the additional dimensions.
  454. ignore_index:
  455. Specifies a target value that is ignored and does not contribute to the metric calculation
  456. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  457. Set to ``False`` for faster computations.
  458. zero_division: Should be `0` or `1`. The value returned when
  459. :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`.
  460. Returns:
  461. The returned shape depends on the ``average`` and ``multidim_average`` arguments:
  462. - If ``multidim_average`` is set to ``global``:
  463. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  464. - If ``average=None/'none'``, the shape will be ``(C,)``
  465. - If ``multidim_average`` is set to ``samplewise``:
  466. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  467. - If ``average=None/'none'``, the shape will be ``(N, C)``
  468. Example (preds is int tensor):
  469. >>> from torch import tensor
  470. >>> from torchmetrics.functional.classification import multiclass_f1_score
  471. >>> target = tensor([2, 1, 0, 0])
  472. >>> preds = tensor([2, 1, 0, 1])
  473. >>> multiclass_f1_score(preds, target, num_classes=3)
  474. tensor(0.7778)
  475. >>> multiclass_f1_score(preds, target, num_classes=3, average=None)
  476. tensor([0.6667, 0.6667, 1.0000])
  477. Example (preds is float tensor):
  478. >>> from torchmetrics.functional.classification import multiclass_f1_score
  479. >>> target = tensor([2, 1, 0, 0])
  480. >>> preds = tensor([[0.16, 0.26, 0.58],
  481. ... [0.22, 0.61, 0.17],
  482. ... [0.71, 0.09, 0.20],
  483. ... [0.05, 0.82, 0.13]])
  484. >>> multiclass_f1_score(preds, target, num_classes=3)
  485. tensor(0.7778)
  486. >>> multiclass_f1_score(preds, target, num_classes=3, average=None)
  487. tensor([0.6667, 0.6667, 1.0000])
  488. Example (multidim tensors):
  489. >>> from torchmetrics.functional.classification import multiclass_f1_score
  490. >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
  491. >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
  492. >>> multiclass_f1_score(preds, target, num_classes=3, multidim_average='samplewise')
  493. tensor([0.4333, 0.2667])
  494. >>> multiclass_f1_score(preds, target, num_classes=3, multidim_average='samplewise', average=None)
  495. tensor([[0.8000, 0.0000, 0.5000],
  496. [0.0000, 0.4000, 0.4000]])
  497. """
  498. return multiclass_fbeta_score(
  499. preds=preds,
  500. target=target,
  501. beta=1.0,
  502. num_classes=num_classes,
  503. average=average,
  504. top_k=top_k,
  505. multidim_average=multidim_average,
  506. ignore_index=ignore_index,
  507. validate_args=validate_args,
  508. zero_division=zero_division,
  509. )
  510. def multilabel_f1_score(
  511. preds: Tensor,
  512. target: Tensor,
  513. num_labels: int,
  514. threshold: float = 0.5,
  515. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  516. multidim_average: Literal["global", "samplewise"] = "global",
  517. ignore_index: Optional[int] = None,
  518. validate_args: bool = True,
  519. zero_division: float = 0,
  520. ) -> Tensor:
  521. r"""Compute F-1 score for multilabel tasks.
  522. .. math::
  523. F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}
  524. Accepts the following input tensors:
  525. - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside
  526. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  527. we convert to int tensor with thresholding using the value in ``threshold``.
  528. - ``target`` (int tensor): ``(N, C, ...)``
  529. Args:
  530. preds: Tensor with predictions
  531. target: Tensor with true labels
  532. num_labels: Integer specifying the number of labels
  533. threshold: Threshold for transforming probability to binary (0,1) predictions
  534. average:
  535. Defines the reduction that is applied over labels. Should be one of the following:
  536. - ``micro``: Sum statistics over all labels
  537. - ``macro``: Calculate statistics for each label and average them
  538. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  539. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  540. multidim_average:
  541. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  542. - ``global``: Additional dimensions are flatted along the batch dimension
  543. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  544. The statistics in this case are calculated over the additional dimensions.
  545. ignore_index:
  546. Specifies a target value that is ignored and does not contribute to the metric calculation
  547. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  548. Set to ``False`` for faster computations.
  549. zero_division: Should be `0` or `1`. The value returned when
  550. :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`.
  551. Returns:
  552. The returned shape depends on the ``average`` and ``multidim_average`` arguments:
  553. - If ``multidim_average`` is set to ``global``:
  554. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  555. - If ``average=None/'none'``, the shape will be ``(C,)``
  556. - If ``multidim_average`` is set to ``samplewise``:
  557. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  558. - If ``average=None/'none'``, the shape will be ``(N, C)``
  559. Example (preds is int tensor):
  560. >>> from torch import tensor
  561. >>> from torchmetrics.functional.classification import multilabel_f1_score
  562. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  563. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  564. >>> multilabel_f1_score(preds, target, num_labels=3)
  565. tensor(0.5556)
  566. >>> multilabel_f1_score(preds, target, num_labels=3, average=None)
  567. tensor([1.0000, 0.0000, 0.6667])
  568. Example (preds is float tensor):
  569. >>> from torchmetrics.functional.classification import multilabel_f1_score
  570. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  571. >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
  572. >>> multilabel_f1_score(preds, target, num_labels=3)
  573. tensor(0.5556)
  574. >>> multilabel_f1_score(preds, target, num_labels=3, average=None)
  575. tensor([1.0000, 0.0000, 0.6667])
  576. Example (multidim tensors):
  577. >>> from torchmetrics.functional.classification import multilabel_f1_score
  578. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  579. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  580. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  581. >>> multilabel_f1_score(preds, target, num_labels=3, multidim_average='samplewise')
  582. tensor([0.4444, 0.0000])
  583. >>> multilabel_f1_score(preds, target, num_labels=3, multidim_average='samplewise', average=None)
  584. tensor([[0.6667, 0.6667, 0.0000],
  585. [0.0000, 0.0000, 0.0000]])
  586. """
  587. return multilabel_fbeta_score(
  588. preds=preds,
  589. target=target,
  590. beta=1.0,
  591. num_labels=num_labels,
  592. threshold=threshold,
  593. average=average,
  594. multidim_average=multidim_average,
  595. ignore_index=ignore_index,
  596. validate_args=validate_args,
  597. zero_division=zero_division,
  598. )
  599. def fbeta_score(
  600. preds: Tensor,
  601. target: Tensor,
  602. task: Literal["binary", "multiclass", "multilabel"],
  603. beta: float = 1.0,
  604. threshold: float = 0.5,
  605. num_classes: Optional[int] = None,
  606. num_labels: Optional[int] = None,
  607. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
  608. multidim_average: Optional[Literal["global", "samplewise"]] = "global",
  609. top_k: Optional[int] = 1,
  610. ignore_index: Optional[int] = None,
  611. validate_args: bool = True,
  612. zero_division: float = 0,
  613. ) -> Tensor:
  614. r"""Compute `F-score`_ metric.
  615. .. math::
  616. F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
  617. {(\beta^2 * \text{precision}) + \text{recall}}
  618. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  619. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  620. :func:`~torchmetrics.functional.classification.binary_fbeta_score`,
  621. :func:`~torchmetrics.functional.classification.multiclass_fbeta_score` and
  622. :func:`~torchmetrics.functional.classification.multilabel_fbeta_score` for the specific
  623. details of each argument influence and examples.
  624. Legacy Example:
  625. >>> from torch import tensor
  626. >>> target = tensor([0, 1, 2, 0, 1, 2])
  627. >>> preds = tensor([0, 2, 1, 0, 0, 1])
  628. >>> fbeta_score(preds, target, task="multiclass", num_classes=3, beta=0.5)
  629. tensor(0.3333)
  630. """
  631. task = ClassificationTask.from_str(task)
  632. assert multidim_average is not None # noqa: S101 # needed for mypy
  633. if task == ClassificationTask.BINARY:
  634. return binary_fbeta_score(
  635. preds, target, beta, threshold, multidim_average, ignore_index, validate_args, zero_division
  636. )
  637. if task == ClassificationTask.MULTICLASS:
  638. if not isinstance(num_classes, int):
  639. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  640. if not isinstance(top_k, int):
  641. raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`")
  642. return multiclass_fbeta_score(
  643. preds,
  644. target,
  645. beta,
  646. num_classes,
  647. average,
  648. top_k,
  649. multidim_average,
  650. ignore_index,
  651. validate_args,
  652. zero_division,
  653. )
  654. if task == ClassificationTask.MULTILABEL:
  655. if not isinstance(num_labels, int):
  656. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  657. return multilabel_fbeta_score(
  658. preds,
  659. target,
  660. beta,
  661. num_labels,
  662. threshold,
  663. average,
  664. multidim_average,
  665. ignore_index,
  666. validate_args,
  667. zero_division,
  668. )
  669. raise ValueError(f"Unsupported task `{task}` passed.")
  670. def f1_score(
  671. preds: Tensor,
  672. target: Tensor,
  673. task: Literal["binary", "multiclass", "multilabel"],
  674. threshold: float = 0.5,
  675. num_classes: Optional[int] = None,
  676. num_labels: Optional[int] = None,
  677. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
  678. multidim_average: Optional[Literal["global", "samplewise"]] = "global",
  679. top_k: Optional[int] = 1,
  680. ignore_index: Optional[int] = None,
  681. validate_args: bool = True,
  682. zero_division: float = 0,
  683. ) -> Tensor:
  684. r"""Compute F-1 score.
  685. .. math::
  686. F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}
  687. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  688. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  689. :func:`~torchmetrics.functional.classification.binary_f1_score`,
  690. :func:`~torchmetrics.functional.classification.multiclass_f1_score` and
  691. :func:`~torchmetrics.functional.classification.multilabel_f1_score` for the specific
  692. details of each argument influence and examples.
  693. Legacy Example:
  694. >>> from torch import tensor
  695. >>> target = tensor([0, 1, 2, 0, 1, 2])
  696. >>> preds = tensor([0, 2, 1, 0, 0, 1])
  697. >>> f1_score(preds, target, task="multiclass", num_classes=3)
  698. tensor(0.3333)
  699. """
  700. task = ClassificationTask.from_str(task)
  701. assert multidim_average is not None # noqa: S101 # needed for mypy
  702. if task == ClassificationTask.BINARY:
  703. return binary_f1_score(preds, target, threshold, multidim_average, ignore_index, validate_args, zero_division)
  704. if task == ClassificationTask.MULTICLASS:
  705. if not isinstance(num_classes, int):
  706. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  707. if not isinstance(top_k, int):
  708. raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`")
  709. return multiclass_f1_score(
  710. preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args, zero_division
  711. )
  712. if task == ClassificationTask.MULTILABEL:
  713. if not isinstance(num_labels, int):
  714. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  715. return multilabel_f1_score(
  716. preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args, zero_division
  717. )
  718. raise ValueError(f"Unsupported task `{task}` passed.")