| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Any, List, Optional
- import torch
- from torch import Tensor
- from torchmetrics.functional.regression.csi import _critical_success_index_compute, _critical_success_index_update
- from torchmetrics.metric import Metric
- from torchmetrics.utilities import dim_zero_cat
- class CriticalSuccessIndex(Metric):
- r"""Calculate critical success index (CSI).
- Critical success index (also known as the threat score) is a statistic used weather forecasting that measures
- forecast performance over inputs binarized at a specified threshold. It is defined as:
- .. math:: \text{CSI} = \frac{\text{TP}}{\text{TP}+\text{FN}+\text{FP}}
- Where :math:`\text{TP}`, :math:`\text{FN}` and :math:`\text{FP}` represent the number of true positives, false
- negatives and false positives respectively after binarizing the input tensors.
- Args:
- threshold: Values above or equal to threshold are replaced with 1, below by 0
- keep_sequence_dim: Index of the sequence dimension if the inputs are sequences of images. If specified,
- the score will be calculated separately for each image in the sequence. If ``None``, the score will be
- calculated across all dimensions.
- Example:
- >>> import torch
- >>> from torchmetrics.regression import CriticalSuccessIndex
- >>> x = torch.Tensor([[0.2, 0.7], [0.9, 0.3]])
- >>> y = torch.Tensor([[0.4, 0.2], [0.8, 0.6]])
- >>> csi = CriticalSuccessIndex(0.5)
- >>> csi(x, y)
- tensor(0.3333)
- Example:
- >>> import torch
- >>> from torchmetrics.regression import CriticalSuccessIndex
- >>> x = torch.Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]])
- >>> y = torch.Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]])
- >>> csi = CriticalSuccessIndex(0.5, keep_sequence_dim=0)
- >>> csi(x, y)
- tensor([0.3333, 0.3333])
- """
- is_differentiable: bool = False
- higher_is_better: bool = True
- hits: Tensor
- misses: Tensor
- false_alarms: Tensor
- hits_list: List[Tensor]
- misses_list: List[Tensor]
- false_alarms_list: List[Tensor]
- def __init__(self, threshold: float, keep_sequence_dim: Optional[int] = None, **kwargs: Any) -> None:
- super().__init__(**kwargs)
- self.threshold = float(threshold)
- if keep_sequence_dim and (not isinstance(keep_sequence_dim, int) or keep_sequence_dim < 0):
- raise ValueError(f"Expected keep_sequence_dim to be a non-negative integer but got {keep_sequence_dim}")
- self.keep_sequence_dim = keep_sequence_dim
- if keep_sequence_dim is None:
- self.add_state("hits", default=torch.tensor(0), dist_reduce_fx="sum")
- self.add_state("misses", default=torch.tensor(0), dist_reduce_fx="sum")
- self.add_state("false_alarms", default=torch.tensor(0), dist_reduce_fx="sum")
- else:
- self.add_state("hits_list", default=[], dist_reduce_fx="cat")
- self.add_state("misses_list", default=[], dist_reduce_fx="cat")
- self.add_state("false_alarms_list", default=[], dist_reduce_fx="cat")
- def update(self, preds: Tensor, target: Tensor) -> None:
- """Update state with predictions and targets."""
- hits, misses, false_alarms = _critical_success_index_update(
- preds, target, self.threshold, self.keep_sequence_dim
- )
- if self.keep_sequence_dim is None:
- self.hits += hits
- self.misses += misses
- self.false_alarms += false_alarms
- else:
- self.hits_list.append(hits)
- self.misses_list.append(misses)
- self.false_alarms_list.append(false_alarms)
- def compute(self) -> Tensor:
- """Compute critical success index over state."""
- if self.keep_sequence_dim is None:
- hits = self.hits
- misses = self.misses
- false_alarms = self.false_alarms
- else:
- hits = dim_zero_cat(self.hits_list)
- misses = dim_zero_cat(self.misses_list)
- false_alarms = dim_zero_cat(self.false_alarms_list)
- return _critical_success_index_compute(hits, misses, false_alarms)
|