csi.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. import torch
  16. from torch import Tensor
  17. from torchmetrics.utilities.checks import _check_same_shape
  18. from torchmetrics.utilities.compute import _safe_divide
  19. def _critical_success_index_update(
  20. preds: Tensor, target: Tensor, threshold: float, keep_sequence_dim: Optional[int] = None
  21. ) -> tuple[Tensor, Tensor, Tensor]:
  22. """Update and return variables required to compute Critical Success Index. Checks for same shape of tensors.
  23. Args:
  24. preds: Predicted tensor
  25. target: Ground truth tensor
  26. threshold: Values above or equal to threshold are replaced with 1, below by 0
  27. keep_sequence_dim: Index of the sequence dimension if the inputs are sequences of images. If specified,
  28. the score will be calculated separately for each image in the sequence. If ``None``, the score will be
  29. calculated across all dimensions.
  30. """
  31. _check_same_shape(preds, target)
  32. if keep_sequence_dim is None:
  33. sum_dims = None
  34. elif not 0 <= keep_sequence_dim < preds.ndim:
  35. raise ValueError(f"Expected keep_sequence dim to be in range [0, {preds.ndim}] but got {keep_sequence_dim}")
  36. else:
  37. sum_dims = tuple(i for i in range(preds.ndim) if i != keep_sequence_dim)
  38. # binarize the tensors with the threshold
  39. preds_bin = (preds >= threshold).bool()
  40. target_bin = (target >= threshold).bool()
  41. if keep_sequence_dim is None:
  42. hits = torch.sum(preds_bin & target_bin).int()
  43. misses = torch.sum((preds_bin ^ target_bin) & target_bin).int()
  44. false_alarms = torch.sum((preds_bin ^ target_bin) & preds_bin).int()
  45. else:
  46. hits = torch.sum(preds_bin & target_bin, dim=sum_dims).int()
  47. misses = torch.sum((preds_bin ^ target_bin) & target_bin, dim=sum_dims).int()
  48. false_alarms = torch.sum((preds_bin ^ target_bin) & preds_bin, dim=sum_dims).int()
  49. return hits, misses, false_alarms
  50. def _critical_success_index_compute(hits: Tensor, misses: Tensor, false_alarms: Tensor) -> Tensor:
  51. """Compute critical success index.
  52. Args:
  53. hits: Number of true positives after binarization
  54. misses: Number of false negatives after binarization
  55. false_alarms: Number of false positives after binarization
  56. Returns:
  57. If input tensors are 5-dimensional and ``keep_sequence_dim=True``, the metric returns a ``(S,)`` vector
  58. with CSI scores for each image in the sequence. Otherwise, it returns a scalar tensor with the CSI score.
  59. """
  60. return _safe_divide(hits, hits + misses + false_alarms)
  61. def critical_success_index(
  62. preds: Tensor, target: Tensor, threshold: float, keep_sequence_dim: Optional[int] = None
  63. ) -> Tensor:
  64. """Compute critical success index.
  65. Args:
  66. preds: Predicted tensor
  67. target: Ground truth tensor
  68. threshold: Values above or equal to threshold are replaced with 1, below by 0
  69. keep_sequence_dim: Index of the sequence dimension if the inputs are sequences of images. If specified,
  70. the score will be calculated separately for each image in the sequence. If ``None``, the score will be
  71. calculated across all dimensions.
  72. Returns:
  73. If ``keep_sequence_dim`` is specified, the metric returns a vector of with CSI scores for each image
  74. in the sequence. Otherwise, it returns a scalar tensor with the CSI score.
  75. Example:
  76. >>> import torch
  77. >>> from torchmetrics.functional.regression import critical_success_index
  78. >>> x = torch.Tensor([[0.2, 0.7], [0.9, 0.3]])
  79. >>> y = torch.Tensor([[0.4, 0.2], [0.8, 0.6]])
  80. >>> critical_success_index(x, y, 0.5)
  81. tensor(0.3333)
  82. Example:
  83. >>> import torch
  84. >>> from torchmetrics.functional.regression import critical_success_index
  85. >>> x = torch.Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]])
  86. >>> y = torch.Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]])
  87. >>> critical_success_index(x, y, 0.5, keep_sequence_dim=0)
  88. tensor([0.3333, 0.3333])
  89. """
  90. hits, misses, false_alarms = _critical_success_index_update(preds, target, threshold, keep_sequence_dim)
  91. return _critical_success_index_compute(hits, misses, false_alarms)