spearman.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from torch import Tensor
  16. from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
  17. from torchmetrics.utilities.checks import _check_same_shape
  18. def _rank_data(data: Tensor) -> Tensor:
  19. """Calculate the rank for each element of a tensor.
  20. The rank refers to the indices of an element in the corresponding sorted tensor (starting from 1). Duplicates of the
  21. same value will be assigned the mean of their rank.
  22. Adopted from `Rank of element tensor`_
  23. """
  24. n = data.numel()
  25. rank = torch.empty_like(data, dtype=torch.int32)
  26. idx = data.argsort()
  27. rank[idx[:n]] = torch.arange(1, n + 1, dtype=torch.int32, device=data.device)
  28. uniq, inv, counts = torch.unique(data, sorted=True, return_inverse=True, return_counts=True)
  29. sum_ranks = torch.zeros_like(uniq, dtype=torch.int32)
  30. sum_ranks.scatter_add_(0, inv, rank.to(torch.int32))
  31. mean_ranks = sum_ranks / counts
  32. return mean_ranks[inv]
  33. def _spearman_corrcoef_update(preds: Tensor, target: Tensor, num_outputs: int) -> tuple[Tensor, Tensor]:
  34. """Update and returns variables required to compute Spearman Correlation Coefficient.
  35. Check for same shape and type of input tensors.
  36. Args:
  37. preds: Predicted tensor
  38. target: Ground truth tensor
  39. num_outputs: Number of outputs in multioutput setting
  40. """
  41. if not (preds.is_floating_point() and target.is_floating_point()):
  42. raise TypeError(
  43. "Expected `preds` and `target` both to be floating point tensors, but got {pred.dtype} and {target.dtype}"
  44. )
  45. _check_same_shape(preds, target)
  46. _check_data_shape_to_num_outputs(preds, target, num_outputs)
  47. return preds, target
  48. def _spearman_corrcoef_compute(preds: Tensor, target: Tensor, eps: float = 1e-6) -> Tensor:
  49. """Compute Spearman Correlation Coefficient.
  50. Args:
  51. preds: Predicted tensor
  52. target: Ground truth tensor
  53. eps: Avoids ``ZeroDivisionError``.
  54. Example:
  55. >>> target = torch.tensor([3, -0.5, 2, 7])
  56. >>> preds = torch.tensor([2.5, 0.0, 2, 8])
  57. >>> preds, target = _spearman_corrcoef_update(preds, target, num_outputs=1)
  58. >>> _spearman_corrcoef_compute(preds, target)
  59. tensor(1.0000)
  60. """
  61. if preds.ndim == 1:
  62. preds = _rank_data(preds)
  63. target = _rank_data(target)
  64. else:
  65. preds = torch.stack([_rank_data(p) for p in preds.T]).T
  66. target = torch.stack([_rank_data(t) for t in target.T]).T
  67. preds_diff = preds - preds.mean(0)
  68. target_diff = target - target.mean(0)
  69. cov = (preds_diff * target_diff).mean(0)
  70. preds_std = torch.sqrt((preds_diff * preds_diff).mean(0))
  71. target_std = torch.sqrt((target_diff * target_diff).mean(0))
  72. corrcoef = cov / (preds_std * target_std + eps)
  73. return torch.clamp(corrcoef, -1.0, 1.0)
  74. def spearman_corrcoef(preds: Tensor, target: Tensor) -> Tensor:
  75. r"""Compute `spearmans rank correlation coefficient`_.
  76. .. math:
  77. r_s = = \frac{cov(rg_x, rg_y)}{\sigma_{rg_x} * \sigma_{rg_y}}
  78. where :math:`rg_x` and :math:`rg_y` are the rank associated to the variables x and y. Spearmans correlations
  79. coefficient corresponds to the standard pearsons correlation coefficient calculated on the rank variables.
  80. Args:
  81. preds: estimated scores
  82. target: ground truth scores
  83. Example (single output regression):
  84. >>> from torchmetrics.functional.regression import spearman_corrcoef
  85. >>> target = torch.tensor([3, -0.5, 2, 7])
  86. >>> preds = torch.tensor([2.5, 0.0, 2, 8])
  87. >>> spearman_corrcoef(preds, target)
  88. tensor(1.0000)
  89. Example (multi output regression):
  90. >>> from torchmetrics.functional.regression import spearman_corrcoef
  91. >>> target = torch.tensor([[3, -0.5], [2, 7]])
  92. >>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
  93. >>> spearman_corrcoef(preds, target)
  94. tensor([1.0000, 1.0000])
  95. """
  96. preds, target = _spearman_corrcoef_update(preds, target, num_outputs=1 if preds.ndim == 1 else preds.shape[-1])
  97. return _spearman_corrcoef_compute(preds, target)