kendall.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Optional, Union
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
  19. from torchmetrics.utilities.checks import _check_same_shape
  20. from torchmetrics.utilities.data import _bincount, _cumsum, dim_zero_cat
  21. from torchmetrics.utilities.enums import EnumStr
  22. class _MetricVariant(EnumStr):
  23. """Enumerate for metric variants."""
  24. A = "a"
  25. B = "b"
  26. C = "c"
  27. @staticmethod
  28. def _name() -> str:
  29. return "variant"
  30. class _TestAlternative(EnumStr):
  31. """Enumerate for test alternative options."""
  32. TWO_SIDED = "two-sided"
  33. LESS = "less"
  34. GREATER = "greater"
  35. @staticmethod
  36. def _name() -> str:
  37. return "alternative"
  38. def _sort_on_first_sequence(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
  39. """Sort sequences in an ascent order according to the sequence ``x``."""
  40. # We need to clone `y` tensor not to change an object in memory
  41. y = torch.clone(y)
  42. x, y = x.T, y.T
  43. x, perm = x.sort()
  44. for i in range(x.shape[0]):
  45. y[i] = y[i][perm[i]]
  46. return x.T, y.T
  47. def _concordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor:
  48. """Count a total number of concordant pairs in a single sequence."""
  49. return torch.logical_and(x[i] < x[(i + 1) :], y[i] < y[(i + 1) :]).sum(0).unsqueeze(0)
  50. def _count_concordant_pairs(preds: Tensor, target: Tensor) -> Tensor:
  51. """Count a total number of concordant pairs in given sequences."""
  52. return torch.cat([_concordant_element_sum(preds, target, i) for i in range(preds.shape[0])]).sum(0)
  53. def _discordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor:
  54. """Count a total number of discordant pairs in a single sequences."""
  55. return (
  56. torch.logical_or(
  57. torch.logical_and(x[i] > x[(i + 1) :], y[i] < y[(i + 1) :]),
  58. torch.logical_and(x[i] < x[(i + 1) :], y[i] > y[(i + 1) :]),
  59. )
  60. .sum(0)
  61. .unsqueeze(0)
  62. )
  63. def _count_discordant_pairs(preds: Tensor, target: Tensor) -> Tensor:
  64. """Count a total number of discordant pairs in given sequences."""
  65. return torch.cat([_discordant_element_sum(preds, target, i) for i in range(preds.shape[0])]).sum(0)
  66. def _convert_sequence_to_dense_rank(x: Tensor, sort: bool = False) -> Tensor:
  67. """Convert a sequence to the rank tensor."""
  68. # Sort if a sequence has not been sorted before
  69. if sort:
  70. x = x.sort(dim=0).values
  71. _ones = torch.zeros(1, x.shape[1], dtype=torch.int32, device=x.device)
  72. return _cumsum(torch.cat([_ones, (x[1:] != x[:-1]).int()], dim=0), dim=0)
  73. def _get_ties(x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
  74. """Get a total number of ties and staistics for p-value calculation for a given sequence."""
  75. ties = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device)
  76. ties_p1 = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device)
  77. ties_p2 = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device)
  78. for dim in range(x.shape[1]):
  79. n_ties = _bincount(x[:, dim])
  80. n_ties = n_ties[n_ties > 1]
  81. ties[dim] = (n_ties * (n_ties - 1) // 2).sum()
  82. ties_p1[dim] = (n_ties * (n_ties - 1.0) * (n_ties - 2)).sum()
  83. ties_p2[dim] = (n_ties * (n_ties - 1.0) * (2 * n_ties + 5)).sum()
  84. return ties, ties_p1, ties_p2
  85. def _get_metric_metadata(
  86. preds: Tensor, target: Tensor, variant: _MetricVariant
  87. ) -> tuple[
  88. Tensor,
  89. Tensor,
  90. Optional[Tensor],
  91. Optional[Tensor],
  92. Optional[Tensor],
  93. Optional[Tensor],
  94. Optional[Tensor],
  95. Optional[Tensor],
  96. Tensor,
  97. ]:
  98. """Obtain statistics to calculate metric value."""
  99. preds, target = _sort_on_first_sequence(preds, target)
  100. concordant_pairs = _count_concordant_pairs(preds, target)
  101. discordant_pairs = _count_discordant_pairs(preds, target)
  102. n_total = torch.tensor(preds.shape[0], device=preds.device)
  103. preds_ties = target_ties = None
  104. preds_ties_p1 = preds_ties_p2 = target_ties_p1 = target_ties_p2 = None
  105. if variant != _MetricVariant.A:
  106. preds = _convert_sequence_to_dense_rank(preds)
  107. target = _convert_sequence_to_dense_rank(target, sort=True)
  108. preds_ties, preds_ties_p1, preds_ties_p2 = _get_ties(preds)
  109. target_ties, target_ties_p1, target_ties_p2 = _get_ties(target)
  110. return (
  111. concordant_pairs,
  112. discordant_pairs,
  113. preds_ties,
  114. preds_ties_p1,
  115. preds_ties_p2,
  116. target_ties,
  117. target_ties_p1,
  118. target_ties_p2,
  119. n_total,
  120. )
  121. def _calculate_tau(
  122. preds: Tensor,
  123. target: Tensor,
  124. concordant_pairs: Tensor,
  125. discordant_pairs: Tensor,
  126. con_min_dis_pairs: Tensor,
  127. n_total: Tensor,
  128. preds_ties: Optional[Tensor],
  129. target_ties: Optional[Tensor],
  130. variant: _MetricVariant,
  131. ) -> Tensor:
  132. """Calculate Kendall's tau from metric metadata."""
  133. if variant == _MetricVariant.A:
  134. return con_min_dis_pairs / (concordant_pairs + discordant_pairs)
  135. if variant == _MetricVariant.B:
  136. total_combinations: Tensor = n_total * (n_total - 1) // 2
  137. if preds_ties is None:
  138. preds_ties = torch.tensor(0.0, dtype=total_combinations.dtype, device=total_combinations.device)
  139. if target_ties is None:
  140. target_ties = torch.tensor(0.0, dtype=total_combinations.dtype, device=total_combinations.device)
  141. denominator = (total_combinations - preds_ties) * (total_combinations - target_ties)
  142. return con_min_dis_pairs / torch.sqrt(denominator)
  143. preds_unique = torch.tensor([len(p.unique()) for p in preds.T], dtype=preds.dtype, device=preds.device)
  144. target_unique = torch.tensor([len(t.unique()) for t in target.T], dtype=target.dtype, device=target.device)
  145. min_classes = torch.minimum(preds_unique, target_unique)
  146. return 2 * con_min_dis_pairs / ((min_classes - 1) / min_classes * n_total**2)
  147. def _get_p_value_for_t_value_from_dist(t_value: Tensor) -> Tensor:
  148. """Obtain p-value for a given Tensor of t-values. Handle ``nan`` which cannot be passed into torch distributions.
  149. When t-value is ``nan``, a resulted p-value should be alson ``nan``.
  150. """
  151. device = t_value
  152. normal_dist = torch.distributions.normal.Normal(torch.tensor([0.0]).to(device), torch.tensor([1.0]).to(device))
  153. is_nan = t_value.isnan()
  154. t_value = t_value.nan_to_num()
  155. p_value = normal_dist.cdf(t_value)
  156. return p_value.where(~is_nan, torch.tensor(float("nan"), dtype=p_value.dtype, device=p_value.device))
  157. def _calculate_p_value(
  158. con_min_dis_pairs: Tensor,
  159. n_total: Tensor,
  160. preds_ties: Optional[Tensor],
  161. preds_ties_p1: Optional[Tensor],
  162. preds_ties_p2: Optional[Tensor],
  163. target_ties: Optional[Tensor],
  164. target_ties_p1: Optional[Tensor],
  165. target_ties_p2: Optional[Tensor],
  166. variant: _MetricVariant,
  167. alternative: Optional[_TestAlternative],
  168. ) -> Tensor:
  169. """Calculate p-value for Kendall's tau from metric metadata."""
  170. t_value_denominator_base = n_total * (n_total - 1) * (2 * n_total + 5)
  171. if variant == _MetricVariant.A:
  172. t_value = 3 * con_min_dis_pairs / torch.sqrt(t_value_denominator_base / 2)
  173. else:
  174. m = n_total * (n_total - 1)
  175. t_value_denominator: Tensor = (
  176. t_value_denominator_base
  177. - (preds_ties_p2 if preds_ties_p2 is not None else 0)
  178. - (target_ties_p2 if target_ties_p2 is not None else 0)
  179. ) / 18
  180. t_value_denominator += (
  181. 2 * (preds_ties if preds_ties is not None else 0) * (target_ties if target_ties is not None else 0)
  182. ) / m
  183. t_value_denominator += (
  184. (preds_ties_p1 if preds_ties_p1 is not None else 0)
  185. * (target_ties_p1 if target_ties_p1 is not None else 0)
  186. / (9 * m * (n_total - 2))
  187. )
  188. t_value = con_min_dis_pairs / torch.sqrt(t_value_denominator)
  189. if alternative == _TestAlternative.TWO_SIDED:
  190. t_value = torch.abs(t_value)
  191. if alternative in [_TestAlternative.TWO_SIDED, _TestAlternative.GREATER]:
  192. t_value *= -1
  193. p_value = _get_p_value_for_t_value_from_dist(t_value)
  194. if alternative == _TestAlternative.TWO_SIDED:
  195. p_value *= 2
  196. return p_value
  197. def _kendall_corrcoef_update(
  198. preds: Tensor,
  199. target: Tensor,
  200. concat_preds: Optional[List[Tensor]] = None,
  201. concat_target: Optional[List[Tensor]] = None,
  202. num_outputs: int = 1,
  203. ) -> tuple[List[Tensor], List[Tensor]]:
  204. """Update variables required to compute Kendall rank correlation coefficient.
  205. Args:
  206. preds: Sequence of data
  207. target: Sequence of data
  208. concat_preds: List of batches of preds sequence to be concatenated
  209. concat_target: List of batches of target sequence to be concatenated
  210. num_outputs: Number of outputs in multioutput setting
  211. Raises:
  212. RuntimeError: If ``preds`` and ``target`` do not have the same shape
  213. """
  214. concat_preds = concat_preds or []
  215. concat_target = concat_target or []
  216. # Data checking
  217. _check_same_shape(preds, target)
  218. _check_data_shape_to_num_outputs(preds, target, num_outputs)
  219. if num_outputs == 1:
  220. preds = preds.unsqueeze(1)
  221. target = target.unsqueeze(1)
  222. concat_preds.append(preds)
  223. concat_target.append(target)
  224. return concat_preds, concat_target
  225. def _kendall_corrcoef_compute(
  226. preds: Tensor,
  227. target: Tensor,
  228. variant: _MetricVariant,
  229. alternative: Optional[_TestAlternative] = None,
  230. ) -> tuple[Tensor, Optional[Tensor]]:
  231. """Compute Kendall rank correlation coefficient, and optionally p-value of corresponding statistical test.
  232. Args:
  233. Args:
  234. preds: Sequence of data
  235. target: Sequence of data
  236. variant: Indication of which variant of Kendall's tau to be used
  237. alternative: Alternative hypothesis for for t-test. Possible values:
  238. - 'two-sided': the rank correlation is nonzero
  239. - 'less': the rank correlation is negative (less than zero)
  240. - 'greater': the rank correlation is positive (greater than zero)
  241. """
  242. (
  243. concordant_pairs,
  244. discordant_pairs,
  245. preds_ties,
  246. preds_ties_p1,
  247. preds_ties_p2,
  248. target_ties,
  249. target_ties_p1,
  250. target_ties_p2,
  251. n_total,
  252. ) = _get_metric_metadata(preds, target, variant)
  253. con_min_dis_pairs = concordant_pairs - discordant_pairs
  254. tau = _calculate_tau(
  255. preds, target, concordant_pairs, discordant_pairs, con_min_dis_pairs, n_total, preds_ties, target_ties, variant
  256. )
  257. p_value = (
  258. _calculate_p_value(
  259. con_min_dis_pairs,
  260. n_total,
  261. preds_ties,
  262. preds_ties_p1,
  263. preds_ties_p2,
  264. target_ties,
  265. target_ties_p1,
  266. target_ties_p2,
  267. variant,
  268. alternative,
  269. )
  270. if alternative
  271. else None
  272. )
  273. # Squeeze tensor if num_outputs=1
  274. if tau.shape[0] == 1:
  275. tau = tau.squeeze()
  276. p_value = p_value.squeeze() if p_value is not None else None
  277. return tau.clamp(-1, 1), p_value
  278. def kendall_rank_corrcoef(
  279. preds: Tensor,
  280. target: Tensor,
  281. variant: Literal["a", "b", "c"] = "b",
  282. t_test: bool = False,
  283. alternative: Optional[Literal["two-sided", "less", "greater"]] = "two-sided",
  284. ) -> Union[Tensor, tuple[Tensor, Tensor]]:
  285. r"""Compute `Kendall Rank Correlation Coefficient`_.
  286. .. math::
  287. tau_a = \frac{C - D}{C + D}
  288. where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs.
  289. .. math::
  290. tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})}}
  291. where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs and :math:`T` represents
  292. a total number of ties.
  293. .. math::
  294. tau_c = 2 * \frac{C - D}{n^2 * \frac{m - 1}{m}}
  295. where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs, :math:`n` is a total number
  296. of observations and :math:`m` is a ``min`` of unique values in ``preds`` and ``target`` sequence.
  297. Definitions according to Definition according to `The Treatment of Ties in Ranking Problems`_.
  298. Args:
  299. preds: Sequence of data of either shape ``(N,)`` or ``(N,d)``
  300. target: Sequence of data of either shape ``(N,)`` or ``(N,d)``
  301. variant: Indication of which variant of Kendall's tau to be used
  302. t_test: Indication whether to run t-test
  303. alternative: Alternative hypothesis for t-test. Possible values:
  304. - 'two-sided': the rank correlation is nonzero
  305. - 'less': the rank correlation is negative (less than zero)
  306. - 'greater': the rank correlation is positive (greater than zero)
  307. Return:
  308. Correlation tau statistic
  309. (Optional) p-value of corresponding statistical test (asymptotic)
  310. Raises:
  311. ValueError: If ``t_test`` is not of a type bool
  312. ValueError: If ``t_test=True`` and ``alternative=None``
  313. Example (single output regression):
  314. >>> from torchmetrics.functional.regression import kendall_rank_corrcoef
  315. >>> preds = torch.tensor([2.5, 0.0, 2, 8])
  316. >>> target = torch.tensor([3, -0.5, 2, 1])
  317. >>> kendall_rank_corrcoef(preds, target)
  318. tensor(0.3333)
  319. Example (multi output regression):
  320. >>> from torchmetrics.functional.regression import kendall_rank_corrcoef
  321. >>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
  322. >>> target = torch.tensor([[3, -0.5], [2, 1]])
  323. >>> kendall_rank_corrcoef(preds, target)
  324. tensor([1., 1.])
  325. Example (single output regression with t-test)
  326. >>> from torchmetrics.functional.regression import kendall_rank_corrcoef
  327. >>> preds = torch.tensor([2.5, 0.0, 2, 8])
  328. >>> target = torch.tensor([3, -0.5, 2, 1])
  329. >>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided')
  330. (tensor(0.3333), tensor(0.4969))
  331. Example (multi output regression with t-test):
  332. >>> from torchmetrics.functional.regression import kendall_rank_corrcoef
  333. >>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
  334. >>> target = torch.tensor([[3, -0.5], [2, 1]])
  335. >>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided')
  336. (tensor([1., 1.]), tensor([nan, nan]))
  337. """
  338. if not isinstance(t_test, bool):
  339. raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.")
  340. if t_test and alternative is None:
  341. raise ValueError("Argument `alternative` is required if `t_test=True` but got `None`.")
  342. _variant = _MetricVariant.from_str(str(variant))
  343. _alternative = _TestAlternative.from_str(str(alternative)) if t_test else None
  344. _preds, _target = _kendall_corrcoef_update(
  345. preds, target, [], [], num_outputs=1 if preds.ndim == 1 else preds.shape[-1]
  346. )
  347. tau, p_value = _kendall_corrcoef_compute(
  348. dim_zero_cat(_preds),
  349. dim_zero_cat(_target),
  350. _variant, # type: ignore[arg-type] # todo
  351. _alternative, # type: ignore[arg-type] # todo
  352. )
  353. if p_value is not None:
  354. return tau, p_value
  355. return tau