f_beta.py 53 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.classification.base import _ClassificationTaskWrapper
  19. from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores
  20. from torchmetrics.functional.classification.f_beta import (
  21. _binary_fbeta_score_arg_validation,
  22. _fbeta_reduce,
  23. _multiclass_fbeta_score_arg_validation,
  24. _multilabel_fbeta_score_arg_validation,
  25. )
  26. from torchmetrics.metric import Metric
  27. from torchmetrics.utilities.enums import ClassificationTask
  28. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  29. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  30. if not _MATPLOTLIB_AVAILABLE:
  31. __doctest_skip__ = [
  32. "BinaryFBetaScore.plot",
  33. "MulticlassFBetaScore.plot",
  34. "MultilabelFBetaScore.plot",
  35. "BinaryF1Score.plot",
  36. "MulticlassF1Score.plot",
  37. "MultilabelF1Score.plot",
  38. ]
  39. class BinaryFBetaScore(BinaryStatScores):
  40. r"""Compute `F-score`_ metric for binary tasks.
  41. .. math::
  42. F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
  43. {(\beta^2 * \text{precision}) + \text{recall}}
  44. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0`
  45. where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false
  46. positives and false negatives respectively. If this case is encountered a score of `zero_division`
  47. (0 or 1, default is 0) is returned.
  48. As input to ``forward`` and ``update`` the metric accepts the following input:
  49. - ``preds`` (:class:`~torch.Tensor`): An int tensor or float tensor of shape ``(N, ...)``. If preds is a floating
  50. point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid
  51. per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
  52. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``.
  53. As output to ``forward`` and ``compute`` the metric returns the following output:
  54. - ``bfbs`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``multidim_average`` argument:
  55. - If ``multidim_average`` is set to ``global`` the output will be a scalar tensor
  56. - If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)`` consisting of
  57. a scalar value per sample.
  58. If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
  59. which the reduction will then be applied over instead of the sample dimension ``N``.
  60. Args:
  61. beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight
  62. threshold: Threshold for transforming probability to binary {0,1} predictions
  63. multidim_average:
  64. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  65. - ``global``: Additional dimensions are flatted along the batch dimension
  66. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  67. The statistics in this case are calculated over the additional dimensions.
  68. ignore_index:
  69. Specifies a target value that is ignored and does not contribute to the metric calculation
  70. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  71. Set to ``False`` for faster computations.
  72. zero_division: Should be `0` or `1`. The value returned when
  73. :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`.
  74. Example (preds is int tensor):
  75. >>> from torch import tensor
  76. >>> from torchmetrics.classification import BinaryFBetaScore
  77. >>> target = tensor([0, 1, 0, 1, 0, 1])
  78. >>> preds = tensor([0, 0, 1, 1, 0, 1])
  79. >>> metric = BinaryFBetaScore(beta=2.0)
  80. >>> metric(preds, target)
  81. tensor(0.6667)
  82. Example (preds is float tensor):
  83. >>> from torchmetrics.classification import BinaryFBetaScore
  84. >>> target = tensor([0, 1, 0, 1, 0, 1])
  85. >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
  86. >>> metric = BinaryFBetaScore(beta=2.0)
  87. >>> metric(preds, target)
  88. tensor(0.6667)
  89. Example (multidim tensors):
  90. >>> from torchmetrics.classification import BinaryFBetaScore
  91. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  92. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  93. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  94. >>> metric = BinaryFBetaScore(beta=2.0, multidim_average='samplewise')
  95. >>> metric(preds, target)
  96. tensor([0.5882, 0.0000])
  97. """
  98. is_differentiable: bool = False
  99. higher_is_better: Optional[bool] = True
  100. full_state_update: bool = False
  101. plot_lower_bound: float = 0.0
  102. plot_upper_bound: float = 1.0
  103. def __init__(
  104. self,
  105. beta: float,
  106. threshold: float = 0.5,
  107. multidim_average: Literal["global", "samplewise"] = "global",
  108. ignore_index: Optional[int] = None,
  109. validate_args: bool = True,
  110. zero_division: float = 0,
  111. **kwargs: Any,
  112. ) -> None:
  113. super().__init__(
  114. threshold=threshold,
  115. multidim_average=multidim_average,
  116. ignore_index=ignore_index,
  117. validate_args=False,
  118. **kwargs,
  119. )
  120. if validate_args:
  121. _binary_fbeta_score_arg_validation(beta, threshold, multidim_average, ignore_index, zero_division)
  122. self.validate_args = validate_args
  123. self.zero_division = zero_division
  124. self.beta = beta
  125. def compute(self) -> Tensor:
  126. """Compute metric."""
  127. tp, fp, tn, fn = self._final_state()
  128. return _fbeta_reduce(
  129. tp,
  130. fp,
  131. tn,
  132. fn,
  133. self.beta,
  134. average="binary",
  135. multidim_average=self.multidim_average,
  136. zero_division=self.zero_division,
  137. )
  138. def plot(
  139. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  140. ) -> _PLOT_OUT_TYPE:
  141. """Plot a single or multiple values from the metric.
  142. Args:
  143. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  144. If no value is provided, will automatically call `metric.compute` and plot that result.
  145. ax: An matplotlib axis object. If provided will add plot to that axis
  146. Returns:
  147. Figure object and Axes object
  148. Raises:
  149. ModuleNotFoundError:
  150. If `matplotlib` is not installed
  151. .. plot::
  152. :scale: 75
  153. >>> from torch import rand, randint
  154. >>> # Example plotting a single value
  155. >>> from torchmetrics.classification import BinaryFBetaScore
  156. >>> metric = BinaryFBetaScore(beta=2.0)
  157. >>> metric.update(rand(10), randint(2,(10,)))
  158. >>> fig_, ax_ = metric.plot()
  159. .. plot::
  160. :scale: 75
  161. >>> from torch import rand, randint
  162. >>> # Example plotting multiple values
  163. >>> from torchmetrics.classification import BinaryFBetaScore
  164. >>> metric = BinaryFBetaScore(beta=2.0)
  165. >>> values = [ ]
  166. >>> for _ in range(10):
  167. ... values.append(metric(rand(10), randint(2,(10,))))
  168. >>> fig_, ax_ = metric.plot(values)
  169. """
  170. return self._plot(val, ax)
  171. class MulticlassFBetaScore(MulticlassStatScores):
  172. r"""Compute `F-score`_ metric for multiclass tasks.
  173. .. math::
  174. F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
  175. {(\beta^2 * \text{precision}) + \text{recall}}
  176. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0`
  177. where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false
  178. positives and false negatives respectively. If this case is encountered for any class, the metric for that class
  179. will be set to `zero_division` (0 or 1, default is 0) and the overall metric may therefore be affected in turn.
  180. As input to ``forward`` and ``update`` the metric accepts the following input:
  181. - ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` or float tensor of shape ``(N, C, ..)``.
  182. If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert
  183. probabilities/logits into an int tensor.
  184. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``.
  185. As output to ``forward`` and ``compute`` the metric returns the following output:
  186. - ``mcfbs`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and
  187. ``multidim_average`` arguments:
  188. - If ``multidim_average`` is set to ``global``:
  189. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  190. - If ``average=None/'none'``, the shape will be ``(C,)``
  191. - If ``multidim_average`` is set to ``samplewise``:
  192. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  193. - If ``average=None/'none'``, the shape will be ``(N, C)``
  194. If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
  195. which the reduction will then be applied over instead of the sample dimension ``N``.
  196. Args:
  197. beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight
  198. num_classes: Integer specifying the number of classes
  199. average:
  200. Defines the reduction that is applied over labels. Should be one of the following:
  201. - ``micro``: Sum statistics over all labels
  202. - ``macro``: Calculate statistics for each label and average them
  203. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  204. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  205. top_k:
  206. Number of highest probability or logit score predictions considered to find the correct label.
  207. Only works when ``preds`` contain probabilities/logits.
  208. multidim_average:
  209. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  210. - ``global``: Additional dimensions are flatted along the batch dimension
  211. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  212. The statistics in this case are calculated over the additional dimensions.
  213. ignore_index:
  214. Specifies a target value that is ignored and does not contribute to the metric calculation
  215. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  216. Set to ``False`` for faster computations.
  217. zero_division: Should be `0` or `1`. The value returned when
  218. :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`.
  219. Example (preds is int tensor):
  220. >>> from torch import tensor
  221. >>> from torchmetrics.classification import MulticlassFBetaScore
  222. >>> target = tensor([2, 1, 0, 0])
  223. >>> preds = tensor([2, 1, 0, 1])
  224. >>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3)
  225. >>> metric(preds, target)
  226. tensor(0.7963)
  227. >>> mcfbs = MulticlassFBetaScore(beta=2.0, num_classes=3, average=None)
  228. >>> mcfbs(preds, target)
  229. tensor([0.5556, 0.8333, 1.0000])
  230. Example (preds is float tensor):
  231. >>> from torchmetrics.classification import MulticlassFBetaScore
  232. >>> target = tensor([2, 1, 0, 0])
  233. >>> preds = tensor([[0.16, 0.26, 0.58],
  234. ... [0.22, 0.61, 0.17],
  235. ... [0.71, 0.09, 0.20],
  236. ... [0.05, 0.82, 0.13]])
  237. >>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3)
  238. >>> metric(preds, target)
  239. tensor(0.7963)
  240. >>> mcfbs = MulticlassFBetaScore(beta=2.0, num_classes=3, average=None)
  241. >>> mcfbs(preds, target)
  242. tensor([0.5556, 0.8333, 1.0000])
  243. Example (multidim tensors):
  244. >>> from torchmetrics.classification import MulticlassFBetaScore
  245. >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
  246. >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
  247. >>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3, multidim_average='samplewise')
  248. >>> metric(preds, target)
  249. tensor([0.4697, 0.2706])
  250. >>> mcfbs = MulticlassFBetaScore(beta=2.0, num_classes=3, multidim_average='samplewise', average=None)
  251. >>> mcfbs(preds, target)
  252. tensor([[0.9091, 0.0000, 0.5000],
  253. [0.0000, 0.3571, 0.4545]])
  254. """
  255. is_differentiable: bool = False
  256. higher_is_better: Optional[bool] = True
  257. full_state_update: bool = False
  258. plot_lower_bound: float = 0.0
  259. plot_upper_bound: float = 1.0
  260. plot_legend_name: str = "Class"
  261. def __init__(
  262. self,
  263. beta: float,
  264. num_classes: int,
  265. top_k: int = 1,
  266. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  267. multidim_average: Literal["global", "samplewise"] = "global",
  268. ignore_index: Optional[int] = None,
  269. validate_args: bool = True,
  270. zero_division: float = 0,
  271. **kwargs: Any,
  272. ) -> None:
  273. super().__init__(
  274. num_classes=num_classes,
  275. top_k=top_k,
  276. average=average,
  277. multidim_average=multidim_average,
  278. ignore_index=ignore_index,
  279. validate_args=False,
  280. **kwargs,
  281. )
  282. if validate_args:
  283. _multiclass_fbeta_score_arg_validation(
  284. beta, num_classes, top_k, average, multidim_average, ignore_index, zero_division
  285. )
  286. self.validate_args = validate_args
  287. self.zero_division = zero_division
  288. self.beta = beta
  289. def compute(self) -> Tensor:
  290. """Compute metric."""
  291. tp, fp, tn, fn = self._final_state()
  292. return _fbeta_reduce(
  293. tp,
  294. fp,
  295. tn,
  296. fn,
  297. self.beta,
  298. average=self.average,
  299. multidim_average=self.multidim_average,
  300. zero_division=self.zero_division,
  301. )
  302. def plot(
  303. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  304. ) -> _PLOT_OUT_TYPE:
  305. """Plot a single or multiple values from the metric.
  306. Args:
  307. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  308. If no value is provided, will automatically call `metric.compute` and plot that result.
  309. ax: An matplotlib axis object. If provided will add plot to that axis
  310. Returns:
  311. Figure object and Axes object
  312. Raises:
  313. ModuleNotFoundError:
  314. If `matplotlib` is not installed
  315. .. plot::
  316. :scale: 75
  317. >>> from torch import randint
  318. >>> # Example plotting a single value per class
  319. >>> from torchmetrics.classification import MulticlassFBetaScore
  320. >>> metric = MulticlassFBetaScore(num_classes=3, beta=2.0, average=None)
  321. >>> metric.update(randint(3, (20,)), randint(3, (20,)))
  322. >>> fig_, ax_ = metric.plot()
  323. .. plot::
  324. :scale: 75
  325. >>> from torch import randint
  326. >>> # Example plotting a multiple values per class
  327. >>> from torchmetrics.classification import MulticlassFBetaScore
  328. >>> metric = MulticlassFBetaScore(num_classes=3, beta=2.0, average=None)
  329. >>> values = []
  330. >>> for _ in range(20):
  331. ... values.append(metric(randint(3, (20,)), randint(3, (20,))))
  332. >>> fig_, ax_ = metric.plot(values)
  333. """
  334. return self._plot(val, ax)
  335. class MultilabelFBetaScore(MultilabelStatScores):
  336. r"""Compute `F-score`_ metric for multilabel tasks.
  337. .. math::
  338. F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
  339. {(\beta^2 * \text{precision}) + \text{recall}}
  340. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0`
  341. where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false
  342. positives and false negatives respectively. If this case is encountered for any label, the metric for that label
  343. will be set to `zero_division` (0 or 1, default is 0) and the overall metric may therefore be affected in turn.
  344. As input to ``forward`` and ``update`` the metric accepts the following input:
  345. - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, C, ...)``. If preds is a floating
  346. point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid
  347. per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
  348. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``.
  349. As output to ``forward`` and ``compute`` the metric returns the following output:
  350. - ``mlfbs`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and
  351. ``multidim_average`` arguments:
  352. - If ``multidim_average`` is set to ``global``:
  353. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  354. - If ``average=None/'none'``, the shape will be ``(C,)``
  355. - If ``multidim_average`` is set to ``samplewise``:
  356. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  357. - If ``average=None/'none'``, the shape will be ``(N, C)``
  358. If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
  359. which the reduction will then be applied over instead of the sample dimension ``N``.
  360. Args:
  361. beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight
  362. num_labels: Integer specifying the number of labels
  363. threshold: Threshold for transforming probability to binary (0,1) predictions
  364. average:
  365. Defines the reduction that is applied over labels. Should be one of the following:
  366. - ``micro``: Sum statistics over all labels
  367. - ``macro``: Calculate statistics for each label and average them
  368. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  369. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  370. multidim_average:
  371. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  372. - ``global``: Additional dimensions are flatted along the batch dimension
  373. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  374. The statistics in this case are calculated over the additional dimensions.
  375. ignore_index:
  376. Specifies a target value that is ignored and does not contribute to the metric calculation
  377. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  378. Set to ``False`` for faster computations.
  379. zero_division: Should be `0` or `1`. The value returned when
  380. :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`.
  381. Example (preds is int tensor):
  382. >>> from torch import tensor
  383. >>> from torchmetrics.classification import MultilabelFBetaScore
  384. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  385. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  386. >>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3)
  387. >>> metric(preds, target)
  388. tensor(0.6111)
  389. >>> mlfbs = MultilabelFBetaScore(beta=2.0, num_labels=3, average=None)
  390. >>> mlfbs(preds, target)
  391. tensor([1.0000, 0.0000, 0.8333])
  392. Example (preds is float tensor):
  393. >>> from torchmetrics.classification import MultilabelFBetaScore
  394. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  395. >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
  396. >>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3)
  397. >>> metric(preds, target)
  398. tensor(0.6111)
  399. >>> mlfbs = MultilabelFBetaScore(beta=2.0, num_labels=3, average=None)
  400. >>> mlfbs(preds, target)
  401. tensor([1.0000, 0.0000, 0.8333])
  402. Example (multidim tensors):
  403. >>> from torchmetrics.classification import MultilabelFBetaScore
  404. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  405. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  406. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  407. >>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0, multidim_average='samplewise')
  408. >>> metric(preds, target)
  409. tensor([0.5556, 0.0000])
  410. >>> mlfbs = MultilabelFBetaScore(num_labels=3, beta=2.0, multidim_average='samplewise', average=None)
  411. >>> mlfbs(preds, target)
  412. tensor([[0.8333, 0.8333, 0.0000],
  413. [0.0000, 0.0000, 0.0000]])
  414. """
  415. is_differentiable: bool = False
  416. higher_is_better: Optional[bool] = True
  417. full_state_update: bool = False
  418. plot_lower_bound: float = 0.0
  419. plot_upper_bound: float = 1.0
  420. plot_legend_name: str = "Label"
  421. def __init__(
  422. self,
  423. beta: float,
  424. num_labels: int,
  425. threshold: float = 0.5,
  426. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  427. multidim_average: Literal["global", "samplewise"] = "global",
  428. ignore_index: Optional[int] = None,
  429. validate_args: bool = True,
  430. zero_division: float = 0,
  431. **kwargs: Any,
  432. ) -> None:
  433. super().__init__(
  434. num_labels=num_labels,
  435. threshold=threshold,
  436. average=average,
  437. multidim_average=multidim_average,
  438. ignore_index=ignore_index,
  439. validate_args=False,
  440. **kwargs,
  441. )
  442. if validate_args:
  443. _multilabel_fbeta_score_arg_validation(
  444. beta, num_labels, threshold, average, multidim_average, ignore_index, zero_division
  445. )
  446. self.validate_args = validate_args
  447. self.zero_division = zero_division
  448. self.beta = beta
  449. def compute(self) -> Tensor:
  450. """Compute metric."""
  451. tp, fp, tn, fn = self._final_state()
  452. return _fbeta_reduce(
  453. tp,
  454. fp,
  455. tn,
  456. fn,
  457. self.beta,
  458. average=self.average,
  459. multidim_average=self.multidim_average,
  460. multilabel=True,
  461. zero_division=self.zero_division,
  462. )
  463. def plot(
  464. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  465. ) -> _PLOT_OUT_TYPE:
  466. """Plot a single or multiple values from the metric.
  467. Args:
  468. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  469. If no value is provided, will automatically call `metric.compute` and plot that result.
  470. ax: An matplotlib axis object. If provided will add plot to that axis
  471. Returns:
  472. Figure and Axes object
  473. Raises:
  474. ModuleNotFoundError:
  475. If `matplotlib` is not installed
  476. .. plot::
  477. :scale: 75
  478. >>> from torch import rand, randint
  479. >>> # Example plotting a single value
  480. >>> from torchmetrics.classification import MultilabelFBetaScore
  481. >>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0)
  482. >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
  483. >>> fig_, ax_ = metric.plot()
  484. .. plot::
  485. :scale: 75
  486. >>> from torch import rand, randint
  487. >>> # Example plotting multiple values
  488. >>> from torchmetrics.classification import MultilabelFBetaScore
  489. >>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0)
  490. >>> values = [ ]
  491. >>> for _ in range(10):
  492. ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
  493. >>> fig_, ax_ = metric.plot(values)
  494. """
  495. return self._plot(val, ax)
  496. class BinaryF1Score(BinaryFBetaScore):
  497. r"""Compute F-1 score for binary tasks.
  498. .. math::
  499. F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}
  500. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0`
  501. where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false
  502. positives and false negatives respectively. If this case is encountered a score of `zero_division`
  503. (0 or 1, default is 0) is returned.
  504. As input to ``forward`` and ``update`` the metric accepts the following input:
  505. - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point
  506. tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per
  507. element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
  508. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``
  509. As output to ``forward`` and ``compute`` the metric returns the following output:
  510. - ``bf1s`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``multidim_average`` argument:
  511. - If ``multidim_average`` is set to ``global``, the metric returns a scalar value.
  512. - If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar
  513. value per sample.
  514. If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
  515. which the reduction will then be applied over instead of the sample dimension ``N``.
  516. Args:
  517. threshold: Threshold for transforming probability to binary {0,1} predictions
  518. multidim_average:
  519. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  520. - ``global``: Additional dimensions are flatted along the batch dimension
  521. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  522. The statistics in this case are calculated over the additional dimensions.
  523. ignore_index:
  524. Specifies a target value that is ignored and does not contribute to the metric calculation
  525. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  526. Set to ``False`` for faster computations.
  527. zero_division: Should be `0` or `1`. The value returned when
  528. :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`.
  529. Example (preds is int tensor):
  530. >>> from torch import tensor
  531. >>> from torchmetrics.classification import BinaryF1Score
  532. >>> target = tensor([0, 1, 0, 1, 0, 1])
  533. >>> preds = tensor([0, 0, 1, 1, 0, 1])
  534. >>> metric = BinaryF1Score()
  535. >>> metric(preds, target)
  536. tensor(0.6667)
  537. Example (preds is float tensor):
  538. >>> from torchmetrics.classification import BinaryF1Score
  539. >>> target = tensor([0, 1, 0, 1, 0, 1])
  540. >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
  541. >>> metric = BinaryF1Score()
  542. >>> metric(preds, target)
  543. tensor(0.6667)
  544. Example (multidim tensors):
  545. >>> from torchmetrics.classification import BinaryF1Score
  546. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  547. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  548. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  549. >>> metric = BinaryF1Score(multidim_average='samplewise')
  550. >>> metric(preds, target)
  551. tensor([0.5000, 0.0000])
  552. """
  553. is_differentiable: bool = False
  554. higher_is_better: Optional[bool] = True
  555. full_state_update: bool = False
  556. plot_lower_bound: float = 0.0
  557. plot_upper_bound: float = 1.0
  558. def __init__(
  559. self,
  560. threshold: float = 0.5,
  561. multidim_average: Literal["global", "samplewise"] = "global",
  562. ignore_index: Optional[int] = None,
  563. validate_args: bool = True,
  564. zero_division: float = 0,
  565. **kwargs: Any,
  566. ) -> None:
  567. super().__init__(
  568. beta=1.0,
  569. threshold=threshold,
  570. multidim_average=multidim_average,
  571. ignore_index=ignore_index,
  572. validate_args=validate_args,
  573. zero_division=zero_division,
  574. **kwargs,
  575. )
  576. def plot(
  577. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  578. ) -> _PLOT_OUT_TYPE:
  579. """Plot a single or multiple values from the metric.
  580. Args:
  581. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  582. If no value is provided, will automatically call `metric.compute` and plot that result.
  583. ax: An matplotlib axis object. If provided will add plot to that axis
  584. Returns:
  585. Figure object and Axes object
  586. Raises:
  587. ModuleNotFoundError:
  588. If `matplotlib` is not installed
  589. .. plot::
  590. :scale: 75
  591. >>> from torch import rand, randint
  592. >>> # Example plotting a single value
  593. >>> from torchmetrics.classification import BinaryF1Score
  594. >>> metric = BinaryF1Score()
  595. >>> metric.update(rand(10), randint(2,(10,)))
  596. >>> fig_, ax_ = metric.plot()
  597. .. plot::
  598. :scale: 75
  599. >>> from torch import rand, randint
  600. >>> # Example plotting multiple values
  601. >>> from torchmetrics.classification import BinaryF1Score
  602. >>> metric = BinaryF1Score()
  603. >>> values = [ ]
  604. >>> for _ in range(10):
  605. ... values.append(metric(rand(10), randint(2,(10,))))
  606. >>> fig_, ax_ = metric.plot(values)
  607. """
  608. return self._plot(val, ax)
  609. class MulticlassF1Score(MulticlassFBetaScore):
  610. r"""Compute F-1 score for multiclass tasks.
  611. .. math::
  612. F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}
  613. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0`
  614. where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false
  615. positives and false negatives respectively. If this case is encountered for any class, the metric for that class
  616. will be set to `zero_division` (0 or 1, default is 0) and the overall metric may therefore be affected in turn.
  617. As input to ``forward`` and ``update`` the metric accepts the following input:
  618. - ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` or float tensor of shape ``(N, C, ..)``.
  619. If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert
  620. probabilities/logits into an int tensor.
  621. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``
  622. As output to ``forward`` and ``compute`` the metric returns the following output:
  623. - ``mcf1s`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and
  624. ``multidim_average`` arguments:
  625. - If ``multidim_average`` is set to ``global``:
  626. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  627. - If ``average=None/'none'``, the shape will be ``(C,)``
  628. - If ``multidim_average`` is set to ``samplewise``:
  629. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  630. - If ``average=None/'none'``, the shape will be ``(N, C)``
  631. If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
  632. which the reduction will then be applied over instead of the sample dimension ``N``.
  633. Args:
  634. preds: Tensor with predictions
  635. target: Tensor with true labels
  636. num_classes: Integer specifying the number of classes
  637. average:
  638. Defines the reduction that is applied over labels. Should be one of the following:
  639. - ``micro``: Sum statistics over all labels
  640. - ``macro``: Calculate statistics for each label and average them
  641. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  642. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  643. top_k:
  644. Number of highest probability or logit score predictions considered to find the correct label.
  645. Only works when ``preds`` contain probabilities/logits.
  646. multidim_average:
  647. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  648. - ``global``: Additional dimensions are flatted along the batch dimension
  649. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  650. The statistics in this case are calculated over the additional dimensions.
  651. ignore_index:
  652. Specifies a target value that is ignored and does not contribute to the metric calculation
  653. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  654. Set to ``False`` for faster computations.
  655. zero_division: Should be `0` or `1`. The value returned when
  656. :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`.
  657. Example (preds is int tensor):
  658. >>> from torch import tensor
  659. >>> from torchmetrics.classification import MulticlassF1Score
  660. >>> target = tensor([2, 1, 0, 0])
  661. >>> preds = tensor([2, 1, 0, 1])
  662. >>> metric = MulticlassF1Score(num_classes=3)
  663. >>> metric(preds, target)
  664. tensor(0.7778)
  665. >>> mcf1s = MulticlassF1Score(num_classes=3, average=None)
  666. >>> mcf1s(preds, target)
  667. tensor([0.6667, 0.6667, 1.0000])
  668. Example (preds is float tensor):
  669. >>> from torchmetrics.classification import MulticlassF1Score
  670. >>> target = tensor([2, 1, 0, 0])
  671. >>> preds = tensor([[0.16, 0.26, 0.58],
  672. ... [0.22, 0.61, 0.17],
  673. ... [0.71, 0.09, 0.20],
  674. ... [0.05, 0.82, 0.13]])
  675. >>> metric = MulticlassF1Score(num_classes=3)
  676. >>> metric(preds, target)
  677. tensor(0.7778)
  678. >>> mcf1s = MulticlassF1Score(num_classes=3, average=None)
  679. >>> mcf1s(preds, target)
  680. tensor([0.6667, 0.6667, 1.0000])
  681. Example (multidim tensors):
  682. >>> from torchmetrics.classification import MulticlassF1Score
  683. >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
  684. >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
  685. >>> metric = MulticlassF1Score(num_classes=3, multidim_average='samplewise')
  686. >>> metric(preds, target)
  687. tensor([0.4333, 0.2667])
  688. >>> mcf1s = MulticlassF1Score(num_classes=3, multidim_average='samplewise', average=None)
  689. >>> mcf1s(preds, target)
  690. tensor([[0.8000, 0.0000, 0.5000],
  691. [0.0000, 0.4000, 0.4000]])
  692. """
  693. is_differentiable: bool = False
  694. higher_is_better: Optional[bool] = True
  695. full_state_update: bool = False
  696. plot_lower_bound: float = 0.0
  697. plot_upper_bound: float = 1.0
  698. plot_legend_name: str = "Class"
  699. def __init__(
  700. self,
  701. num_classes: int,
  702. top_k: int = 1,
  703. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  704. multidim_average: Literal["global", "samplewise"] = "global",
  705. ignore_index: Optional[int] = None,
  706. validate_args: bool = True,
  707. zero_division: float = 0,
  708. **kwargs: Any,
  709. ) -> None:
  710. super().__init__(
  711. beta=1.0,
  712. num_classes=num_classes,
  713. top_k=top_k,
  714. average=average,
  715. multidim_average=multidim_average,
  716. ignore_index=ignore_index,
  717. validate_args=validate_args,
  718. zero_division=zero_division,
  719. **kwargs,
  720. )
  721. def plot(
  722. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  723. ) -> _PLOT_OUT_TYPE:
  724. """Plot a single or multiple values from the metric.
  725. Args:
  726. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  727. If no value is provided, will automatically call `metric.compute` and plot that result.
  728. ax: An matplotlib axis object. If provided will add plot to that axis
  729. Returns:
  730. Figure object and Axes object
  731. Raises:
  732. ModuleNotFoundError:
  733. If `matplotlib` is not installed
  734. .. plot::
  735. :scale: 75
  736. >>> from torch import randint
  737. >>> # Example plotting a single value per class
  738. >>> from torchmetrics.classification import MulticlassF1Score
  739. >>> metric = MulticlassF1Score(num_classes=3, average=None)
  740. >>> metric.update(randint(3, (20,)), randint(3, (20,)))
  741. >>> fig_, ax_ = metric.plot()
  742. .. plot::
  743. :scale: 75
  744. >>> from torch import randint
  745. >>> # Example plotting a multiple values per class
  746. >>> from torchmetrics.classification import MulticlassF1Score
  747. >>> metric = MulticlassF1Score(num_classes=3, average=None)
  748. >>> values = []
  749. >>> for _ in range(20):
  750. ... values.append(metric(randint(3, (20,)), randint(3, (20,))))
  751. >>> fig_, ax_ = metric.plot(values)
  752. """
  753. return self._plot(val, ax)
  754. class MultilabelF1Score(MultilabelFBetaScore):
  755. r"""Compute F-1 score for multilabel tasks.
  756. .. math::
  757. F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}
  758. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0`
  759. where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false
  760. positives and false negatives respectively. If this case is encountered for any label, the metric for that label
  761. will be set to `zero_division` (0 or 1, default is 0) and the overall metric may therefore be affected in turn.
  762. As input to ``forward`` and ``update`` the metric accepts the following input:
  763. - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, C, ...)``.
  764. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and
  765. will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value
  766. in ``threshold``.
  767. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``.
  768. As output to ``forward`` and ``compute`` the metric returns the following output:
  769. - ``mlf1s`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and
  770. ``multidim_average`` arguments:
  771. - If ``multidim_average`` is set to ``global``:
  772. - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
  773. - If ``average=None/'none'``, the shape will be ``(C,)``
  774. - If ``multidim_average`` is set to ``samplewise``:
  775. - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
  776. - If ``average=None/'none'``, the shape will be ``(N, C)```
  777. If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
  778. which the reduction will then be applied over instead of the sample dimension ``N``.
  779. Args:
  780. num_labels: Integer specifying the number of labels
  781. threshold: Threshold for transforming probability to binary (0,1) predictions
  782. average:
  783. Defines the reduction that is applied over labels. Should be one of the following:
  784. - ``micro``: Sum statistics over all labels
  785. - ``macro``: Calculate statistics for each label and average them
  786. - ``weighted``: calculates statistics for each label and computes weighted average using their support
  787. - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
  788. multidim_average:
  789. Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
  790. - ``global``: Additional dimensions are flatted along the batch dimension
  791. - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
  792. The statistics in this case are calculated over the additional dimensions.
  793. ignore_index:
  794. Specifies a target value that is ignored and does not contribute to the metric calculation
  795. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  796. Set to ``False`` for faster computations.
  797. zero_division: Should be `0` or `1`. The value returned when
  798. :math:`\text{TP} + \text{FP} = 0 \wedge \text{TP} + \text{FN} = 0`.
  799. Example (preds is int tensor):
  800. >>> from torch import tensor
  801. >>> from torchmetrics.classification import MultilabelF1Score
  802. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  803. >>> preds = tensor([[0, 0, 1], [1, 0, 1]])
  804. >>> metric = MultilabelF1Score(num_labels=3)
  805. >>> metric(preds, target)
  806. tensor(0.5556)
  807. >>> mlf1s = MultilabelF1Score(num_labels=3, average=None)
  808. >>> mlf1s(preds, target)
  809. tensor([1.0000, 0.0000, 0.6667])
  810. Example (preds is float tensor):
  811. >>> from torchmetrics.classification import MultilabelF1Score
  812. >>> target = tensor([[0, 1, 0], [1, 0, 1]])
  813. >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
  814. >>> metric = MultilabelF1Score(num_labels=3)
  815. >>> metric(preds, target)
  816. tensor(0.5556)
  817. >>> mlf1s = MultilabelF1Score(num_labels=3, average=None)
  818. >>> mlf1s(preds, target)
  819. tensor([1.0000, 0.0000, 0.6667])
  820. Example (multidim tensors):
  821. >>> from torchmetrics.classification import MultilabelF1Score
  822. >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
  823. >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
  824. ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
  825. >>> metric = MultilabelF1Score(num_labels=3, multidim_average='samplewise')
  826. >>> metric(preds, target)
  827. tensor([0.4444, 0.0000])
  828. >>> mlf1s = MultilabelF1Score(num_labels=3, multidim_average='samplewise', average=None)
  829. >>> mlf1s(preds, target)
  830. tensor([[0.6667, 0.6667, 0.0000],
  831. [0.0000, 0.0000, 0.0000]])
  832. """
  833. is_differentiable: bool = False
  834. higher_is_better: Optional[bool] = True
  835. full_state_update: bool = False
  836. plot_lower_bound: float = 0.0
  837. plot_upper_bound: float = 1.0
  838. plot_legend_name: str = "Label"
  839. def __init__(
  840. self,
  841. num_labels: int,
  842. threshold: float = 0.5,
  843. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
  844. multidim_average: Literal["global", "samplewise"] = "global",
  845. ignore_index: Optional[int] = None,
  846. validate_args: bool = True,
  847. zero_division: float = 0,
  848. **kwargs: Any,
  849. ) -> None:
  850. super().__init__(
  851. beta=1.0,
  852. num_labels=num_labels,
  853. threshold=threshold,
  854. average=average,
  855. multidim_average=multidim_average,
  856. ignore_index=ignore_index,
  857. validate_args=validate_args,
  858. zero_division=zero_division,
  859. **kwargs,
  860. )
  861. def plot(
  862. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  863. ) -> _PLOT_OUT_TYPE:
  864. """Plot a single or multiple values from the metric.
  865. Args:
  866. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  867. If no value is provided, will automatically call `metric.compute` and plot that result.
  868. ax: An matplotlib axis object. If provided will add plot to that axis
  869. Returns:
  870. Figure and Axes object
  871. Raises:
  872. ModuleNotFoundError:
  873. If `matplotlib` is not installed
  874. .. plot::
  875. :scale: 75
  876. >>> from torch import rand, randint
  877. >>> # Example plotting a single value
  878. >>> from torchmetrics.classification import MultilabelF1Score
  879. >>> metric = MultilabelF1Score(num_labels=3)
  880. >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
  881. >>> fig_, ax_ = metric.plot()
  882. .. plot::
  883. :scale: 75
  884. >>> from torch import rand, randint
  885. >>> # Example plotting multiple values
  886. >>> from torchmetrics.classification import MultilabelF1Score
  887. >>> metric = MultilabelF1Score(num_labels=3)
  888. >>> values = [ ]
  889. >>> for _ in range(10):
  890. ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
  891. >>> fig_, ax_ = metric.plot(values)
  892. """
  893. return self._plot(val, ax)
  894. class FBetaScore(_ClassificationTaskWrapper):
  895. r"""Compute `F-score`_ metric.
  896. .. math::
  897. F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
  898. {(\beta^2 * \text{precision}) + \text{recall}}
  899. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0`
  900. where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false
  901. positives and false negatives respectively. If this case is encountered for any class/label, the metric for that
  902. class/label will be set to `zero_division` (0 or 1, default is 0) and the overall metric may therefore be
  903. affected in turn.
  904. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  905. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  906. :class:`~torchmetrics.classification.BinaryFBetaScore`,
  907. :class:`~torchmetrics.classification.MulticlassFBetaScore` and
  908. :class:`~torchmetrics.classification.MultilabelFBetaScore` for the specific details of each argument influence
  909. and examples.
  910. Legcy Example:
  911. >>> from torch import tensor
  912. >>> target = tensor([0, 1, 2, 0, 1, 2])
  913. >>> preds = tensor([0, 2, 1, 0, 0, 1])
  914. >>> f_beta = FBetaScore(task="multiclass", num_classes=3, beta=0.5)
  915. >>> f_beta(preds, target)
  916. tensor(0.3333)
  917. """
  918. def __new__( # type: ignore[misc]
  919. cls: type["FBetaScore"],
  920. task: Literal["binary", "multiclass", "multilabel"],
  921. beta: float = 1.0,
  922. threshold: float = 0.5,
  923. num_classes: Optional[int] = None,
  924. num_labels: Optional[int] = None,
  925. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
  926. multidim_average: Optional[Literal["global", "samplewise"]] = "global",
  927. top_k: Optional[int] = 1,
  928. ignore_index: Optional[int] = None,
  929. validate_args: bool = True,
  930. zero_division: float = 0,
  931. **kwargs: Any,
  932. ) -> Metric:
  933. """Initialize task metric."""
  934. task = ClassificationTask.from_str(task)
  935. assert multidim_average is not None # noqa: S101 # needed for mypy
  936. kwargs.update({
  937. "multidim_average": multidim_average,
  938. "ignore_index": ignore_index,
  939. "validate_args": validate_args,
  940. "zero_division": zero_division,
  941. })
  942. if task == ClassificationTask.BINARY:
  943. return BinaryFBetaScore(beta, threshold, **kwargs)
  944. if task == ClassificationTask.MULTICLASS:
  945. if not isinstance(num_classes, int):
  946. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  947. if not isinstance(top_k, int):
  948. raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`")
  949. return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs)
  950. if task == ClassificationTask.MULTILABEL:
  951. if not isinstance(num_labels, int):
  952. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  953. return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs)
  954. raise ValueError(f"Task {task} not supported!")
  955. class F1Score(_ClassificationTaskWrapper):
  956. r"""Compute F-1 score.
  957. .. math::
  958. F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}
  959. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0`
  960. where :math:`\text{TP}`, :math:`\text{FP}` and :math:`\text{FN}` represent the number of true positives, false
  961. positives and false negatives respectively. If this case is encountered for any class/label, the metric for that
  962. class/label will be set to `zero_division` (0 or 1, default is 0) and the overall metric may therefore be
  963. affected in turn.
  964. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  965. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  966. :class:`~torchmetrics.classification.BinaryF1Score`, :class:`~torchmetrics.classification.MulticlassF1Score` and
  967. :class:`~torchmetrics.classification.MultilabelF1Score` for the specific details of each argument influence and
  968. examples.
  969. Legacy Example:
  970. >>> from torch import tensor
  971. >>> target = tensor([0, 1, 2, 0, 1, 2])
  972. >>> preds = tensor([0, 2, 1, 0, 0, 1])
  973. >>> f1 = F1Score(task="multiclass", num_classes=3)
  974. >>> f1(preds, target)
  975. tensor(0.3333)
  976. """
  977. def __new__( # type: ignore[misc]
  978. cls: type["F1Score"],
  979. task: Literal["binary", "multiclass", "multilabel"],
  980. threshold: float = 0.5,
  981. num_classes: Optional[int] = None,
  982. num_labels: Optional[int] = None,
  983. average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
  984. multidim_average: Optional[Literal["global", "samplewise"]] = "global",
  985. top_k: Optional[int] = 1,
  986. ignore_index: Optional[int] = None,
  987. validate_args: bool = True,
  988. zero_division: float = 0,
  989. **kwargs: Any,
  990. ) -> Metric:
  991. """Initialize task metric."""
  992. task = ClassificationTask.from_str(task)
  993. assert multidim_average is not None # noqa: S101 # needed for mypy
  994. kwargs.update({
  995. "multidim_average": multidim_average,
  996. "ignore_index": ignore_index,
  997. "validate_args": validate_args,
  998. "zero_division": zero_division,
  999. })
  1000. if task == ClassificationTask.BINARY:
  1001. return BinaryF1Score(threshold, **kwargs)
  1002. if task == ClassificationTask.MULTICLASS:
  1003. if not isinstance(num_classes, int):
  1004. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  1005. if not isinstance(top_k, int):
  1006. raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`")
  1007. return MulticlassF1Score(num_classes, top_k, average, **kwargs)
  1008. if task == ClassificationTask.MULTILABEL:
  1009. if not isinstance(num_labels, int):
  1010. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  1011. return MultilabelF1Score(num_labels, threshold, average, **kwargs)
  1012. raise ValueError(f"Task {task} not supported!")