nrmse.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Union
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.regression.mse import _mean_squared_error_update
  19. def _normalized_root_mean_squared_error_update(
  20. preds: Tensor, target: Tensor, num_outputs: int, normalization: Literal["mean", "range", "std", "l2"] = "mean"
  21. ) -> tuple[Tensor, int, Tensor]:
  22. """Updates and returns the sum of squared errors and the number of observations for NRMSE computation.
  23. Args:
  24. preds: Predicted tensor
  25. target: Ground truth tensor
  26. num_outputs: Number of outputs in multioutput setting
  27. normalization: type of normalization to be applied. Choose from "mean", "range", "std", "l2"
  28. """
  29. sum_squared_error, num_obs = _mean_squared_error_update(preds, target, num_outputs)
  30. target = target.view(-1) if num_outputs == 1 else target
  31. if normalization == "mean":
  32. denom = torch.mean(target, dim=0)
  33. elif normalization == "range":
  34. denom = torch.max(target, dim=0).values - torch.min(target, dim=0).values
  35. elif normalization == "std":
  36. denom = torch.std(target, correction=0, dim=0)
  37. elif normalization == "l2":
  38. denom = torch.norm(target, p=2, dim=0)
  39. else:
  40. raise ValueError(
  41. f"Argument `normalization` should be either 'mean', 'range', 'std' or 'l2' but got {normalization}"
  42. )
  43. return sum_squared_error, num_obs, denom
  44. def _normalized_root_mean_squared_error_compute(
  45. sum_squared_error: Tensor, num_obs: Union[int, Tensor], denom: Tensor
  46. ) -> Tensor:
  47. """Calculates RMSE and normalizes it."""
  48. rmse = torch.sqrt(sum_squared_error / num_obs)
  49. return rmse / denom
  50. def normalized_root_mean_squared_error(
  51. preds: Tensor,
  52. target: Tensor,
  53. normalization: Literal["mean", "range", "std", "l2"] = "mean",
  54. num_outputs: int = 1,
  55. ) -> Tensor:
  56. """Calculates the `Normalized Root Mean Squared Error`_ (NRMSE) also know as scatter index.
  57. Args:
  58. preds: estimated labels
  59. target: ground truth labels
  60. normalization: type of normalization to be applied. Choose from "mean", "range", "std", "l2" which corresponds
  61. to normalizing the RMSE by the mean of the target, the range of the target, the standard deviation of the
  62. target or the L2 norm of the target.
  63. num_outputs: Number of outputs in multioutput setting
  64. Return:
  65. Tensor with the NRMSE score
  66. Example:
  67. >>> import torch
  68. >>> from torchmetrics.functional.regression import normalized_root_mean_squared_error
  69. >>> preds = torch.tensor([0., 1, 2, 3])
  70. >>> target = torch.tensor([0., 1, 2, 2])
  71. >>> normalized_root_mean_squared_error(preds, target, normalization="mean")
  72. tensor(0.4000)
  73. >>> normalized_root_mean_squared_error(preds, target, normalization="range")
  74. tensor(0.2500)
  75. >>> normalized_root_mean_squared_error(preds, target, normalization="std")
  76. tensor(0.6030)
  77. >>> normalized_root_mean_squared_error(preds, target, normalization="l2")
  78. tensor(0.1667)
  79. Example (multioutput):
  80. >>> import torch
  81. >>> from torchmetrics.functional.regression import normalized_root_mean_squared_error
  82. >>> preds = torch.tensor([[0., 1], [2, 3], [4, 5], [6, 7]])
  83. >>> target = torch.tensor([[0., 1], [3, 3], [4, 5], [8, 9]])
  84. >>> normalized_root_mean_squared_error(preds, target, normalization="mean", num_outputs=2)
  85. tensor([0.2981, 0.2222])
  86. """
  87. sum_squared_error, num_obs, denom = _normalized_root_mean_squared_error_update(
  88. preds, target, num_outputs=num_outputs, normalization=normalization
  89. )
  90. return _normalized_root_mean_squared_error_compute(sum_squared_error, num_obs, denom)