pearson.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import torch
  16. from torch import Tensor
  17. from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
  18. from torchmetrics.utilities import rank_zero_warn
  19. from torchmetrics.utilities.checks import _check_same_shape
  20. def _pearson_corrcoef_update(
  21. preds: Tensor,
  22. target: Tensor,
  23. mean_x: Tensor,
  24. mean_y: Tensor,
  25. max_abs_dev_x: Tensor,
  26. max_abs_dev_y: Tensor,
  27. var_x: Tensor,
  28. var_y: Tensor,
  29. corr_xy: Tensor,
  30. num_prior: Tensor,
  31. num_outputs: int,
  32. ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
  33. """Update and returns variables required to compute Pearson Correlation Coefficient.
  34. Check for same shape of input tensors.
  35. Args:
  36. preds: estimated scores
  37. target: ground truth scores
  38. mean_x: current mean estimate of x tensor
  39. mean_y: current mean estimate of y tensor
  40. max_abs_dev_x: current maximum absolute value of x tensor
  41. max_abs_dev_y: current maximum absolute value of y tensor
  42. var_x: current variance estimate of x tensor
  43. var_y: current variance estimate of y tensor
  44. corr_xy: current covariance estimate between x and y tensor
  45. num_prior: current number of observed observations
  46. num_outputs: Number of outputs in multioutput setting
  47. """
  48. # Data checking
  49. _check_same_shape(preds, target)
  50. _check_data_shape_to_num_outputs(preds, target, num_outputs)
  51. num_obs = preds.shape[0]
  52. batch_mean_x = preds.mean(0)
  53. batch_mean_y = target.mean(0)
  54. delta_x = batch_mean_x - mean_x
  55. delta_y = batch_mean_y - mean_y
  56. n_total = num_prior + num_obs
  57. mx_new = mean_x + delta_x * num_obs / n_total
  58. my_new = mean_y + delta_y * num_obs / n_total
  59. if num_obs == 1:
  60. delta2_x = batch_mean_x - mx_new
  61. delta2_y = batch_mean_y - my_new
  62. var_x = var_x + delta2_x * delta_x
  63. var_y = var_y + delta2_y * delta_y
  64. corr_xy = corr_xy + delta_x * delta2_y
  65. else:
  66. preds_centered = preds - batch_mean_x
  67. target_centered = target - batch_mean_y
  68. batch_var_x = (preds_centered**2).sum(0)
  69. batch_var_y = (target_centered**2).sum(0)
  70. batch_cov_xy = (preds_centered * target_centered).sum(0)
  71. correction = num_prior * num_obs / n_total
  72. var_x = var_x + batch_var_x + delta_x**2 * correction
  73. var_y = var_y + batch_var_y + delta_y**2 * correction
  74. corr_xy = corr_xy + batch_cov_xy + delta_x * delta_y * correction
  75. max_abs_dev_x = torch.maximum(max_abs_dev_x, torch.max((preds - mx_new).abs(), dim=0)[0])
  76. max_abs_dev_y = torch.maximum(max_abs_dev_y, torch.max((target - my_new).abs(), dim=0)[0])
  77. return mx_new, my_new, max_abs_dev_x, max_abs_dev_y, var_x, var_y, corr_xy, n_total
  78. def _pearson_corrcoef_compute(
  79. max_abs_dev_x: Tensor,
  80. max_abs_dev_y: Tensor,
  81. var_x: Tensor,
  82. var_y: Tensor,
  83. corr_xy: Tensor,
  84. nb: Tensor,
  85. ) -> Tensor:
  86. """Compute the final pearson correlation based on accumulated statistics.
  87. Args:
  88. max_abs_dev_x: maximum absolute value of x tensor
  89. max_abs_dev_y: maximum absolute value of y tensor
  90. var_x: variance estimate of x tensor
  91. var_y: variance estimate of y tensor
  92. corr_xy: covariance estimate between x and y tensor
  93. nb: number of observations
  94. """
  95. # prevent overwrite the inputs
  96. var_x = var_x / (nb - 1)
  97. var_y = var_y / (nb - 1)
  98. corr_xy = corr_xy / (nb - 1)
  99. # if var_x, var_y is float16 and on cpu, make it bfloat16 as sqrt is not supported for float16
  100. # on cpu, remove this after https://github.com/pytorch/pytorch/issues/54774 is fixed
  101. if var_x.dtype == torch.float16 and var_x.device == torch.device("cpu"):
  102. var_x = var_x.bfloat16()
  103. var_y = var_y.bfloat16()
  104. var_x = var_x * torch.pow(max_abs_dev_x, -2)
  105. var_y = var_y * torch.pow(max_abs_dev_y, -2)
  106. corr_xy = corr_xy / (max_abs_dev_x * max_abs_dev_y)
  107. bound = math.sqrt(torch.finfo(var_x.dtype).eps)
  108. if (
  109. (var_x < bound).any()
  110. or (var_y < bound).any()
  111. or ~torch.isfinite(var_x).any()
  112. or ~torch.isfinite(var_y).any()
  113. or ~torch.isfinite(corr_xy).any()
  114. ):
  115. rank_zero_warn(
  116. "The variance of predictions or target is close to zero. This can cause instability in Pearson correlation"
  117. "coefficient, leading to wrong results. Consider re-scaling the input if possible or computing using a"
  118. f"larger dtype (currently using {var_x.dtype}). Setting the correlation coefficient to nan.",
  119. UserWarning,
  120. )
  121. zero_var_mask = (
  122. (var_x < bound) | (var_y < bound) | ~torch.isfinite(var_x) | ~torch.isfinite(var_y) | ~torch.isfinite(corr_xy)
  123. )
  124. corrcoef = torch.full_like(corr_xy, float("nan"), device=corr_xy.device, dtype=corr_xy.dtype)
  125. valid_mask = ~zero_var_mask
  126. if valid_mask.any():
  127. corrcoef[valid_mask] = (
  128. (corr_xy[valid_mask] / (var_x[valid_mask] * var_y[valid_mask]).sqrt()).squeeze().to(corrcoef.dtype)
  129. )
  130. corrcoef = torch.clamp(corrcoef, -1.0, 1.0)
  131. return corrcoef.squeeze()
  132. def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor:
  133. """Compute pearson correlation coefficient.
  134. Args:
  135. preds: estimated scores
  136. target: ground truth scores
  137. Example (single output regression):
  138. >>> from torchmetrics.functional.regression import pearson_corrcoef
  139. >>> target = torch.tensor([3, -0.5, 2, 7])
  140. >>> preds = torch.tensor([2.5, 0.0, 2, 8])
  141. >>> pearson_corrcoef(preds, target)
  142. tensor(0.9849)
  143. Example (multi output regression):
  144. >>> from torchmetrics.functional.regression import pearson_corrcoef
  145. >>> target = torch.tensor([[3, -0.5], [2, 7]])
  146. >>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
  147. >>> pearson_corrcoef(preds, target)
  148. tensor([1., 1.])
  149. """
  150. d = preds.shape[1] if preds.ndim == 2 else 1
  151. _temp = torch.zeros(d, dtype=preds.dtype, device=preds.device)
  152. mean_x, mean_y, var_x = _temp.clone(), _temp.clone(), _temp.clone()
  153. var_y, corr_xy, nb = _temp.clone(), _temp.clone(), _temp.clone()
  154. max_abs_dev_x, max_abs_dev_y = _temp.clone(), _temp.clone()
  155. _, _, max_abs_dev_x, max_abs_dev_y, var_x, var_y, corr_xy, nb = _pearson_corrcoef_update(
  156. preds=preds,
  157. target=target,
  158. mean_x=mean_x,
  159. mean_y=mean_y,
  160. max_abs_dev_x=max_abs_dev_x,
  161. max_abs_dev_y=max_abs_dev_y,
  162. var_x=var_x,
  163. var_y=var_y,
  164. corr_xy=corr_xy,
  165. num_prior=nb,
  166. num_outputs=1 if preds.ndim == 1 else preds.shape[-1],
  167. )
  168. return _pearson_corrcoef_compute(max_abs_dev_x, max_abs_dev_y, var_x, var_y, corr_xy, nb)