| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Tuple
- import torch
- from torch import Tensor
- from torchmetrics.utilities.checks import _check_same_shape
- def _crps_update(preds: Tensor, target: Tensor) -> Tuple[int, Tensor, Tensor]:
- """Compute intermediate CRPS values before aggregation.
- Args:
- preds: Tensor of shape (batch_size, ensemble_members)
- target: Tensor of shape (batch_size,)
- Returns:
- batch_size: int
- diff: Tensor (batch-wise absolute error term)
- ensemble_sum: Tensor (pairwise ensemble term)
- """
- # Only second dimension should deviate in shape (the ensemble members)
- _check_same_shape(preds[:, 0], target)
- batch_size, n_ensemble_members = preds.shape
- if n_ensemble_members < 2:
- raise ValueError(f"CRPS requires at least 2 ensemble members, but you provided {preds.shape}.")
- # sort forecasts
- preds = torch.sort(preds, dim=1)[0]
- # inflate observations:
- observation_inflated = target.unsqueeze(1).expand_as(preds)
- # Compute mean absolute difference between predictions and target
- diff = torch.sum(torch.abs(preds - observation_inflated), dim=1) / n_ensemble_members
- # Compute ensemble term using the reference implementation formula
- ensemble_diffs = torch.abs(preds.unsqueeze(2) - preds.unsqueeze(1))
- ensemble_sum = torch.sum(ensemble_diffs, dim=(1, 2)) / (2 * n_ensemble_members * n_ensemble_members)
- return batch_size, diff, ensemble_sum
- def _crps_compute(batch_size: int, diff: Tensor, ensemble_sum: Tensor) -> Tensor:
- """Final CRPS computation."""
- return torch.mean(diff - ensemble_sum) # Changed from sum to mean
- def continuous_ranked_probability_score(preds: Tensor, target: Tensor) -> Tensor:
- r"""Computes continuous ranked probability score.
- .. math::
- CRPS(F, y) = \int_{-\infty}^{\infty} (F(x) - 1_{x \geq y})^2 dx
- where :math:`F` is the predicted cumulative distribution function and :math:`y` is the true target. The metric is
- usually used to evaluate probabilistic regression models, such as forecasting models. A lower CRPS indicates a
- better forecast, meaning that forecasted probabilities are closer to the true observed values. CRPS can also be
- seen as a generalization of the brier score for non binary classification problems.
- Args:
- preds: a 2d tensor of shape (batch_size, ensemble_members) with predictions. The second dimension represents
- the ensemble members.
- target: a 1d tensor of shape (batch_size) with the target values.
- Return:
- Tensor with CRPS
- Raises:
- ValueError:
- If the number of ensemble members is less than 2.
- ValueError:
- If the first dimension of preds and target do not match.
- Example::
- >>> from torchmetrics.functional.regression import continuous_ranked_probability_score
- >>> from torch import randn
- >>> preds = randn(10, 5)
- >>> target = randn(10)
- >>> continuous_ranked_probability_score(preds, target)
- tensor(0.7731)
- """
- batch_size, diff, ensemble_sum = _crps_update(preds, target)
- return _crps_compute(batch_size, diff, ensemble_sum)
|