csi.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, List, Optional
  15. import torch
  16. from torch import Tensor
  17. from torchmetrics.functional.regression.csi import _critical_success_index_compute, _critical_success_index_update
  18. from torchmetrics.metric import Metric
  19. from torchmetrics.utilities import dim_zero_cat
  20. class CriticalSuccessIndex(Metric):
  21. r"""Calculate critical success index (CSI).
  22. Critical success index (also known as the threat score) is a statistic used weather forecasting that measures
  23. forecast performance over inputs binarized at a specified threshold. It is defined as:
  24. .. math:: \text{CSI} = \frac{\text{TP}}{\text{TP}+\text{FN}+\text{FP}}
  25. Where :math:`\text{TP}`, :math:`\text{FN}` and :math:`\text{FP}` represent the number of true positives, false
  26. negatives and false positives respectively after binarizing the input tensors.
  27. Args:
  28. threshold: Values above or equal to threshold are replaced with 1, below by 0
  29. keep_sequence_dim: Index of the sequence dimension if the inputs are sequences of images. If specified,
  30. the score will be calculated separately for each image in the sequence. If ``None``, the score will be
  31. calculated across all dimensions.
  32. Example:
  33. >>> import torch
  34. >>> from torchmetrics.regression import CriticalSuccessIndex
  35. >>> x = torch.Tensor([[0.2, 0.7], [0.9, 0.3]])
  36. >>> y = torch.Tensor([[0.4, 0.2], [0.8, 0.6]])
  37. >>> csi = CriticalSuccessIndex(0.5)
  38. >>> csi(x, y)
  39. tensor(0.3333)
  40. Example:
  41. >>> import torch
  42. >>> from torchmetrics.regression import CriticalSuccessIndex
  43. >>> x = torch.Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]])
  44. >>> y = torch.Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]])
  45. >>> csi = CriticalSuccessIndex(0.5, keep_sequence_dim=0)
  46. >>> csi(x, y)
  47. tensor([0.3333, 0.3333])
  48. """
  49. is_differentiable: bool = False
  50. higher_is_better: bool = True
  51. hits: Tensor
  52. misses: Tensor
  53. false_alarms: Tensor
  54. hits_list: List[Tensor]
  55. misses_list: List[Tensor]
  56. false_alarms_list: List[Tensor]
  57. def __init__(self, threshold: float, keep_sequence_dim: Optional[int] = None, **kwargs: Any) -> None:
  58. super().__init__(**kwargs)
  59. self.threshold = float(threshold)
  60. if keep_sequence_dim and (not isinstance(keep_sequence_dim, int) or keep_sequence_dim < 0):
  61. raise ValueError(f"Expected keep_sequence_dim to be a non-negative integer but got {keep_sequence_dim}")
  62. self.keep_sequence_dim = keep_sequence_dim
  63. if keep_sequence_dim is None:
  64. self.add_state("hits", default=torch.tensor(0), dist_reduce_fx="sum")
  65. self.add_state("misses", default=torch.tensor(0), dist_reduce_fx="sum")
  66. self.add_state("false_alarms", default=torch.tensor(0), dist_reduce_fx="sum")
  67. else:
  68. self.add_state("hits_list", default=[], dist_reduce_fx="cat")
  69. self.add_state("misses_list", default=[], dist_reduce_fx="cat")
  70. self.add_state("false_alarms_list", default=[], dist_reduce_fx="cat")
  71. def update(self, preds: Tensor, target: Tensor) -> None:
  72. """Update state with predictions and targets."""
  73. hits, misses, false_alarms = _critical_success_index_update(
  74. preds, target, self.threshold, self.keep_sequence_dim
  75. )
  76. if self.keep_sequence_dim is None:
  77. self.hits += hits
  78. self.misses += misses
  79. self.false_alarms += false_alarms
  80. else:
  81. self.hits_list.append(hits)
  82. self.misses_list.append(misses)
  83. self.false_alarms_list.append(false_alarms)
  84. def compute(self) -> Tensor:
  85. """Compute critical success index over state."""
  86. if self.keep_sequence_dim is None:
  87. hits = self.hits
  88. misses = self.misses
  89. false_alarms = self.false_alarms
  90. else:
  91. hits = dim_zero_cat(self.hits_list)
  92. misses = dim_zero_cat(self.misses_list)
  93. false_alarms = dim_zero_cat(self.false_alarms_list)
  94. return _critical_success_index_compute(hits, misses, false_alarms)