utils.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.utilities.prints import rank_zero_warn
  19. def _nominal_input_validation(nan_strategy: str, nan_replace_value: Optional[float]) -> None:
  20. if nan_strategy not in ["replace", "drop"]:
  21. raise ValueError(
  22. f"Argument `nan_strategy` is expected to be one of `['replace', 'drop']`, but got {nan_strategy}"
  23. )
  24. if nan_strategy == "replace" and not isinstance(nan_replace_value, (float, int)):
  25. raise ValueError(
  26. "Argument `nan_replace` is expected to be of a type `int` or `float` when `nan_strategy = 'replace`, "
  27. f"but got {nan_replace_value}"
  28. )
  29. def _compute_expected_freqs(confmat: Tensor) -> Tensor:
  30. """Compute the expected frequenceis from the provided confusion matrix."""
  31. margin_sum_rows, margin_sum_cols = confmat.sum(1), confmat.sum(0)
  32. return torch.einsum("r, c -> rc", margin_sum_rows, margin_sum_cols) / confmat.sum()
  33. def _compute_chi_squared(confmat: Tensor, bias_correction: bool) -> Tensor:
  34. """Chi-square test of independenc of variables in a confusion matrix table.
  35. Adapted from: https://github.com/scipy/scipy/blob/v1.9.2/scipy/stats/contingency.py.
  36. """
  37. expected_freqs = _compute_expected_freqs(confmat)
  38. # Get degrees of freedom
  39. df = expected_freqs.numel() - sum(expected_freqs.shape) + expected_freqs.ndim - 1
  40. if df == 0:
  41. return torch.tensor(0.0, device=confmat.device)
  42. if df == 1 and bias_correction:
  43. diff = expected_freqs - confmat
  44. direction = diff.sign()
  45. confmat += direction * torch.minimum(0.5 * torch.ones_like(direction), direction.abs())
  46. return torch.sum((confmat - expected_freqs) ** 2 / expected_freqs)
  47. def _drop_empty_rows_and_cols(confmat: Tensor) -> Tensor:
  48. """Drop all rows and columns containing only zeros.
  49. Example:
  50. >>> from torch import randint
  51. >>> from torchmetrics.functional.nominal.utils import _drop_empty_rows_and_cols
  52. >>> matrix = randint(10, size=(4, 3))
  53. >>> matrix[1, :] = matrix[:, 1] = 0
  54. >>> matrix
  55. tensor([[2, 0, 6],
  56. [0, 0, 0],
  57. [0, 0, 0],
  58. [3, 0, 4]])
  59. >>> _drop_empty_rows_and_cols(matrix)
  60. tensor([[2, 6],
  61. [3, 4]])
  62. """
  63. confmat = confmat[confmat.sum(1) != 0]
  64. return confmat[:, confmat.sum(0) != 0]
  65. def _compute_phi_squared_corrected(
  66. phi_squared: Tensor,
  67. num_rows: int,
  68. num_cols: int,
  69. confmat_sum: Tensor,
  70. ) -> Tensor:
  71. """Compute bias-corrected Phi Squared."""
  72. return torch.max(
  73. torch.tensor(0.0, device=phi_squared.device),
  74. phi_squared - ((num_rows - 1) * (num_cols - 1)) / (confmat_sum - 1),
  75. )
  76. def _compute_rows_and_cols_corrected(num_rows: int, num_cols: int, confmat_sum: Tensor) -> tuple[Tensor, Tensor]:
  77. """Compute bias-corrected number of rows and columns."""
  78. rows_corrected = num_rows - (num_rows - 1) ** 2 / (confmat_sum - 1)
  79. cols_corrected = num_cols - (num_cols - 1) ** 2 / (confmat_sum - 1)
  80. return rows_corrected, cols_corrected
  81. def _compute_bias_corrected_values(
  82. phi_squared: Tensor, num_rows: int, num_cols: int, confmat_sum: Tensor
  83. ) -> tuple[Tensor, Tensor, Tensor]:
  84. """Compute bias-corrected Phi Squared and number of rows and columns."""
  85. phi_squared_corrected = _compute_phi_squared_corrected(phi_squared, num_rows, num_cols, confmat_sum)
  86. rows_corrected, cols_corrected = _compute_rows_and_cols_corrected(num_rows, num_cols, confmat_sum)
  87. return phi_squared_corrected, rows_corrected, cols_corrected
  88. def _handle_nan_in_data(
  89. preds: Tensor,
  90. target: Tensor,
  91. nan_strategy: Literal["replace", "drop"] = "replace",
  92. nan_replace_value: Optional[float] = 0.0,
  93. ) -> tuple[Tensor, Tensor]:
  94. """Handle ``NaN`` values in input data.
  95. If ``nan_strategy = 'replace'``, all ``NaN`` values are replaced with ``nan_replace_value``.
  96. If ``nan_strategy = 'drop'``, all rows containing ``NaN`` in any of two vectors are dropped.
  97. Args:
  98. preds: 1D tensor of categorical (nominal) data
  99. target: 1D tensor of categorical (nominal) data
  100. nan_strategy: Indication of whether to replace or drop ``NaN`` values
  101. nan_replace_value: Value to replace ``NaN`s when ``nan_strategy = 'replace```
  102. Returns:
  103. Updated ``preds`` and ``target`` tensors which contain no ``Nan``
  104. Raises:
  105. ValueError: If ``nan_strategy`` is not from ``['replace', 'drop']``.
  106. ValueError: If ``nan_strategy = replace`` and ``nan_replace_value`` is not of a type ``int`` or ``float``.
  107. """
  108. if nan_strategy == "replace":
  109. return preds.nan_to_num(nan_replace_value), target.nan_to_num(nan_replace_value)
  110. rows_contain_nan = torch.logical_or(preds.isnan(), target.isnan())
  111. return preds[~rows_contain_nan], target[~rows_contain_nan]
  112. def _unable_to_use_bias_correction_warning(metric_name: str) -> None:
  113. rank_zero_warn(
  114. f"Unable to compute {metric_name} using bias correction. Please consider to set `bias_correction=False`."
  115. )