theils_u.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import itertools
  15. from typing import Optional
  16. import torch
  17. from torch import Tensor
  18. from typing_extensions import Literal
  19. from torchmetrics.functional.classification.confusion_matrix import _multiclass_confusion_matrix_update
  20. from torchmetrics.functional.nominal.utils import (
  21. _drop_empty_rows_and_cols,
  22. _handle_nan_in_data,
  23. _nominal_input_validation,
  24. )
  25. def _conditional_entropy_compute(confmat: Tensor) -> Tensor:
  26. r"""Compute Conditional Entropy Statistic based on a pre-computed confusion matrix.
  27. .. math::
  28. H(X|Y) = \sum_{x, y ~ (X, Y)} p(x, y)\frac{p(y)}{p(x, y)}
  29. Args:
  30. confmat: Confusion matrix for observed data
  31. Returns:
  32. Conditional Entropy Value
  33. """
  34. confmat = _drop_empty_rows_and_cols(confmat)
  35. total_occurrences = confmat.sum()
  36. # iterate over all i, j combinations
  37. p_xy_m = confmat / total_occurrences
  38. # get p_y by summing over x dim (=1)
  39. p_y = confmat.sum(1) / total_occurrences
  40. # repeat over rows (shape = p_xy_m.shape[1]) for tensor multiplication
  41. p_y_m = p_y.unsqueeze(1).repeat(1, p_xy_m.shape[1])
  42. # entropy calculated as p_xy * log (p_xy / p_y)
  43. return torch.nansum(p_xy_m * torch.log(p_y_m / p_xy_m))
  44. def _theils_u_update(
  45. preds: Tensor,
  46. target: Tensor,
  47. num_classes: int,
  48. nan_strategy: Literal["replace", "drop"] = "replace",
  49. nan_replace_value: Optional[float] = 0.0,
  50. ) -> Tensor:
  51. """Compute the bins to update the confusion matrix with for Theil's U calculation.
  52. Args:
  53. preds: 1D or 2D tensor of categorical (nominal) data
  54. target: 1D or 2D tensor of categorical (nominal) data
  55. num_classes: Integer specifying the number of classes
  56. nan_strategy: Indication of whether to replace or drop ``NaN`` values
  57. nan_replace_value: Value to replace ``NaN`s when ``nan_strategy = 'replace```
  58. Returns:
  59. Non-reduced confusion matrix
  60. """
  61. preds = preds.argmax(1) if preds.ndim == 2 else preds
  62. target = target.argmax(1) if target.ndim == 2 else target
  63. preds, target = _handle_nan_in_data(preds, target, nan_strategy, nan_replace_value)
  64. return _multiclass_confusion_matrix_update(preds, target, num_classes)
  65. def _theils_u_compute(confmat: Tensor) -> Tensor:
  66. """Compute Theil's U statistic based on a pre-computed confusion matrix.
  67. Args:
  68. confmat: Confusion matrix for observed data
  69. Returns:
  70. Theil's U statistic
  71. """
  72. confmat = _drop_empty_rows_and_cols(confmat)
  73. # compute conditional entropy
  74. s_xy = _conditional_entropy_compute(confmat)
  75. # compute H(x)
  76. total_occurrences = confmat.sum()
  77. p_x = confmat.sum(0) / total_occurrences
  78. s_x = -torch.sum(p_x * torch.log(p_x))
  79. # compute u statistic
  80. if s_x == 0:
  81. return torch.tensor(0, device=confmat.device)
  82. return (s_x - s_xy) / s_x
  83. def theils_u(
  84. preds: Tensor,
  85. target: Tensor,
  86. nan_strategy: Literal["replace", "drop"] = "replace",
  87. nan_replace_value: Optional[float] = 0.0,
  88. ) -> Tensor:
  89. r"""Compute `Theils Uncertainty coefficient`_ statistic measuring the association between two nominal data series.
  90. .. math::
  91. U(X|Y) = \frac{H(X) - H(X|Y)}{H(X)}
  92. where :math:`H(X)` is entropy of variable :math:`X` while :math:`H(X|Y)` is the conditional entropy of :math:`X`
  93. given :math:`Y`.
  94. Theils's U is an asymmetric coefficient, i.e. :math:`TheilsU(preds, target) \neq TheilsU(target, preds)`.
  95. The output values lies in [0, 1]. 0 means y has no information about x while value 1 means y has complete
  96. information about x.
  97. Args:
  98. preds: 1D or 2D tensor of categorical (nominal) data
  99. - 1D shape: (batch_size,)
  100. - 2D shape: (batch_size, num_classes)
  101. target: 1D or 2D tensor of categorical (nominal) data
  102. - 1D shape: (batch_size,)
  103. - 2D shape: (batch_size, num_classes)
  104. nan_strategy: Indication of whether to replace or drop ``NaN`` values
  105. nan_replace_value: Value to replace ``NaN``s when ``nan_strategy = 'replace'``
  106. Returns:
  107. Tensor containing Theil's U statistic
  108. Example:
  109. >>> from torch import randint
  110. >>> from torchmetrics.functional.nominal import theils_u
  111. >>> preds = randint(10, (10,))
  112. >>> target = randint(10, (10,))
  113. >>> theils_u(preds, target)
  114. tensor(0.8530)
  115. """
  116. num_classes = len(torch.cat([preds, target]).unique())
  117. confmat = _theils_u_update(preds, target, num_classes, nan_strategy, nan_replace_value)
  118. return _theils_u_compute(confmat)
  119. def theils_u_matrix(
  120. matrix: Tensor,
  121. nan_strategy: Literal["replace", "drop"] = "replace",
  122. nan_replace_value: Optional[float] = 0.0,
  123. ) -> Tensor:
  124. r"""Compute `Theil's U`_ statistic between a set of multiple variables.
  125. This can serve as a convenient tool to compute Theil's U statistic for analyses of correlation between categorical
  126. variables in your dataset.
  127. Args:
  128. matrix: A tensor of categorical (nominal) data, where:
  129. - rows represent a number of data points
  130. - columns represent a number of categorical (nominal) features
  131. nan_strategy: Indication of whether to replace or drop ``NaN`` values
  132. nan_replace_value: Value to replace ``NaN``s when ``nan_strategy = 'replace'``
  133. Returns:
  134. Theil's U statistic for a dataset of categorical variables
  135. Example:
  136. >>> from torch import randint
  137. >>> from torchmetrics.functional.nominal import theils_u_matrix
  138. >>> matrix = randint(0, 4, (200, 5))
  139. >>> theils_u_matrix(matrix)
  140. tensor([[1.0000, 0.0202, 0.0142, 0.0196, 0.0353],
  141. [0.0202, 1.0000, 0.0070, 0.0136, 0.0065],
  142. [0.0143, 0.0070, 1.0000, 0.0125, 0.0206],
  143. [0.0198, 0.0137, 0.0125, 1.0000, 0.0312],
  144. [0.0352, 0.0065, 0.0204, 0.0308, 1.0000]])
  145. """
  146. _nominal_input_validation(nan_strategy, nan_replace_value)
  147. num_variables = matrix.shape[1]
  148. theils_u_matrix_value = torch.ones(num_variables, num_variables, device=matrix.device)
  149. for i, j in itertools.combinations(range(num_variables), 2):
  150. x, y = matrix[:, i], matrix[:, j]
  151. num_classes = len(torch.cat([x, y]).unique())
  152. confmat = _theils_u_update(x, y, num_classes, nan_strategy, nan_replace_value)
  153. theils_u_matrix_value[i, j] = _theils_u_compute(confmat)
  154. theils_u_matrix_value[j, i] = _theils_u_compute(confmat.T)
  155. return theils_u_matrix_value