cohen_kappa.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.classification.confusion_matrix import (
  19. _binary_confusion_matrix_arg_validation,
  20. _binary_confusion_matrix_format,
  21. _binary_confusion_matrix_tensor_validation,
  22. _binary_confusion_matrix_update,
  23. _multiclass_confusion_matrix_arg_validation,
  24. _multiclass_confusion_matrix_format,
  25. _multiclass_confusion_matrix_tensor_validation,
  26. _multiclass_confusion_matrix_update,
  27. )
  28. from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel
  29. def _cohen_kappa_reduce(confmat: Tensor, weights: Optional[Literal["linear", "quadratic", "none"]] = None) -> Tensor:
  30. """Reduce an un-normalized confusion matrix of shape (n_classes, n_classes) into the cohen kappa score."""
  31. confmat = confmat.float() if not confmat.is_floating_point() else confmat
  32. num_classes = confmat.shape[0]
  33. sum0 = confmat.sum(dim=0, keepdim=True)
  34. sum1 = confmat.sum(dim=1, keepdim=True)
  35. expected = sum1 @ sum0 / sum0.sum() # outer product
  36. if weights is None or weights == "none":
  37. w_mat = torch.ones_like(confmat).flatten()
  38. w_mat[:: num_classes + 1] = 0
  39. w_mat = w_mat.reshape(num_classes, num_classes)
  40. elif weights in ("linear", "quadratic"):
  41. w_mat = torch.zeros_like(confmat)
  42. w_mat += torch.arange(num_classes, dtype=w_mat.dtype, device=w_mat.device)
  43. w_mat = torch.abs(w_mat - w_mat.T) if weights == "linear" else torch.pow(w_mat - w_mat.T, 2.0)
  44. else:
  45. raise ValueError(
  46. f"Received {weights} for argument ``weights`` but should be either None, 'linear' or 'quadratic'"
  47. )
  48. k = torch.sum(w_mat * confmat) / torch.sum(w_mat * expected)
  49. return 1 - k
  50. def _binary_cohen_kappa_arg_validation(
  51. threshold: float = 0.5,
  52. ignore_index: Optional[int] = None,
  53. weights: Optional[Literal["linear", "quadratic", "none"]] = None,
  54. ) -> None:
  55. """Validate non tensor input.
  56. - ``threshold`` has to be a float in the [0,1] range
  57. - ``ignore_index`` has to be None or int
  58. - ``weights`` has to be "linear" | "quadratic" | "none" | None
  59. """
  60. _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize=None)
  61. allowed_weights = ("linear", "quadratic", "none", None)
  62. if weights not in allowed_weights:
  63. raise ValueError(f"Expected argument `weight` to be one of {allowed_weights}, but got {weights}.")
  64. def binary_cohen_kappa(
  65. preds: Tensor,
  66. target: Tensor,
  67. threshold: float = 0.5,
  68. weights: Optional[Literal["linear", "quadratic", "none"]] = None,
  69. ignore_index: Optional[int] = None,
  70. validate_args: bool = True,
  71. ) -> Tensor:
  72. r"""Calculate `Cohen's kappa score`_ that measures inter-annotator agreement for binary tasks.
  73. .. math::
  74. \kappa = (p_o - p_e) / (1 - p_e)
  75. where :math:`p_o` is the empirical probability of agreement and :math:`p_e` is
  76. the expected agreement when both annotators assign labels randomly. Note that
  77. :math:`p_e` is estimated using a per-annotator empirical prior over the
  78. class labels.
  79. Accepts the following input tensors:
  80. - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside
  81. [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally,
  82. we convert to int tensor with thresholding using the value in ``threshold``.
  83. - ``target`` (int tensor): ``(N, ...)``
  84. Additional dimension ``...`` will be flattened into the batch dimension.
  85. Args:
  86. preds: Tensor with predictions
  87. target: Tensor with true labels
  88. threshold: Threshold for transforming probability to binary (0,1) predictions
  89. weights: Weighting type to calculate the score. Choose from:
  90. - ``None`` or ``'none'``: no weighting
  91. - ``'linear'``: linear weighting
  92. - ``'quadratic'``: quadratic weighting
  93. ignore_index:
  94. Specifies a target value that is ignored and does not contribute to the metric calculation
  95. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  96. Set to ``False`` for faster computations.
  97. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  98. Example (preds is int tensor):
  99. >>> from torch import tensor
  100. >>> from torchmetrics.functional.classification import binary_cohen_kappa
  101. >>> target = tensor([1, 1, 0, 0])
  102. >>> preds = tensor([0, 1, 0, 0])
  103. >>> binary_cohen_kappa(preds, target)
  104. tensor(0.5000)
  105. Example (preds is float tensor):
  106. >>> from torchmetrics.functional.classification import binary_cohen_kappa
  107. >>> target = tensor([1, 1, 0, 0])
  108. >>> preds = tensor([0.35, 0.85, 0.48, 0.01])
  109. >>> binary_cohen_kappa(preds, target)
  110. tensor(0.5000)
  111. """
  112. if validate_args:
  113. _binary_cohen_kappa_arg_validation(threshold, ignore_index, weights)
  114. _binary_confusion_matrix_tensor_validation(preds, target, ignore_index)
  115. preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index)
  116. confmat = _binary_confusion_matrix_update(preds, target)
  117. return _cohen_kappa_reduce(confmat, weights)
  118. def _multiclass_cohen_kappa_arg_validation(
  119. num_classes: int,
  120. ignore_index: Optional[int] = None,
  121. weights: Optional[Literal["linear", "quadratic", "none"]] = None,
  122. ) -> None:
  123. """Validate non tensor input.
  124. - ``num_classes`` has to be a int larger than 1
  125. - ``ignore_index`` has to be None or int
  126. - ``weights`` has to be "linear" | "quadratic" | "none" | None
  127. """
  128. _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize=None)
  129. allowed_weights = ("linear", "quadratic", "none", None)
  130. if weights not in allowed_weights:
  131. raise ValueError(f"Expected argument `weight` to be one of {allowed_weights}, but got {weights}.")
  132. def multiclass_cohen_kappa(
  133. preds: Tensor,
  134. target: Tensor,
  135. num_classes: int,
  136. weights: Optional[Literal["linear", "quadratic", "none"]] = None,
  137. ignore_index: Optional[int] = None,
  138. validate_args: bool = True,
  139. ) -> Tensor:
  140. r"""Calculate `Cohen's kappa score`_ that measures inter-annotator agreement for multiclass tasks.
  141. .. math::
  142. \kappa = (p_o - p_e) / (1 - p_e)
  143. where :math:`p_o` is the empirical probability of agreement and :math:`p_e` is
  144. the expected agreement when both annotators assign labels randomly. Note that
  145. :math:`p_e` is estimated using a per-annotator empirical prior over the
  146. class labels.
  147. Accepts the following input tensors:
  148. - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point
  149. we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into
  150. an int tensor.
  151. - ``target`` (int tensor): ``(N, ...)``
  152. Additional dimension ``...`` will be flattened into the batch dimension.
  153. Args:
  154. preds: Tensor with predictions
  155. target: Tensor with true labels
  156. num_classes: Integer specifying the number of classes
  157. weights: Weighting type to calculate the score. Choose from:
  158. - ``None`` or ``'none'``: no weighting
  159. - ``'linear'``: linear weighting
  160. - ``'quadratic'``: quadratic weighting
  161. ignore_index:
  162. Specifies a target value that is ignored and does not contribute to the metric calculation
  163. validate_args: bool indicating if input arguments and tensors should be validated for correctness.
  164. Set to ``False`` for faster computations.
  165. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  166. Example (pred is integer tensor):
  167. >>> from torch import tensor
  168. >>> from torchmetrics.functional.classification import multiclass_cohen_kappa
  169. >>> target = tensor([2, 1, 0, 0])
  170. >>> preds = tensor([2, 1, 0, 1])
  171. >>> multiclass_cohen_kappa(preds, target, num_classes=3)
  172. tensor(0.6364)
  173. Example (pred is float tensor):
  174. >>> from torchmetrics.functional.classification import multiclass_cohen_kappa
  175. >>> target = tensor([2, 1, 0, 0])
  176. >>> preds = tensor([[0.16, 0.26, 0.58],
  177. ... [0.22, 0.61, 0.17],
  178. ... [0.71, 0.09, 0.20],
  179. ... [0.05, 0.82, 0.13]])
  180. >>> multiclass_cohen_kappa(preds, target, num_classes=3)
  181. tensor(0.6364)
  182. """
  183. if validate_args:
  184. _multiclass_cohen_kappa_arg_validation(num_classes, ignore_index, weights)
  185. _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index)
  186. preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index)
  187. confmat = _multiclass_confusion_matrix_update(preds, target, num_classes)
  188. return _cohen_kappa_reduce(confmat, weights)
  189. def cohen_kappa(
  190. preds: Tensor,
  191. target: Tensor,
  192. task: Literal["binary", "multiclass"],
  193. threshold: float = 0.5,
  194. num_classes: Optional[int] = None,
  195. weights: Optional[Literal["linear", "quadratic", "none"]] = None,
  196. ignore_index: Optional[int] = None,
  197. validate_args: bool = True,
  198. ) -> Tensor:
  199. r"""Calculate `Cohen's kappa score`_ that measures inter-annotator agreement. It is defined as.
  200. .. math::
  201. \kappa = (p_o - p_e) / (1 - p_e)
  202. where :math:`p_o` is the empirical probability of agreement and :math:`p_e` is
  203. the expected agreement when both annotators assign labels randomly. Note that
  204. :math:`p_e` is estimated using a per-annotator empirical prior over the
  205. class labels.
  206. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
  207. ``task`` argument to either ``'binary'`` or ``'multiclass'``. See the documentation of
  208. :func:`~torchmetrics.functional.classification.binary_cohen_kappa` and
  209. :func:`~torchmetrics.functional.classification.multiclass_cohen_kappa` for the specific details of
  210. each argument influence and examples.
  211. Legacy Example:
  212. >>> from torch import tensor
  213. >>> target = tensor([1, 1, 0, 0])
  214. >>> preds = tensor([0, 1, 0, 0])
  215. >>> cohen_kappa(preds, target, task="multiclass", num_classes=2)
  216. tensor(0.5000)
  217. """
  218. task = ClassificationTaskNoMultilabel.from_str(task)
  219. if task == ClassificationTaskNoMultilabel.BINARY:
  220. return binary_cohen_kappa(preds, target, threshold, weights, ignore_index, validate_args)
  221. if task == ClassificationTaskNoMultilabel.MULTICLASS:
  222. if not isinstance(num_classes, int):
  223. raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
  224. return multiclass_cohen_kappa(preds, target, num_classes, weights, ignore_index, validate_args)
  225. raise ValueError(f"Not handled value: {task}")