data.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. from collections.abc import Sequence
  16. from typing import Any, List, Optional, Union
  17. import torch
  18. from lightning_utilities import apply_to_collection
  19. from torch import Tensor
  20. from torchmetrics.utilities.exceptions import TorchMetricsUserWarning
  21. from torchmetrics.utilities.imports import _TORCH_LESS_THAN_2_6, _XLA_AVAILABLE
  22. from torchmetrics.utilities.prints import rank_zero_warn
  23. METRIC_EPS = 1e-6
  24. def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor:
  25. """Concatenation along the zero dimension."""
  26. if isinstance(x, torch.Tensor):
  27. return x
  28. x = [y.unsqueeze(0) if y.numel() == 1 and y.ndim == 0 else y for y in x]
  29. if not x: # empty list
  30. raise ValueError("No samples to concatenate")
  31. return torch.cat(x, dim=0)
  32. def dim_zero_sum(x: Tensor) -> Tensor:
  33. """Summation along the zero dimension."""
  34. return torch.sum(x, dim=0)
  35. def dim_zero_mean(x: Tensor) -> Tensor:
  36. """Average along the zero dimension."""
  37. return torch.mean(x, dim=0)
  38. def dim_zero_max(x: Tensor) -> Tensor:
  39. """Max along the zero dimension."""
  40. return torch.max(x, dim=0).values
  41. def dim_zero_min(x: Tensor) -> Tensor:
  42. """Min along the zero dimension."""
  43. return torch.min(x, dim=0).values
  44. def _flatten(x: Sequence) -> list:
  45. """Flatten list of list into single list."""
  46. return [item for sublist in x for item in sublist]
  47. def _flatten_dict(x: dict) -> tuple[dict, bool]:
  48. """Flatten dict of dicts into single dict and checking for duplicates in keys along the way."""
  49. new_dict = {}
  50. duplicates = False
  51. for key, value in x.items():
  52. if isinstance(value, dict):
  53. for k, v in value.items():
  54. if k in new_dict:
  55. duplicates = True
  56. new_dict[k] = v
  57. else:
  58. if key in new_dict:
  59. duplicates = True
  60. new_dict[key] = value
  61. return new_dict, duplicates
  62. def to_onehot(
  63. label_tensor: Tensor,
  64. num_classes: Optional[int] = None,
  65. ) -> Tensor:
  66. """Convert a dense label tensor to one-hot format.
  67. Args:
  68. label_tensor: dense label tensor, with shape [N, d1, d2, ...]
  69. num_classes: number of classes C
  70. Returns:
  71. A sparse label tensor with shape [N, C, d1, d2, ...]
  72. Example:
  73. >>> x = torch.tensor([1, 2, 3])
  74. >>> to_onehot(x)
  75. tensor([[0, 1, 0, 0],
  76. [0, 0, 1, 0],
  77. [0, 0, 0, 1]])
  78. """
  79. if num_classes is None:
  80. num_classes = int(label_tensor.max().detach().item() + 1)
  81. tensor_onehot = torch.zeros(
  82. label_tensor.shape[0],
  83. num_classes,
  84. *label_tensor.shape[1:],
  85. dtype=label_tensor.dtype,
  86. device=label_tensor.device,
  87. )
  88. index = label_tensor.long().unsqueeze(1).expand_as(tensor_onehot)
  89. return tensor_onehot.scatter_(1, index, 1.0)
  90. def _top_k_with_half_precision_support(x: Tensor, k: int = 1, dim: int = 1) -> Tensor:
  91. """torch.top_k does not support half precision on CPU."""
  92. if x.dtype == torch.half and not x.is_cuda:
  93. idx = torch.argsort(x, dim=dim, stable=True).flip(dim)
  94. return idx.narrow(dim, 0, k)
  95. return x.topk(k=k, dim=dim).indices
  96. def select_topk(prob_tensor: Tensor, topk: int = 1, dim: int = 1) -> Tensor:
  97. """Convert a probability tensor to binary by selecting top-k the highest entries.
  98. Args:
  99. prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the
  100. position defined by the ``dim`` argument
  101. topk: number of the highest entries to turn into 1s
  102. dim: dimension on which to compare entries
  103. Returns:
  104. A binary tensor of the same shape as the input tensor of type ``torch.int32``
  105. Example:
  106. >>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]])
  107. >>> select_topk(x, topk=2)
  108. tensor([[0, 1, 1],
  109. [1, 1, 0]], dtype=torch.int32)
  110. """
  111. topk_tensor = torch.zeros_like(prob_tensor, dtype=torch.int)
  112. if topk == 1: # argmax has better performance than topk
  113. topk_tensor.scatter_(dim, prob_tensor.argmax(dim=dim, keepdim=True), 1.0)
  114. else:
  115. topk_tensor.scatter_(dim, _top_k_with_half_precision_support(prob_tensor, k=topk, dim=dim), 1.0)
  116. return topk_tensor.int()
  117. def to_categorical(x: Tensor, argmax_dim: int = 1) -> Tensor:
  118. """Convert a tensor of probabilities to a dense label tensor.
  119. Args:
  120. x: probabilities to get the categorical label [N, d1, d2, ...]
  121. argmax_dim: dimension to apply
  122. Return:
  123. A tensor with categorical labels [N, d2, ...]
  124. Example:
  125. >>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]])
  126. >>> to_categorical(x)
  127. tensor([1, 0])
  128. """
  129. return torch.argmax(x, dim=argmax_dim)
  130. def _squeeze_scalar_element_tensor(x: Tensor) -> Tensor:
  131. return x.squeeze() if x.numel() == 1 else x
  132. def _squeeze_if_scalar(data: Any) -> Any:
  133. return apply_to_collection(data, Tensor, _squeeze_scalar_element_tensor)
  134. def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
  135. """Implement custom bincount.
  136. PyTorch currently does not support ``torch.bincount`` when running in deterministic mode on GPU or when running
  137. MPS devices or when running on XLA device. This implementation therefore falls back to using a combination of
  138. `torch.arange` and `torch.eq` in these scenarios. A small performance hit can expected and higher memory consumption
  139. as `[batch_size, mincount]` tensor needs to be initialized compared to native ``torch.bincount``.
  140. Args:
  141. x: tensor to count
  142. minlength: minimum length to count
  143. Returns:
  144. Number of occurrences for each unique element in x
  145. Example:
  146. >>> x = torch.tensor([0,0,0,1,1,2,2,2,2])
  147. >>> _bincount(x, minlength=3)
  148. tensor([3, 2, 4])
  149. """
  150. if minlength is None:
  151. minlength = len(torch.unique(x))
  152. if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or x.is_mps:
  153. mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1)
  154. return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0)
  155. return torch.bincount(x, minlength=minlength)
  156. def _cumsum(x: Tensor, dim: Optional[int] = 0, dtype: Optional[torch.dtype] = None) -> Tensor:
  157. """Implement custom cumulative summation for Torch versions which does not support it natively."""
  158. is_cuda_fp_deterministic = torch.are_deterministic_algorithms_enabled() and x.is_cuda and x.is_floating_point()
  159. if _TORCH_LESS_THAN_2_6 and is_cuda_fp_deterministic and sys.platform != "win32":
  160. rank_zero_warn(
  161. "You are trying to use a metric in deterministic mode on GPU that uses `torch.cumsum`, which is currently"
  162. " not supported. The tensor will be copied to the CPU memory to compute it and then copied back to GPU."
  163. " Expect some slowdowns.",
  164. TorchMetricsUserWarning,
  165. )
  166. return x.cpu().cumsum(dim=dim, dtype=dtype).to(x.device)
  167. return torch.cumsum(x, dim=dim, dtype=dtype)
  168. def _flexible_bincount(x: Tensor) -> Tensor:
  169. """Similar to `_bincount`, but works also with tensor that do not contain continuous values.
  170. Args:
  171. x: tensor to count
  172. Returns:
  173. Number of occurrences for each unique element in x
  174. """
  175. unique_x, inverse_indices = torch.unique(x, return_inverse=True)
  176. return _bincount(inverse_indices, minlength=len(unique_x))
  177. def allclose(tensor1: Tensor, tensor2: Tensor) -> bool:
  178. """Wrap torch.allclose to be robust towards dtype difference."""
  179. if tensor1.dtype != tensor2.dtype:
  180. tensor2 = tensor2.to(dtype=tensor1.dtype)
  181. return torch.allclose(tensor1, tensor2)
  182. def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor:
  183. """Interpolation function comparable to numpy.interp.
  184. Args:
  185. x: x-coordinates where to evaluate the interpolated values
  186. xp: x-coordinates of the data points
  187. fp: y-coordinates of the data points
  188. """
  189. # Sort xp and fp based on xp for compatibility with np.interp
  190. sorted_indices = torch.argsort(xp)
  191. xp = xp[sorted_indices]
  192. fp = fp[sorted_indices]
  193. # Calculate slopes for each interval
  194. slopes = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1])
  195. # Identify where x falls relative to xp
  196. indices = torch.searchsorted(xp, x) - 1
  197. indices = torch.clamp(indices, 0, len(slopes) - 1)
  198. # Compute interpolated values
  199. return fp[indices] + slopes[indices] * (x - xp[indices])