crps.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Tuple
  15. import torch
  16. from torch import Tensor
  17. from torchmetrics.utilities.checks import _check_same_shape
  18. def _crps_update(preds: Tensor, target: Tensor) -> Tuple[int, Tensor, Tensor]:
  19. """Compute intermediate CRPS values before aggregation.
  20. Args:
  21. preds: Tensor of shape (batch_size, ensemble_members)
  22. target: Tensor of shape (batch_size,)
  23. Returns:
  24. batch_size: int
  25. diff: Tensor (batch-wise absolute error term)
  26. ensemble_sum: Tensor (pairwise ensemble term)
  27. """
  28. # Only second dimension should deviate in shape (the ensemble members)
  29. _check_same_shape(preds[:, 0], target)
  30. batch_size, n_ensemble_members = preds.shape
  31. if n_ensemble_members < 2:
  32. raise ValueError(f"CRPS requires at least 2 ensemble members, but you provided {preds.shape}.")
  33. # sort forecasts
  34. preds = torch.sort(preds, dim=1)[0]
  35. # inflate observations:
  36. observation_inflated = target.unsqueeze(1).expand_as(preds)
  37. # Compute mean absolute difference between predictions and target
  38. diff = torch.sum(torch.abs(preds - observation_inflated), dim=1) / n_ensemble_members
  39. # Compute ensemble term using the reference implementation formula
  40. ensemble_diffs = torch.abs(preds.unsqueeze(2) - preds.unsqueeze(1))
  41. ensemble_sum = torch.sum(ensemble_diffs, dim=(1, 2)) / (2 * n_ensemble_members * n_ensemble_members)
  42. return batch_size, diff, ensemble_sum
  43. def _crps_compute(batch_size: int, diff: Tensor, ensemble_sum: Tensor) -> Tensor:
  44. """Final CRPS computation."""
  45. return torch.mean(diff - ensemble_sum) # Changed from sum to mean
  46. def continuous_ranked_probability_score(preds: Tensor, target: Tensor) -> Tensor:
  47. r"""Computes continuous ranked probability score.
  48. .. math::
  49. CRPS(F, y) = \int_{-\infty}^{\infty} (F(x) - 1_{x \geq y})^2 dx
  50. where :math:`F` is the predicted cumulative distribution function and :math:`y` is the true target. The metric is
  51. usually used to evaluate probabilistic regression models, such as forecasting models. A lower CRPS indicates a
  52. better forecast, meaning that forecasted probabilities are closer to the true observed values. CRPS can also be
  53. seen as a generalization of the brier score for non binary classification problems.
  54. Args:
  55. preds: a 2d tensor of shape (batch_size, ensemble_members) with predictions. The second dimension represents
  56. the ensemble members.
  57. target: a 1d tensor of shape (batch_size) with the target values.
  58. Return:
  59. Tensor with CRPS
  60. Raises:
  61. ValueError:
  62. If the number of ensemble members is less than 2.
  63. ValueError:
  64. If the first dimension of preds and target do not match.
  65. Example::
  66. >>> from torchmetrics.functional.regression import continuous_ranked_probability_score
  67. >>> from torch import randn
  68. >>> preds = randn(10, 5)
  69. >>> target = randn(10)
  70. >>> continuous_ranked_probability_score(preds, target)
  71. tensor(0.7731)
  72. """
  73. batch_size, diff, ensemble_sum = _crps_update(preds, target)
  74. return _crps_compute(batch_size, diff, ensemble_sum)