eer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Optional, Union
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.classification.roc import (
  19. binary_roc,
  20. multiclass_roc,
  21. multilabel_roc,
  22. )
  23. from torchmetrics.utilities.enums import ClassificationTask
  24. def _binary_eer_compute(fpr: Tensor, tpr: Tensor) -> Tensor:
  25. """Compute Equal Error Rate (EER) for binary classification task."""
  26. diff = fpr - (1 - tpr)
  27. idx = torch.argmin(torch.abs(diff))
  28. return (fpr[idx] + (1 - tpr[idx])) / 2
  29. def _eer_compute(
  30. fpr: Union[Tensor, List[Tensor]],
  31. tpr: Union[Tensor, List[Tensor]],
  32. ) -> Tensor:
  33. """Compute Equal Error Rate (EER)."""
  34. if isinstance(fpr, Tensor) and isinstance(tpr, Tensor) and fpr.ndim == 1:
  35. return _binary_eer_compute(fpr, tpr)
  36. return torch.stack([_binary_eer_compute(f, t) for f, t in zip(fpr, tpr)])
  37. def binary_eer(
  38. preds: Tensor,
  39. target: Tensor,
  40. thresholds: Optional[Union[int, List[float], Tensor]] = None,
  41. ignore_index: Optional[int] = None,
  42. validate_args: bool = True,
  43. ) -> Tensor:
  44. r"""Compute Equal Error Rate (EER) for binary classification task.
  45. .. math::
  46. \text{EER} = \frac{\text{FAR} + \text{FRR}}{2}, \text{where} \min_t abs(FAR_t-FRR_t)
  47. The Equal Error Rate (EER) is the point where the False Positive Rate (FPR) and True Positive Rate (TPR) are
  48. equal, or in practise minimized. A lower EER value signifies higher system accuracy.
  49. Args:
  50. preds: Tensor with predictions
  51. target: Tensor with true labels
  52. thresholds:
  53. Can be one of:
  54. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  55. all the data. Most accurate but also most memory consuming approach.
  56. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  57. 0 to 1 as bins for the calculation.
  58. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  59. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  60. bins for the calculation.
  61. ignore_index:
  62. Specifies a target value that is ignored and does not contribute to the metric calculation
  63. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  64. Set to ``False`` for faster computations
  65. Returns:
  66. A single scalar with the eer score
  67. Example:
  68. >>> from torchmetrics.functional.classification import binary_eer
  69. >>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
  70. >>> target = torch.tensor([0, 1, 1, 0])
  71. >>> binary_eer(preds, target, thresholds=None)
  72. tensor(0.5000)
  73. >>> binary_eer(preds, target, thresholds=5)
  74. tensor(0.7500)
  75. """
  76. fpr, tpr, _ = binary_roc(preds, target, thresholds, ignore_index, validate_args)
  77. return _eer_compute(fpr, tpr)
  78. def multiclass_eer(
  79. preds: Tensor,
  80. target: Tensor,
  81. num_classes: int,
  82. thresholds: Optional[Union[int, List[float], Tensor]] = None,
  83. average: Optional[Literal["micro", "macro"]] = None,
  84. ignore_index: Optional[int] = None,
  85. validate_args: bool = True,
  86. ) -> Tensor:
  87. r"""Compute Equal Error Rate (EER) for multiclass classification task.
  88. .. math::
  89. \text{EER} = \frac{\text{FAR} + (1 - \text{FRR})}{2}, \text{where} \min_t abs(FAR_t-FRR_t)
  90. The Equal Error Rate (EER) is the point where the False Positive Rate (FPR) and True Positive Rate (TPR) are
  91. equal, or in practise minimized. A lower EER value signifies higher system accuracy.
  92. Args:
  93. preds: Tensor with predictions
  94. target: Tensor with true labels
  95. num_classes: Integer specifying the number of classes
  96. thresholds:
  97. Can be one of:
  98. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  99. all the data. Most accurate but also most memory consuming approach.
  100. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  101. 0 to 1 as bins for the calculation.
  102. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  103. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  104. bins for the calculation.
  105. average:
  106. If aggregation of should be applied. The aggregation is applied to underlying ROC curves.
  107. By default, eer is not aggregated and a score for each class is returned. If `average` is set to ``"micro"``
  108. , the metric will aggregate the curves by one hot encoding the targets and flattening the predictions,
  109. considering all classes jointly as a binary problem. If `average` is set to ``"macro"``, the metric will
  110. aggregate the curves by first interpolating the curves from each class at a combined set of thresholds and
  111. then average over the classwise interpolated curves. See `averaging curve objects`_ for more info on the
  112. different averaging methods.
  113. ignore_index:
  114. Specifies a target value that is ignored and does not contribute to the metric calculation
  115. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  116. Set to ``False`` for faster computations.
  117. Returns:
  118. If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with eer score per class.
  119. If `average="macro"|"micro"` then a single scalar is returned.
  120. Example:
  121. >>> from torchmetrics.functional.classification import multiclass_eer
  122. >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
  123. ... [0.05, 0.75, 0.05, 0.05, 0.05],
  124. ... [0.05, 0.05, 0.75, 0.05, 0.05],
  125. ... [0.05, 0.05, 0.05, 0.75, 0.05]])
  126. >>> target = torch.tensor([0, 1, 3, 2])
  127. >>> multiclass_eer(preds, target, num_classes=5, average="macro", thresholds=None)
  128. tensor(0.4667)
  129. >>> multiclass_eer(preds, target, num_classes=5, average=None, thresholds=None)
  130. tensor([0.0000, 0.0000, 0.6667, 0.6667, 1.0000])
  131. >>> multiclass_eer(preds, target, num_classes=5, average="macro", thresholds=5)
  132. tensor(0.4667)
  133. >>> multiclass_eer(preds, target, num_classes=5, average=None, thresholds=5)
  134. tensor([0.0000, 0.0000, 0.6667, 0.6667, 1.0000])
  135. """
  136. fpr, tpr, _ = multiclass_roc(preds, target, num_classes, thresholds, average, ignore_index, validate_args)
  137. return _eer_compute(fpr, tpr)
  138. def multilabel_eer(
  139. preds: Tensor,
  140. target: Tensor,
  141. num_labels: int,
  142. thresholds: Optional[Union[int, List[float], Tensor]] = None,
  143. ignore_index: Optional[int] = None,
  144. validate_args: bool = True,
  145. ) -> Tensor:
  146. r"""Compute Equal Error Rate (EER) for multilabel classification task.
  147. .. math::
  148. \text{EER} = \frac{\text{FAR} + (1 - \text{FRR})}{2}, \text{where} \min_t abs(FAR_t-FRR_t)
  149. The Equal Error Rate (EER) is the point where the False Positive Rate (FPR) and True Positive Rate (TPR) are
  150. equal, or in practise minimized. A lower EER value signifies higher system accuracy.
  151. Args:
  152. preds: Tensor with predictions
  153. target: Tensor with true labels
  154. num_labels: Integer specifying the number of labels
  155. thresholds:
  156. Can be one of:
  157. - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
  158. all the data. Most accurate but also most memory consuming approach.
  159. - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
  160. 0 to 1 as bins for the calculation.
  161. - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
  162. - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
  163. bins for the calculation.
  164. ignore_index:
  165. Specifies a target value that is ignored and does not contribute to the metric calculation
  166. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  167. Set to ``False`` for faster computations.
  168. Returns:
  169. A 1d tensor of shape (n_classes, ) will be returned with eer score per label.
  170. Example:
  171. >>> from torchmetrics.functional.classification import multilabel_eer
  172. >>> preds = torch.tensor([[0.75, 0.05, 0.35],
  173. ... [0.45, 0.75, 0.05],
  174. ... [0.05, 0.55, 0.75],
  175. ... [0.05, 0.65, 0.05]])
  176. >>> target = torch.tensor([[1, 0, 1],
  177. ... [0, 0, 0],
  178. ... [0, 1, 1],
  179. ... [1, 1, 1]])
  180. >>> multilabel_eer(preds, target, num_labels=3, thresholds=None)
  181. tensor([0.5000, 0.5000, 0.1667])
  182. >>> multilabel_eer(preds, target, num_labels=3, thresholds=5)
  183. tensor([0.5000, 0.7500, 0.1667])
  184. """
  185. fpr, tpr, _ = multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args)
  186. return _eer_compute(fpr, tpr)
  187. def eer(
  188. preds: Tensor,
  189. target: Tensor,
  190. task: Literal["binary", "multiclass", "multilabel"],
  191. thresholds: Optional[Union[int, List[float], Tensor]] = None,
  192. num_classes: Optional[int] = None,
  193. num_labels: Optional[int] = None,
  194. average: Optional[Literal["micro", "macro"]] = None,
  195. ignore_index: Optional[int] = None,
  196. validate_args: bool = True,
  197. ) -> Union[Tensor, List[Tensor]]:
  198. """Compute Equal Error Rate (EER) metric.
  199. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  200. ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
  201. :func:`~torchmetrics.functional.classification.binary_eer`,
  202. :func:`~torchmetrics.functional.classification.multiclass_eer` and
  203. :func:`~torchmetrics.functional.classification.multilabel_eer` for the specific details of
  204. each argument influence and examples.
  205. Args:
  206. preds: Predictions from model (logits or probabilities)
  207. target: Ground truth labels
  208. task: Type of task, either 'binary', 'multiclass' or 'multilabel'
  209. thresholds: Thresholds used for computing the ROC curve
  210. num_classes: Number of classes (for multiclass task)
  211. num_labels: Number of labels (for multilabel task)
  212. average: Method to average EER over multiple classes/labels
  213. ignore_index: Specify a target value that is ignored
  214. validate_args: Bool indicating whether to validate input arguments
  215. Legacy Example:
  216. >>> from torchmetrics.functional.classification import eer
  217. >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
  218. >>> target = torch.tensor([0, 0, 1, 1, 1])
  219. >>> eer(preds, target, task='binary')
  220. tensor(0.5833)
  221. >>> preds = torch.tensor([[0.90, 0.05, 0.05],
  222. ... [0.05, 0.90, 0.05],
  223. ... [0.05, 0.05, 0.90],
  224. ... [0.85, 0.05, 0.10],
  225. ... [0.10, 0.10, 0.80]])
  226. >>> target = torch.tensor([0, 1, 1, 2, 2])
  227. >>> eer(preds, target, task='multiclass', num_classes=3, )
  228. tensor([0.0000, 0.4167, 0.4167])
  229. """
  230. task = ClassificationTask.from_str(task)
  231. if task == ClassificationTask.BINARY:
  232. return binary_eer(preds, target, thresholds, ignore_index, validate_args)
  233. if task == ClassificationTask.MULTICLASS:
  234. if not isinstance(num_classes, int):
  235. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  236. return multiclass_eer(preds, target, num_classes, thresholds, average, ignore_index, validate_args)
  237. if task == ClassificationTask.MULTILABEL:
  238. if not isinstance(num_labels, int):
  239. raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
  240. return multilabel_eer(preds, target, num_labels, thresholds, ignore_index, validate_args)
  241. raise ValueError(f"Task {task} not supported.")