compute.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional, Union
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.utilities import rank_zero_warn
  19. def _safe_matmul(x: Tensor, y: Tensor) -> Tensor:
  20. """Safe calculation of matrix multiplication.
  21. If input is float16, will cast to float32 for computation and back again.
  22. """
  23. if x.dtype == torch.float16 or y.dtype == torch.float16:
  24. return (x.float() @ y.T.float()).half()
  25. return x @ y.T
  26. def _safe_xlogy(x: Tensor, y: Tensor) -> Tensor:
  27. """Compute x * log(y). Returns 0 if x=0.
  28. Example:
  29. >>> import torch
  30. >>> x = torch.zeros(1)
  31. >>> _safe_xlogy(x, 1/x)
  32. tensor([0.])
  33. """
  34. res = x * torch.log(y)
  35. res[x == 0] = 0.0
  36. return res
  37. def _safe_divide(
  38. num: Tensor,
  39. denom: Tensor,
  40. zero_division: Union[float, Literal["warn", "nan"]] = 0.0,
  41. ) -> Tensor:
  42. """Safe division, by preventing division by zero.
  43. Function will cast to float if input is not already to secure backwards compatibility.
  44. Args:
  45. num: numerator tensor
  46. denom: denominator tensor, which may contain zeros
  47. zero_division: value to replace elements divided by zero
  48. Example:
  49. >>> import torch
  50. >>> num = torch.tensor([1.0, 2.0, 3.0])
  51. >>> denom = torch.tensor([0.0, 1.0, 2.0])
  52. >>> _safe_divide(num, denom)
  53. tensor([0.0000, 2.0000, 1.5000])
  54. """
  55. num = num if num.is_floating_point() else num.float()
  56. denom = denom if denom.is_floating_point() else denom.float()
  57. if isinstance(zero_division, (float, int)) or zero_division == "warn":
  58. if zero_division == "warn" and torch.any(denom == 0):
  59. rank_zero_warn("Detected zero division in _safe_divide. Setting 0/0 to 0.0")
  60. zero_division = 0.0 if zero_division == "warn" else zero_division
  61. zero_division_tensor = torch.full((), zero_division, dtype=num.dtype, device=num.device)
  62. return torch.where(denom != 0, num / denom, zero_division_tensor)
  63. return torch.true_divide(num, denom)
  64. def _adjust_weights_safe_divide(
  65. score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor, top_k: int = 1
  66. ) -> Tensor:
  67. if average is None or average == "none":
  68. return score
  69. if average == "weighted":
  70. weights = tp + fn
  71. else:
  72. weights = torch.ones_like(score)
  73. if not multilabel:
  74. weights[tp + fp + fn == 0 if top_k == 1 else tp + fn == 0] = 0.0
  75. return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)
  76. def _auc_format_inputs(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
  77. """Check that auc input is correct."""
  78. x = x.squeeze() if x.ndim > 1 else x
  79. y = y.squeeze() if y.ndim > 1 else y
  80. if x.ndim > 1 or y.ndim > 1:
  81. raise ValueError(
  82. f"Expected both `x` and `y` tensor to be 1d, but got tensors with dimension {x.ndim} and {y.ndim}"
  83. )
  84. if x.numel() != y.numel():
  85. raise ValueError(
  86. f"Expected the same number of elements in `x` and `y` tensor but received {x.numel()} and {y.numel()}"
  87. )
  88. return x, y
  89. def _auc_compute_without_check(x: Tensor, y: Tensor, direction: float, axis: int = -1) -> Tensor:
  90. """Compute area under the curve using the trapezoidal rule.
  91. Assumes increasing or decreasing order of `x`.
  92. """
  93. with torch.no_grad():
  94. auc_score: Tensor = torch.trapz(y, x, dim=axis) * direction
  95. return auc_score
  96. def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
  97. """Compute area under the curve using the trapezoidal rule.
  98. Example:
  99. >>> import torch
  100. >>> x = torch.tensor([1, 2, 3, 4])
  101. >>> y = torch.tensor([1, 2, 3, 4])
  102. >>> _auc_compute(x, y)
  103. tensor(7.5000)
  104. """
  105. with torch.no_grad():
  106. if reorder:
  107. x, x_idx = torch.sort(x, stable=True)
  108. y = y[x_idx]
  109. dx = x[1:] - x[:-1]
  110. if (dx < 0).any():
  111. if (dx <= 0).all():
  112. direction = -1.0
  113. else:
  114. raise ValueError(
  115. "The `x` tensor is neither increasing or decreasing. Try setting the reorder argument to `True`."
  116. )
  117. else:
  118. direction = 1.0
  119. return _auc_compute_without_check(x, y, direction)
  120. def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
  121. """Compute Area Under the Curve (AUC) using the trapezoidal rule.
  122. Args:
  123. x: x-coordinates, must be either increasing or decreasing
  124. y: y-coordinates
  125. reorder: if True, will reorder the arrays to make it either increasing or decreasing
  126. Return:
  127. Tensor containing AUC score
  128. """
  129. x, y = _auc_format_inputs(x, y)
  130. return _auc_compute(x, y, reorder=reorder)
  131. def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor:
  132. """One-dimensional linear interpolation for monotonically increasing sample points.
  133. Returns the one-dimensional piecewise linear interpolation to a function with
  134. given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.
  135. Adjusted version of this https://github.com/pytorch/pytorch/issues/50334#issuecomment-1000917964
  136. Args:
  137. x: the :math:`x`-coordinates at which to evaluate the interpolated values.
  138. xp: the :math:`x`-coordinates of the data points, must be increasing.
  139. fp: the :math:`y`-coordinates of the data points, same length as `xp`.
  140. Returns:
  141. the interpolated values, same size as `x`.
  142. Example:
  143. >>> x = torch.tensor([0.5, 1.5, 2.5])
  144. >>> xp = torch.tensor([1, 2, 3])
  145. >>> fp = torch.tensor([1, 2, 3])
  146. >>> interp(x, xp, fp)
  147. tensor([0.5000, 1.5000, 2.5000])
  148. """
  149. m = _safe_divide(fp[1:] - fp[:-1], xp[1:] - xp[:-1])
  150. b = fp[:-1] - (m * xp[:-1])
  151. indices = torch.sum(torch.ge(x[:, None], xp[None, :]), 1) - 1
  152. indices = torch.clamp(indices, 0, len(m) - 1)
  153. return m[indices] * x + b[indices]
  154. def normalize_logits_if_needed(tensor: Tensor, normalization: Optional[Literal["sigmoid", "softmax"]]) -> Tensor:
  155. """Normalize logits if needed.
  156. If input tensor is outside the [0,1] we assume that logits are provided and apply the normalization.
  157. Use torch.where to prevent device-host sync.
  158. Args:
  159. tensor: input tensor that may be logits or probabilities
  160. normalization: normalization method, either 'sigmoid' or 'softmax'
  161. Returns:
  162. normalized tensor if needed
  163. Example:
  164. >>> import torch
  165. >>> tensor = torch.tensor([-1.0, 0.0, 1.0])
  166. >>> normalize_logits_if_needed(tensor, normalization="sigmoid")
  167. tensor([0.2689, 0.5000, 0.7311])
  168. >>> tensor = torch.tensor([[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]])
  169. >>> normalize_logits_if_needed(tensor, normalization="softmax")
  170. tensor([[0.0900, 0.2447, 0.6652],
  171. [0.6652, 0.2447, 0.0900]])
  172. >>> tensor = torch.tensor([0.0, 0.5, 1.0])
  173. >>> normalize_logits_if_needed(tensor, normalization="sigmoid")
  174. tensor([0.0000, 0.5000, 1.0000])
  175. """
  176. # if not specified, do nothing.
  177. if not normalization:
  178. return tensor
  179. # decrease sigmoid on cpu .
  180. if tensor.device == torch.device("cpu"):
  181. if not torch.all((tensor >= 0) * (tensor <= 1)):
  182. tensor = tensor.sigmoid() if normalization == "sigmoid" else torch.softmax(tensor, dim=1)
  183. return tensor
  184. # decrease device-host sync on device .
  185. condition = ((tensor < 0) | (tensor > 1)).any()
  186. return torch.where(
  187. condition,
  188. torch.sigmoid(tensor) if normalization == "sigmoid" else torch.softmax(tensor, dim=1),
  189. tensor,
  190. )