nrmse.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. import torch
  17. from torch import Tensor
  18. from typing_extensions import Literal
  19. from torchmetrics.functional.regression.nrmse import (
  20. _mean_squared_error_update,
  21. _normalized_root_mean_squared_error_compute,
  22. )
  23. from torchmetrics.metric import Metric
  24. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  25. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  26. if not _MATPLOTLIB_AVAILABLE:
  27. __doctest_skip__ = ["NormalizedRootMeanSquaredError.plot"]
  28. def _final_aggregation(
  29. min_val: Tensor,
  30. max_val: Tensor,
  31. mean_val: Tensor,
  32. var_val: Tensor,
  33. target_squared: Tensor,
  34. total: Tensor,
  35. normalization: Literal["mean", "range", "std", "l2"] = "mean",
  36. ) -> Tensor:
  37. """In the case of multiple devices we need to aggregate the statistics from the different devices."""
  38. if len(min_val) == 1:
  39. if normalization == "mean":
  40. return mean_val[0]
  41. if normalization == "range":
  42. return max_val[0] - min_val[0]
  43. if normalization == "std":
  44. return var_val[0]
  45. if normalization == "l2":
  46. return target_squared[0]
  47. min_val_1, max_val_1, mean_val_1, var_val_1, target_squared_1, total_1 = (
  48. min_val[0],
  49. max_val[0],
  50. mean_val[0],
  51. var_val[0],
  52. target_squared[0],
  53. total[0],
  54. )
  55. for i in range(1, len(min_val)):
  56. min_val_2, max_val_2, mean_val_2, var_val_2, target_squared_2, total_2 = (
  57. min_val[i],
  58. max_val[i],
  59. mean_val[i],
  60. var_val[i],
  61. target_squared[i],
  62. total[i],
  63. )
  64. # update total and mean
  65. total = total_1 + total_2
  66. mean = (total_1 * mean_val_1 + total_2 * mean_val_2) / total
  67. # update variance
  68. _temp = (total_1 + 1) * mean - total_1 * mean_val_1
  69. var_val_1 += (_temp - mean_val_1) * (_temp - mean) - (_temp - mean) ** 2
  70. _temp = (total_2 + 1) * mean - total_2 * mean_val_2
  71. var_val_2 += (_temp - mean_val_2) * (_temp - mean) - (_temp - mean) ** 2
  72. var = var_val_1 + var_val_2
  73. # update min and max and target squared
  74. min_val = torch.min(min_val_1, min_val_2)
  75. max_val = torch.max(max_val_1, max_val_2)
  76. target_squared = target_squared_1 + target_squared_2
  77. if normalization == "mean":
  78. return mean
  79. if normalization == "range":
  80. return max_val - min_val
  81. if normalization == "std":
  82. return (var / total).sqrt()
  83. return target_squared.sqrt()
  84. class NormalizedRootMeanSquaredError(Metric):
  85. r"""Calculates the `Normalized Root Mean Squared Error`_ (NRMSE) also know as scatter index.
  86. The metric is defined as:
  87. .. math::
  88. \text{NRMSE} = \frac{\text{RMSE}}{\text{denom}}
  89. where RMSE is the root mean squared error and `denom` is the normalization factor. The normalization factor can be
  90. either be the mean, range, standard deviation or L2 norm of the target, which can be set using the `normalization`
  91. argument.
  92. As input to ``forward`` and ``update`` the metric accepts the following input:
  93. - ``preds`` (:class:`~torch.Tensor`): Predictions from model
  94. - ``target`` (:class:`~torch.Tensor`): Ground truth values
  95. As output of ``forward`` and ``compute`` the metric returns the following output:
  96. - ``nrmse`` (:class:`~torch.Tensor`): A tensor with the mean squared error
  97. Args:
  98. normalization: type of normalization to be applied. Choose from "mean", "range", "std", "l2" which corresponds
  99. to normalizing the RMSE by the mean of the target, the range of the target, the standard deviation of the
  100. target or the L2 norm of the target.
  101. num_outputs: Number of outputs in multioutput setting
  102. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  103. Example::
  104. Single output normalized root mean squared error computation:
  105. >>> import torch
  106. >>> from torchmetrics import NormalizedRootMeanSquaredError
  107. >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0])
  108. >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0])
  109. >>> nrmse = NormalizedRootMeanSquaredError(normalization="mean")
  110. >>> nrmse(preds, target)
  111. tensor(0.1919)
  112. >>> nrmse = NormalizedRootMeanSquaredError(normalization="range")
  113. >>> nrmse(preds, target)
  114. tensor(0.1701)
  115. Example::
  116. Multioutput normalized root mean squared error computation:
  117. >>> import torch
  118. >>> from torchmetrics import NormalizedRootMeanSquaredError
  119. >>> preds = torch.tensor([[0., 1], [2, 3], [4, 5], [6, 7]])
  120. >>> target = torch.tensor([[0., 1], [3, 3], [4, 5], [8, 9]])
  121. >>> nrmse = NormalizedRootMeanSquaredError(num_outputs=2)
  122. >>> nrmse(preds, target)
  123. tensor([0.2981, 0.2222])
  124. """
  125. is_differentiable: bool = True
  126. higher_is_better: bool = False
  127. full_state_update: bool = True
  128. plot_lower_bound: float = 0.0
  129. sum_squared_error: Tensor
  130. total: Tensor
  131. min_val: Tensor
  132. max_val: Tensor
  133. target_squared: Tensor
  134. mean_val: Tensor
  135. var_val: Tensor
  136. def __init__(
  137. self,
  138. normalization: Literal["mean", "range", "std", "l2"] = "mean",
  139. num_outputs: int = 1,
  140. **kwargs: Any,
  141. ) -> None:
  142. super().__init__(**kwargs)
  143. if normalization not in ("mean", "range", "std", "l2"):
  144. raise ValueError(
  145. f"Argument `normalization` should be either 'mean', 'range', 'std' or 'l2', but got {normalization}"
  146. )
  147. self.normalization = normalization
  148. if not (isinstance(num_outputs, int) and num_outputs > 0):
  149. raise ValueError(f"Expected num_outputs to be a positive integer but got {num_outputs}")
  150. self.num_outputs = num_outputs
  151. self.add_state("sum_squared_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum")
  152. self.add_state("total", default=torch.zeros(num_outputs), dist_reduce_fx=None)
  153. self.add_state("min_val", default=float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None)
  154. self.add_state("max_val", default=-float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None)
  155. self.add_state("mean_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
  156. self.add_state("var_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
  157. self.add_state("target_squared", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
  158. def update(self, preds: Tensor, target: Tensor) -> None:
  159. """Update state with predictions and targets.
  160. See `mean_squared_error_update` for details.
  161. """
  162. sum_squared_error, num_obs = _mean_squared_error_update(preds, target, self.num_outputs)
  163. self.sum_squared_error += sum_squared_error
  164. target = target.view(-1) if self.num_outputs == 1 else target
  165. # Update min and max and target squared
  166. self.min_val = torch.minimum(target.min(dim=0).values, self.min_val)
  167. self.max_val = torch.maximum(target.max(dim=0).values, self.max_val)
  168. self.target_squared += (target**2).sum(dim=0)
  169. # Update mean and variance
  170. new_mean = (self.total * self.mean_val + target.sum(dim=0)) / (self.total + num_obs)
  171. self.total += num_obs
  172. new_var = ((target - new_mean) * (target - self.mean_val)).sum(dim=0)
  173. self.mean_val = new_mean
  174. self.var_val += new_var
  175. def compute(self) -> Tensor:
  176. """Computes NRMSE over state.
  177. See `mean_squared_error_compute` for details.
  178. """
  179. if (self.num_outputs == 1 and self.mean_val.numel() > 1) or (self.num_outputs > 1 and self.mean_val.ndim > 1):
  180. denom = _final_aggregation(
  181. min_val=self.min_val,
  182. max_val=self.max_val,
  183. mean_val=self.mean_val,
  184. var_val=self.var_val,
  185. target_squared=self.target_squared,
  186. total=self.total,
  187. normalization=self.normalization,
  188. )
  189. total = self.total.squeeze().sum(dim=0)
  190. else:
  191. if self.normalization == "mean":
  192. denom = self.mean_val
  193. elif self.normalization == "range":
  194. denom = self.max_val - self.min_val
  195. elif self.normalization == "std":
  196. denom = torch.sqrt(self.var_val / self.total)
  197. else:
  198. denom = torch.sqrt(self.target_squared)
  199. total = self.total
  200. return _normalized_root_mean_squared_error_compute(self.sum_squared_error, total, denom)
  201. def plot(
  202. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  203. ) -> _PLOT_OUT_TYPE:
  204. """Plot a single or multiple values from the metric.
  205. Args:
  206. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  207. If no value is provided, will automatically call `metric.compute` and plot that result.
  208. ax: An matplotlib axis object. If provided will add plot to that axis
  209. Returns:
  210. Figure and Axes object
  211. Raises:
  212. ModuleNotFoundError:
  213. If `matplotlib` is not installed
  214. .. plot::
  215. :scale: 75
  216. >>> from torch import randn
  217. >>> # Example plotting a single value
  218. >>> from torchmetrics.regression import NormalizedRootMeanSquaredError
  219. >>> metric = NormalizedRootMeanSquaredError()
  220. >>> metric.update(randn(10,), randn(10,))
  221. >>> fig_, ax_ = metric.plot()
  222. .. plot::
  223. :scale: 75
  224. >>> from torch import randn
  225. >>> # Example plotting multiple values
  226. >>> from torchmetrics.regression import NormalizedRootMeanSquaredError
  227. >>> metric = NormalizedRootMeanSquaredError()
  228. >>> values = []
  229. >>> for _ in range(10):
  230. ... values.append(metric(randn(10,), randn(10,)))
  231. >>> fig, ax = metric.plot(values)
  232. """
  233. return self._plot(val, ax)