pearson.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, List, Optional, Union
  16. import torch
  17. from torch import Tensor
  18. from torchmetrics.functional.regression.pearson import _pearson_corrcoef_compute, _pearson_corrcoef_update
  19. from torchmetrics.metric import Metric
  20. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  21. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  22. if not _MATPLOTLIB_AVAILABLE:
  23. __doctest_skip__ = ["PearsonCorrCoef.plot"]
  24. def _final_aggregation(
  25. means_x: torch.Tensor,
  26. means_y: torch.Tensor,
  27. maxs_abs_x: torch.Tensor,
  28. maxs_abs_y: torch.Tensor,
  29. vars_x: torch.Tensor,
  30. vars_y: torch.Tensor,
  31. corrs_xy: torch.Tensor,
  32. nbs: torch.Tensor,
  33. eps: float = 1e-10,
  34. ) -> tuple[
  35. torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
  36. ]:
  37. """Aggregate the statistics from multiple devices.
  38. Formula taken from here: `Parallel algorithm for calculating variance
  39. <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm>`_
  40. We use `eps` to avoid division by zero when `n1` and `n2` are both zero. Generally, the value of `eps` should not
  41. matter, as if `n1` and `n2` are both zero, all the states will also be zero.
  42. """
  43. if len(means_x) == 1:
  44. return means_x[0], means_y[0], maxs_abs_x[0], maxs_abs_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
  45. mx1 = means_x[0]
  46. my1 = means_y[0]
  47. max1 = maxs_abs_x[0]
  48. may1 = maxs_abs_y[0]
  49. vx1 = vars_x[0]
  50. vy1 = vars_y[0]
  51. cxy1 = corrs_xy[0]
  52. n1 = nbs[0]
  53. for i in range(1, len(means_x)):
  54. mx2 = means_x[i]
  55. my2 = means_y[i]
  56. max2 = maxs_abs_x[i]
  57. may2 = maxs_abs_y[i]
  58. vx2 = vars_x[i]
  59. vy2 = vars_y[i]
  60. cxy2 = corrs_xy[i]
  61. n2 = nbs[i]
  62. # count
  63. nb = torch.where(torch.logical_or(n1, n2), n1 + n2, eps)
  64. # mean_x
  65. mean_x = (n1 * mx1 + n2 * mx2) / nb
  66. # mean_y
  67. mean_y = (n1 * my1 + n2 * my2) / nb
  68. # intermediates for running variances
  69. n12_b = n1 * n2 / nb
  70. delta_x = mx2 - mx1
  71. delta_y = my2 - my1
  72. # var_x
  73. var_x = vx1 + vx2 + n12_b * delta_x**2
  74. # var_y
  75. var_y = vy1 + vy2 + n12_b * delta_y**2
  76. # corr_xy
  77. corr_xy = cxy1 + cxy2 + n12_b * delta_x * delta_y
  78. max_abs_dev_x = torch.maximum(max1, max2)
  79. max_abs_dev_y = torch.maximum(may1, may2)
  80. mx1 = mean_x
  81. my1 = mean_y
  82. max1 = max_abs_dev_x
  83. may1 = max_abs_dev_y
  84. vx1 = var_x
  85. vy1 = var_y
  86. cxy1 = corr_xy
  87. n1 = nb
  88. return mean_x, mean_y, max_abs_dev_x, max_abs_dev_y, var_x, var_y, corr_xy, nb
  89. class PearsonCorrCoef(Metric):
  90. r"""Compute `Pearson Correlation Coefficient`_.
  91. .. math::
  92. P_{corr}(x,y) = \frac{cov(x,y)}{\sigma_x \sigma_y}
  93. Where :math:`y` is a tensor of target values, and :math:`x` is a tensor of predictions.
  94. As input to ``forward`` and ``update`` the metric accepts the following input:
  95. - ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)``
  96. or multioutput float tensor of shape ``(N,d)``
  97. - ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)``
  98. or multioutput tensor of shape ``(N,d)``
  99. As output of ``forward`` and ``compute`` the metric returns the following output:
  100. - ``pearson`` (:class:`~torch.Tensor`): A tensor with the Pearson Correlation Coefficient
  101. Args:
  102. num_outputs: Number of outputs in multioutput setting
  103. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  104. Example (single output regression):
  105. >>> from torchmetrics.regression import PearsonCorrCoef
  106. >>> target = torch.tensor([3, -0.5, 2, 7])
  107. >>> preds = torch.tensor([2.5, 0.0, 2, 8])
  108. >>> pearson = PearsonCorrCoef()
  109. >>> pearson(preds, target)
  110. tensor(0.9849)
  111. Example (multi output regression):
  112. >>> from torchmetrics.regression import PearsonCorrCoef
  113. >>> target = torch.tensor([[3, -0.5], [2, 7]])
  114. >>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
  115. >>> pearson = PearsonCorrCoef(num_outputs=2)
  116. >>> pearson(preds, target)
  117. tensor([1., 1.])
  118. """
  119. is_differentiable: bool = True
  120. higher_is_better: Optional[bool] = None # both -1 and 1 are optimal
  121. full_state_update: bool = True
  122. plot_lower_bound: float = -1.0
  123. plot_upper_bound: float = 1.0
  124. preds: List[Tensor]
  125. target: List[Tensor]
  126. mean_x: Tensor
  127. mean_y: Tensor
  128. max_abs_dev_x: Tensor
  129. max_abs_dev_y: Tensor
  130. var_x: Tensor
  131. var_y: Tensor
  132. corr_xy: Tensor
  133. n_total: Tensor
  134. def __init__(
  135. self,
  136. num_outputs: int = 1,
  137. **kwargs: Any,
  138. ) -> None:
  139. super().__init__(**kwargs)
  140. if not isinstance(num_outputs, int) and num_outputs < 1:
  141. raise ValueError("Expected argument `num_outputs` to be an int larger than 0, but got {num_outputs}")
  142. self.num_outputs = num_outputs
  143. self.add_state("mean_x", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
  144. self.add_state("mean_y", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
  145. self.add_state("max_abs_dev_x", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
  146. self.add_state("max_abs_dev_y", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
  147. self.add_state("var_x", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
  148. self.add_state("var_y", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
  149. self.add_state("corr_xy", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
  150. self.add_state("n_total", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
  151. def update(self, preds: Tensor, target: Tensor) -> None:
  152. """Update state with predictions and targets."""
  153. (
  154. self.mean_x,
  155. self.mean_y,
  156. self.max_abs_dev_x,
  157. self.max_abs_dev_y,
  158. self.var_x,
  159. self.var_y,
  160. self.corr_xy,
  161. self.n_total,
  162. ) = _pearson_corrcoef_update(
  163. preds=preds,
  164. target=target,
  165. mean_x=self.mean_x,
  166. mean_y=self.mean_y,
  167. max_abs_dev_x=self.max_abs_dev_x,
  168. max_abs_dev_y=self.max_abs_dev_y,
  169. var_x=self.var_x,
  170. var_y=self.var_y,
  171. corr_xy=self.corr_xy,
  172. num_prior=self.n_total,
  173. num_outputs=self.num_outputs,
  174. )
  175. def compute(self) -> Tensor:
  176. """Compute pearson correlation coefficient over state."""
  177. if (self.num_outputs == 1 and self.mean_x.numel() > 1) or (self.num_outputs > 1 and self.mean_x.ndim > 1):
  178. # multiple devices, need further reduction
  179. _, _, max_abs_dev_x, max_abs_dev_y, var_x, var_y, corr_xy, n_total = _final_aggregation(
  180. means_x=self.mean_x,
  181. means_y=self.mean_y,
  182. maxs_abs_x=self.max_abs_dev_x,
  183. maxs_abs_y=self.max_abs_dev_y,
  184. vars_x=self.var_x,
  185. vars_y=self.var_y,
  186. corrs_xy=self.corr_xy,
  187. nbs=self.n_total,
  188. )
  189. else:
  190. max_abs_dev_x = self.max_abs_dev_x
  191. max_abs_dev_y = self.max_abs_dev_y
  192. var_x = self.var_x
  193. var_y = self.var_y
  194. corr_xy = self.corr_xy
  195. n_total = self.n_total
  196. return _pearson_corrcoef_compute(max_abs_dev_x, max_abs_dev_y, var_x, var_y, corr_xy, n_total)
  197. def plot(
  198. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  199. ) -> _PLOT_OUT_TYPE:
  200. """Plot a single or multiple values from the metric.
  201. Args:
  202. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  203. If no value is provided, will automatically call `metric.compute` and plot that result.
  204. ax: An matplotlib axis object. If provided will add plot to that axis
  205. Returns:
  206. Figure and Axes object
  207. Raises:
  208. ModuleNotFoundError:
  209. If `matplotlib` is not installed
  210. .. plot::
  211. :scale: 75
  212. >>> from torch import randn
  213. >>> # Example plotting a single value
  214. >>> from torchmetrics.regression import PearsonCorrCoef
  215. >>> metric = PearsonCorrCoef()
  216. >>> metric.update(randn(10,), randn(10,))
  217. >>> fig_, ax_ = metric.plot()
  218. .. plot::
  219. :scale: 75
  220. >>> from torch import randn
  221. >>> # Example plotting multiple values
  222. >>> from torchmetrics.regression import PearsonCorrCoef
  223. >>> metric = PearsonCorrCoef()
  224. >>> values = []
  225. >>> for _ in range(10):
  226. ... values.append(metric(randn(10,), randn(10,)))
  227. >>> fig, ax = metric.plot(values)
  228. """
  229. return self._plot(val, ax)