r2.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Union
  15. import torch
  16. from torch import Tensor
  17. from torchmetrics.utilities import rank_zero_warn
  18. from torchmetrics.utilities.checks import _check_same_shape
  19. def _r2_score_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor, Tensor, int]:
  20. """Update and returns variables required to compute R2 score.
  21. Check for same shape and 1D/2D input tensors.
  22. Args:
  23. preds: Predicted tensor
  24. target: Ground truth tensor
  25. """
  26. _check_same_shape(preds, target)
  27. if preds.ndim > 2:
  28. raise ValueError(
  29. "Expected both prediction and target to be 1D or 2D tensors,"
  30. f" but received tensors with dimension {preds.shape}"
  31. )
  32. sum_obs = torch.sum(target, dim=0)
  33. sum_squared_obs = torch.sum(target * target, dim=0)
  34. residual = target - preds
  35. rss = torch.sum(residual * residual, dim=0)
  36. return sum_squared_obs, sum_obs, rss, target.size(0)
  37. def _r2_score_compute(
  38. sum_squared_obs: Tensor,
  39. sum_obs: Tensor,
  40. rss: Tensor,
  41. num_obs: Union[int, Tensor],
  42. adjusted: int = 0,
  43. multioutput: str = "uniform_average",
  44. ) -> Tensor:
  45. """Compute R2 score.
  46. Args:
  47. sum_squared_obs: Sum of square of all observations
  48. sum_obs: Sum of all observations
  49. rss: Residual sum of squares
  50. num_obs: Number of predictions or observations
  51. adjusted: number of independent regressors for calculating adjusted r2 score.
  52. multioutput: Defines aggregation in the case of multiple output scores. Can be one of the following strings:
  53. * `'raw_values'` returns full set of scores
  54. * `'uniform_average'` scores are uniformly averaged
  55. * `'variance_weighted'` scores are weighted by their individual variances
  56. Example:
  57. >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
  58. >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
  59. >>> sum_squared_obs, sum_obs, rss, num_obs = _r2_score_update(preds, target)
  60. >>> _r2_score_compute(sum_squared_obs, sum_obs, rss, num_obs, multioutput="raw_values")
  61. tensor([0.9654, 0.9082])
  62. """
  63. if num_obs < 2:
  64. raise ValueError("Needs at least two samples to calculate r2 score.")
  65. mean_obs = sum_obs / num_obs
  66. tss = sum_squared_obs - sum_obs * mean_obs
  67. # Account for near constant targets
  68. cond_rss = ~torch.isclose(rss, torch.zeros_like(rss), atol=1e-4)
  69. cond_tss = ~torch.isclose(tss, torch.zeros_like(tss), atol=1e-4)
  70. cond = cond_rss & cond_tss
  71. raw_scores = torch.ones_like(rss)
  72. raw_scores[cond] = 1 - (rss[cond] / tss[cond])
  73. raw_scores[cond_rss & ~cond_tss] = 0.0
  74. if multioutput == "raw_values":
  75. r2 = raw_scores
  76. elif multioutput == "uniform_average":
  77. r2 = torch.mean(raw_scores)
  78. elif multioutput == "variance_weighted":
  79. tss_sum = torch.sum(tss)
  80. r2 = torch.sum(tss / tss_sum * raw_scores)
  81. else:
  82. raise ValueError(
  83. "Argument `multioutput` must be either `raw_values`,"
  84. f" `uniform_average` or `variance_weighted`. Received {multioutput}."
  85. )
  86. if adjusted < 0 or not isinstance(adjusted, int):
  87. raise ValueError("`adjusted` parameter should be an integer larger or equal to 0.")
  88. if adjusted != 0:
  89. if adjusted > num_obs - 1:
  90. rank_zero_warn(
  91. "More independent regressions than data points in adjusted r2 score. Falls back to standard r2 score.",
  92. UserWarning,
  93. )
  94. elif adjusted == num_obs - 1:
  95. rank_zero_warn("Division by zero in adjusted r2 score. Falls back to standard r2 score.", UserWarning)
  96. else:
  97. return 1 - (1 - r2) * (num_obs - 1) / (num_obs - adjusted - 1)
  98. return r2
  99. def r2_score(
  100. preds: Tensor,
  101. target: Tensor,
  102. adjusted: int = 0,
  103. multioutput: str = "uniform_average",
  104. ) -> Tensor:
  105. r"""Compute r2 score also known as `R2 Score_Coefficient Determination`_.
  106. .. math:: R^2 = 1 - \frac{SS_{res}}{SS_{tot}}
  107. where :math:`SS_{res}=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and
  108. :math:`SS_{tot}=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate
  109. adjusted r2 score given by
  110. .. math:: R^2_{adj} = 1 - \frac{(1-R^2)(n-1)}{n-k-1}
  111. where the parameter :math:`k` (the number of independent regressors) should
  112. be provided as the ``adjusted`` argument.
  113. Args:
  114. preds: estimated labels
  115. target: ground truth labels
  116. adjusted: number of independent regressors for calculating adjusted r2 score.
  117. multioutput: Defines aggregation in the case of multiple output scores. Can be one of the following strings:
  118. * ``'raw_values'`` returns full set of scores
  119. * ``'uniform_average'`` scores are uniformly averaged
  120. * ``'variance_weighted'`` scores are weighted by their individual variances
  121. Raises:
  122. ValueError:
  123. If both ``preds`` and ``targets`` are not ``1D`` or ``2D`` tensors.
  124. ValueError:
  125. If ``len(preds)`` is less than ``2`` since at least ``2`` samples are needed to calculate r2 score.
  126. ValueError:
  127. If ``multioutput`` is not one of ``raw_values``, ``uniform_average`` or ``variance_weighted``.
  128. ValueError:
  129. If ``adjusted`` is not an ``integer`` greater than ``0``.
  130. Example:
  131. >>> from torchmetrics.functional.regression import r2_score
  132. >>> target = torch.tensor([3, -0.5, 2, 7])
  133. >>> preds = torch.tensor([2.5, 0.0, 2, 8])
  134. >>> r2_score(preds, target)
  135. tensor(0.9486)
  136. >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
  137. >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
  138. >>> r2_score(preds, target, multioutput='raw_values')
  139. tensor([0.9654, 0.9082])
  140. """
  141. sum_squared_obs, sum_obs, rss, num_obs = _r2_score_update(preds, target)
  142. return _r2_score_compute(sum_squared_obs, sum_obs, rss, num_obs, adjusted, multioutput)