eer.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.classification.base import _ClassificationTaskWrapper
  19. from torchmetrics.classification.roc import (
  20. BinaryROC,
  21. MulticlassROC,
  22. MultilabelROC,
  23. )
  24. from torchmetrics.functional.classification.eer import _eer_compute
  25. from torchmetrics.metric import Metric
  26. from torchmetrics.utilities.enums import ClassificationTask
  27. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  28. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  29. if not _MATPLOTLIB_AVAILABLE:
  30. __doctest_skip__ = ["BinaryEER.plot", "MulticlassEER.plot", "MultilabelEER.plot"]
  31. class BinaryEER(BinaryROC):
  32. r"""Compute Equal Error Rate (EER) for multiclass classification task.
  33. .. math::
  34. \text{EER} = \frac{\text{FAR} + \text{FRR}}{2}, \text{where} \min_t abs(FAR_t-FRR_t)
  35. The Equal Error Rate (EER) is the point where the False Positive Rate (FPR) and True Positive Rate (TPR) are
  36. equal, or in practise minimized. A lower EER value signifies higher system accuracy.
  37. As input to ``forward`` and ``update`` the metric accepts the following input:
  38. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)`` containing probabilities or logits for
  39. each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
  40. sigmoid per element.
  41. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and
  42. therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the
  43. positive class.
  44. As output to ``forward`` and ``compute`` the metric returns the following output:
  45. - ``b_eer`` (:class:`~torch.Tensor`): A single scalar with the eer score.
  46. Additional dimension ``...`` will be flattened into the batch dimension.
  47. The implementation both supports calculating the metric in a non-binned but accurate version and a
  48. binned version that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will
  49. activate the non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the
  50. `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  51. size :math:`\mathcal{O}(n_{thresholds})` (constant memory).
  52. Args:
  53. thresholds: Can be one of:
  54. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  55. all the data. Most accurate but also most memory consuming approach.
  56. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  57. 0 to 1 as bins for the calculation.
  58. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  59. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  60. bins for the calculation.
  61. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  62. Set to ``False`` for faster computations.
  63. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  64. Example:
  65. >>> from torch import tensor
  66. >>> from torchmetrics.classification import BinaryEER
  67. >>> preds = tensor([0, 0.5, 0.7, 0.8])
  68. >>> target = tensor([0, 1, 1, 0])
  69. >>> metric = BinaryEER(thresholds=None)
  70. >>> metric(preds, target)
  71. tensor(0.5000)
  72. >>> b_eer = BinaryEER(thresholds=5)
  73. >>> b_eer(preds, target)
  74. tensor(0.7500)
  75. """
  76. def compute(self) -> Tensor: # type: ignore[override]
  77. """Compute metric."""
  78. fpr, tpr, _ = super().compute()
  79. return _eer_compute(fpr, tpr)
  80. def plot( # type: ignore[override]
  81. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  82. ) -> _PLOT_OUT_TYPE:
  83. """Plot a single or multiple values from the metric.
  84. Args:
  85. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  86. If no value is provided, will automatically call `metric.compute` and plot that result.
  87. ax: An matplotlib axis object. If provided will add plot to that axis
  88. Returns:
  89. Figure and Axes object
  90. Raises:
  91. ModuleNotFoundError:
  92. If `matplotlib` is not installed
  93. .. plot::
  94. :scale: 75
  95. >>> # Example plotting a single
  96. >>> import torch
  97. >>> from torchmetrics.classification import BinaryEER
  98. >>> metric = BinaryEER()
  99. >>> metric.update(torch.rand(20,), torch.randint(2, (20,)))
  100. >>> fig_, ax_ = metric.plot()
  101. .. plot::
  102. :scale: 75
  103. >>> # Example plotting multiple values
  104. >>> import torch
  105. >>> from torchmetrics.classification import BinaryEER
  106. >>> metric = BinaryEER()
  107. >>> values = [ ]
  108. >>> for _ in range(10):
  109. ... values.append(metric(torch.rand(20,), torch.randint(2, (20,))))
  110. >>> fig_, ax_ = metric.plot(values)
  111. """
  112. return self._plot(val, ax)
  113. class MulticlassEER(MulticlassROC):
  114. r"""Compute Equal Error Rate (EER) for multiclass classification task.
  115. .. math::
  116. \text{EER} = \frac{\text{FAR} + (1 - \text{FRR})}{2}, \text{where} \min_t abs(FAR_t-FRR_t)
  117. The Equal Error Rate (EER) is the point where the False Positive Rate (FPR) and True Positive Rate (TPR) are
  118. equal, or in practise minimized. A lower EER value signifies higher system accuracy.
  119. As input to ``forward`` and ``update`` the metric accepts the following input:
  120. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)`` containing probabilities or logits
  121. for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto
  122. apply softmax per sample.
  123. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and
  124. therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).
  125. As output to ``forward`` and ``compute`` the metric returns the following output:
  126. - ``mc_eer`` (:class:`~torch.Tensor`): If `average=None` then a 1d tensor of shape (n_classes, ) will
  127. be returned with eer score per class. If `average="macro"|"micro"` then a single scalar will be returned.
  128. Additional dimension ``...`` will be flattened into the batch dimension.
  129. The implementation both supports calculating the metric in a non-binned but accurate version and a
  130. binned version that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will
  131. activate the non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the
  132. `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  133. size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory).
  134. Args:
  135. num_classes: Integer specifying the number of classes
  136. thresholds: Can be one of:
  137. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  138. all the data. Most accurate but also most memory consuming approach.
  139. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  140. 0 to 1 as bins for the calculation.
  141. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  142. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  143. bins for the calculation.
  144. average:
  145. If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for
  146. each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot
  147. encoding the targets and flattening the predictions, considering all classes jointly as a binary problem.
  148. If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves
  149. from each class at a combined set of thresholds and then average over the classwise interpolated curves.
  150. See `averaging curve objects`_ for more info on the different averaging methods.
  151. ignore_index:
  152. Specifies a target value that is ignored and does not contribute to the metric calculation
  153. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  154. Set to ``False`` for faster computations.
  155. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  156. Examples:
  157. >>> from torch import tensor
  158. >>> from torchmetrics.classification import MulticlassEER
  159. >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
  160. ... [0.05, 0.75, 0.05, 0.05, 0.05],
  161. ... [0.05, 0.05, 0.75, 0.05, 0.05],
  162. ... [0.05, 0.05, 0.05, 0.75, 0.05]])
  163. >>> target = tensor([0, 1, 3, 2])
  164. >>> metric = MulticlassEER(num_classes=5, average="macro", thresholds=None)
  165. >>> metric(preds, target)
  166. tensor(0.4667)
  167. >>> mc_eer = MulticlassEER(num_classes=5, average=None, thresholds=None)
  168. >>> mc_eer(preds, target)
  169. tensor([0.0000, 0.0000, 0.6667, 0.6667, 1.0000])
  170. """
  171. def compute(self) -> Tensor: # type: ignore[override]
  172. """Compute metric."""
  173. fpr, tpr, _ = super().compute()
  174. return _eer_compute(fpr, tpr)
  175. def plot( # type: ignore[override]
  176. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  177. ) -> _PLOT_OUT_TYPE:
  178. """Plot a single or multiple values from the metric.
  179. Args:
  180. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  181. If no value is provided, will automatically call `metric.compute` and plot that result.
  182. ax: An matplotlib axis object. If provided will add plot to that axis
  183. Returns:
  184. Figure and Axes object
  185. Raises:
  186. ModuleNotFoundError:
  187. If `matplotlib` is not installed
  188. .. plot::
  189. :scale: 75
  190. >>> # Example plotting a single
  191. >>> import torch
  192. >>> from torchmetrics.classification import MulticlassEER
  193. >>> metric = MulticlassEER(num_classes=3)
  194. >>> metric.update(torch.randn(20, 3), torch.randint(3,(20,)))
  195. >>> fig_, ax_ = metric.plot()
  196. .. plot::
  197. :scale: 75
  198. >>> # Example plotting multiple values
  199. >>> import torch
  200. >>> from torchmetrics.classification import MulticlassEER
  201. >>> metric = MulticlassEER(num_classes=3)
  202. >>> values = [ ]
  203. >>> for _ in range(10):
  204. ... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,))))
  205. >>> fig_, ax_ = metric.plot(values)
  206. """
  207. return self._plot(val, ax)
  208. class MultilabelEER(MultilabelROC):
  209. r"""Compute Equal Error Rate (EER) for multiclass classification task.
  210. .. math::
  211. \text{EER} = \frac{\text{FAR} + (1 - \text{FRR})}{2}, \text{where} \min_t abs(FAR_t-FRR_t)
  212. The Equal Error Rate (EER) is the point where the False Positive Rate (FPR) and True Positive Rate (TPR) are
  213. equal, or in practise minimized. A lower EER value signifies higher system accuracy.
  214. As input to ``forward`` and ``update`` the metric accepts the following input:
  215. - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)`` containing probabilities or logits
  216. for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto
  217. apply sigmoid per element.
  218. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)`` containing ground truth labels, and
  219. therefore only contain {0,1} values (except if `ignore_index` is specified).
  220. As output to ``forward`` and ``compute`` the metric returns the following output:
  221. - ``ml_eer`` (:class:`~torch.Tensor`): A 1d tensor of shape (n_classes, ) will be returned with eer score per label.
  222. Additional dimension ``...`` will be flattened into the batch dimension.
  223. The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
  224. that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
  225. non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
  226. argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
  227. size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).
  228. Args:
  229. num_labels: Integer specifying the number of labels
  230. average: Defines the reduction that is applied over labels. Should be one of the following:
  231. - ``micro``: Sum score over all labels
  232. - ``macro``: Calculate score for each label and average them
  233. - ``weighted``: calculates score for each label and computes weighted average using their support
  234. - ``"none"`` or ``None``: calculates score for each label and applies no reduction
  235. thresholds: Can be one of:
  236. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  237. all the data. Most accurate but also most memory consuming approach.
  238. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  239. 0 to 1 as bins for the calculation.
  240. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  241. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  242. bins for the calculation.
  243. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  244. Set to ``False`` for faster computations.
  245. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  246. Example:
  247. >>> from torch import tensor
  248. >>> from torchmetrics.classification import MultilabelEER
  249. >>> preds = tensor([[0.75, 0.05, 0.35],
  250. ... [0.45, 0.75, 0.05],
  251. ... [0.05, 0.55, 0.75],
  252. ... [0.05, 0.65, 0.05]])
  253. >>> target = tensor([[1, 0, 1],
  254. ... [0, 0, 0],
  255. ... [0, 1, 1],
  256. ... [1, 1, 1]])
  257. >>> ml_eer = MultilabelEER(num_labels=3, thresholds=None)
  258. >>> ml_eer(preds, target)
  259. tensor([0.5000, 0.5000, 0.1667])
  260. """
  261. def compute(self) -> Tensor: # type: ignore[override]
  262. """Compute metric."""
  263. fpr, tpr, _ = super().compute()
  264. return _eer_compute(fpr, tpr)
  265. def plot( # type: ignore[override]
  266. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  267. ) -> _PLOT_OUT_TYPE:
  268. """Plot a single or multiple values from the metric.
  269. Args:
  270. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  271. If no value is provided, will automatically call `metric.compute` and plot that result.
  272. ax: An matplotlib axis object. If provided will add plot to that axis
  273. Returns:
  274. Figure and Axes object
  275. Raises:
  276. ModuleNotFoundError:
  277. If `matplotlib` is not installed
  278. .. plot::
  279. :scale: 75
  280. >>> # Example plotting a single
  281. >>> import torch
  282. >>> from torchmetrics.classification import MultilabelEER
  283. >>> metric = MultilabelEER(num_labels=3)
  284. >>> metric.update(torch.rand(20,3), torch.randint(2, (20,3)))
  285. >>> fig_, ax_ = metric.plot()
  286. .. plot::
  287. :scale: 75
  288. >>> # Example plotting multiple values
  289. >>> import torch
  290. >>> from torchmetrics.classification import MultilabelEER
  291. >>> metric = MultilabelEER(num_labels=3)
  292. >>> values = [ ]
  293. >>> for _ in range(10):
  294. ... values.append(metric(torch.rand(20,3), torch.randint(2, (20,3))))
  295. >>> fig_, ax_ = metric.plot(values)
  296. """
  297. return self._plot(val, ax)
  298. class EER(_ClassificationTaskWrapper):
  299. r"""Compute Equal Error Rate (EER) for multiclass classification task.
  300. .. math::
  301. \text{EER} = \frac{\text{FAR} + (1 - \text{FRR})}{2}, \text{where} \min_t abs(FAR_t-FRR_t)
  302. The Equal Error Rate (EER) is the point where the False Positive Rate (FPR) and True Positive Rate (TPR) are
  303. equal, or in practise minimized. A lower EER value signifies higher system accuracy.
  304. This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  305. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  306. :class:`~torchmetrics.classification.BinaryEER`, :class:`~torchmetrics.classification.MulticlassEER` and
  307. :class:`~torchmetrics.classification.MultilabelEER` for the specific details of each argument influence and
  308. examples.
  309. Legacy Example:
  310. >>> from torch import tensor
  311. >>> preds = tensor([0.13, 0.26, 0.08, 0.19, 0.34])
  312. >>> target = tensor([0, 0, 1, 1, 1])
  313. >>> eer = EER(task="binary")
  314. >>> eer(preds, target)
  315. tensor(0.5833)
  316. >>> preds = tensor([[0.90, 0.05, 0.05],
  317. ... [0.05, 0.90, 0.05],
  318. ... [0.05, 0.05, 0.90],
  319. ... [0.85, 0.05, 0.10],
  320. ... [0.10, 0.10, 0.80]])
  321. >>> target = tensor([0, 1, 1, 2, 2])
  322. >>> eer = EER(task="multiclass", num_classes=3)
  323. >>> eer(preds, target)
  324. tensor([0.0000, 0.4167, 0.4167])
  325. """
  326. def __new__( # type: ignore[misc]
  327. cls: type["EER"],
  328. task: Literal["binary", "multiclass", "multilabel"],
  329. thresholds: Optional[Union[int, list[float], Tensor]] = None,
  330. num_classes: Optional[int] = None,
  331. num_labels: Optional[int] = None,
  332. average: Optional[Literal["macro", "micro"]] = None,
  333. ignore_index: Optional[int] = None,
  334. validate_args: bool = True,
  335. **kwargs: Any,
  336. ) -> Metric:
  337. """Initialize task metric."""
  338. task = ClassificationTask.from_str(task)
  339. kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args})
  340. if task == ClassificationTask.BINARY:
  341. return BinaryEER(**kwargs)
  342. if task == ClassificationTask.MULTICLASS:
  343. if not isinstance(num_classes, int):
  344. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  345. return MulticlassEER(num_classes, average=average, **kwargs)
  346. if task == ClassificationTask.MULTILABEL:
  347. if not isinstance(num_labels, int):
  348. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  349. return MultilabelEER(num_labels, **kwargs)
  350. raise ValueError(f"Task {task} not supported!")
  351. def update(self, *args: Any, **kwargs: Any) -> None:
  352. """Update metric state."""
  353. raise NotImplementedError(
  354. f"{self.__class__.__name__} metric does not have a global `update` method. Use the task specific metric."
  355. )
  356. def compute(self) -> None:
  357. """Compute metric."""
  358. raise NotImplementedError(
  359. f"{self.__class__.__name__} metric does not have a global `compute` method. Use the task specific metric."
  360. )