checks.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import multiprocessing
  15. import os
  16. import sys
  17. from collections.abc import Mapping, Sequence
  18. from functools import partial
  19. from time import perf_counter
  20. from typing import Any, Callable, Optional, no_type_check
  21. from unittest.mock import Mock
  22. import torch
  23. from torch import Tensor
  24. from torchmetrics.metric import Metric
  25. _DOCTEST_DOWNLOAD_TIMEOUT = int(os.environ.get("DOCTEST_DOWNLOAD_TIMEOUT", 120))
  26. _SKIP_SLOW_DOCTEST = bool(os.environ.get("SKIP_SLOW_DOCTEST", 0))
  27. def _check_for_empty_tensors(preds: Tensor, target: Tensor) -> bool:
  28. return preds.numel() == target.numel() == 0
  29. def _check_same_shape(preds: Tensor, target: Tensor) -> None:
  30. """Check that predictions and target have the same shape, else raise error."""
  31. if preds.shape != target.shape:
  32. raise RuntimeError(
  33. f"Predictions and targets are expected to have the same shape, but got {preds.shape} and {target.shape}."
  34. )
  35. def _check_retrieval_functional_inputs(
  36. preds: Tensor,
  37. target: Tensor,
  38. allow_non_binary_target: bool = False,
  39. ) -> tuple[Tensor, Tensor]:
  40. """Check ``preds`` and ``target`` tensors are of the same shape and of the correct data type.
  41. Args:
  42. preds: either tensor with scores/logits
  43. target: tensor with ground true labels
  44. allow_non_binary_target: whether to allow target to contain non-binary values
  45. Raises:
  46. ValueError:
  47. If ``preds`` and ``target`` don't have the same shape, if they are empty
  48. or not of the correct ``dtypes``.
  49. Returns:
  50. preds: as torch.float32
  51. target: as torch.long if not floating point else torch.float32
  52. """
  53. if preds.shape != target.shape:
  54. raise ValueError("`preds` and `target` must be of the same shape")
  55. if not preds.numel() or not preds.size():
  56. raise ValueError("`preds` and `target` must be non-empty and non-scalar tensors")
  57. return _check_retrieval_target_and_prediction_types(preds, target, allow_non_binary_target=allow_non_binary_target)
  58. def _check_retrieval_inputs(
  59. indexes: Tensor,
  60. preds: Tensor,
  61. target: Tensor,
  62. allow_non_binary_target: bool = False,
  63. ignore_index: Optional[int] = None,
  64. ) -> tuple[Tensor, Tensor, Tensor]:
  65. """Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct data type.
  66. Args:
  67. indexes: tensor with queries indexes
  68. preds: tensor with scores/logits
  69. target: tensor with ground true labels
  70. allow_non_binary_target: whether to allow target to contain non-binary values
  71. ignore_index: ignore predictions where targets are equal to this number
  72. Raises:
  73. ValueError:
  74. If ``preds`` and ``target`` don't have the same shape, if they are empty or not of the correct ``dtypes``.
  75. Returns:
  76. indexes: as ``torch.long``
  77. preds: as ``torch.float32``
  78. target: as ``torch.long``
  79. """
  80. if indexes.shape != preds.shape or preds.shape != target.shape:
  81. raise ValueError("`indexes`, `preds` and `target` must be of the same shape")
  82. if indexes.dtype is not torch.long:
  83. raise ValueError("`indexes` must be a tensor of long integers")
  84. # remove predictions where target is equal to `ignore_index`
  85. if ignore_index is not None:
  86. valid_positions = target != ignore_index
  87. indexes, preds, target = indexes[valid_positions], preds[valid_positions], target[valid_positions]
  88. if not indexes.numel() or not indexes.size():
  89. raise ValueError(
  90. "`indexes`, `preds` and `target` must be non-empty and non-scalar tensors",
  91. )
  92. preds, target = _check_retrieval_target_and_prediction_types(
  93. preds, target, allow_non_binary_target=allow_non_binary_target
  94. )
  95. return indexes.long().flatten(), preds, target
  96. def _check_retrieval_target_and_prediction_types(
  97. preds: Tensor,
  98. target: Tensor,
  99. allow_non_binary_target: bool = False,
  100. ) -> tuple[Tensor, Tensor]:
  101. """Check ``preds`` and ``target`` tensors are of the same shape and of the correct data type.
  102. Args:
  103. preds: either tensor with scores/logits
  104. target: tensor with ground true labels
  105. allow_non_binary_target: whether to allow target to contain non-binary values
  106. Raises:
  107. ValueError:
  108. If ``preds`` and ``target`` don't have the same shape, if they are empty or not of the correct ``dtypes``.
  109. """
  110. if target.dtype not in (torch.bool, torch.long, torch.int) and not torch.is_floating_point(target):
  111. raise ValueError("`target` must be a tensor of booleans, integers or floats")
  112. if not preds.is_floating_point():
  113. raise ValueError("`preds` must be a tensor of floats")
  114. if not allow_non_binary_target and (target.max() > 1 or target.min() < 0):
  115. raise ValueError("`target` must contain `binary` values")
  116. target = target.float() if target.is_floating_point() else target.long()
  117. preds = preds.float()
  118. return preds.flatten(), target.flatten()
  119. def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-6) -> bool:
  120. """Recursively asserting that two results are within a certain tolerance."""
  121. # single output compare
  122. if isinstance(res1, Tensor):
  123. return torch.allclose(res1, res2, atol=atol)
  124. if isinstance(res1, str):
  125. return res1 == res2
  126. if isinstance(res1, Sequence):
  127. return all(_allclose_recursive(r1, r2) for r1, r2 in zip(res1, res2))
  128. if isinstance(res1, Mapping):
  129. return all(_allclose_recursive(res1[k], res2[k]) for k in res1)
  130. return res1 == res2
  131. @no_type_check
  132. def check_forward_full_state_property(
  133. metric_class: Metric,
  134. init_args: Optional[dict[str, Any]] = None,
  135. input_args: Optional[dict[str, Any]] = None,
  136. num_update_to_compare: Sequence[int] = [10, 100, 1000],
  137. reps: int = 5,
  138. ) -> None:
  139. """Check if the new ``full_state_update`` property works as intended.
  140. This function checks if the property can safely be set to ``False`` which will for most metrics results in a
  141. speedup when using ``forward``.
  142. Args:
  143. metric_class: metric class object that should be checked
  144. init_args: dict containing arguments for initializing the metric class
  145. input_args: dict containing arguments to pass to ``forward``
  146. num_update_to_compare: if we successfully detect that the flag is safe to set to ``False``
  147. we will run some speedup test. This arg should be a list of integers for how many
  148. steps to compare over.
  149. reps: number of repetitions of speedup test
  150. Example (states in ``update`` are independent, save to set ``full_state_update=False``)
  151. >>> from torchmetrics.classification import MulticlassConfusionMatrix
  152. >>> check_forward_full_state_property( # doctest: +SKIP
  153. ... MulticlassConfusionMatrix,
  154. ... init_args = {'num_classes': 3},
  155. ... input_args = {'preds': torch.randint(3, (100,)), 'target': torch.randint(3, (100,))},
  156. ... )
  157. Full state for 10 steps took: ...
  158. Partial state for 10 steps took: ...
  159. Full state for 100 steps took: ...
  160. Partial state for 100 steps took: ...
  161. Full state for 1000 steps took: ...
  162. Partial state for 1000 steps took: ...
  163. Recommended setting `full_state_update=False`
  164. Example (states in ``update`` are dependent meaning that ``full_state_update=True``):
  165. >>> from torchmetrics.classification import MulticlassConfusionMatrix
  166. >>> class MyMetric(MulticlassConfusionMatrix):
  167. ... def update(self, preds, target):
  168. ... super().update(preds, target)
  169. ... # by construction make future states dependent on prior states
  170. ... if self.confmat.sum() > 20:
  171. ... self.reset()
  172. >>> check_forward_full_state_property(
  173. ... MyMetric,
  174. ... init_args = {'num_classes': 3},
  175. ... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))},
  176. ... )
  177. Recommended setting `full_state_update=True`
  178. """
  179. init_args = init_args or {}
  180. input_args = input_args or {}
  181. class FullState(metric_class):
  182. full_state_update = True
  183. class PartState(metric_class):
  184. full_state_update = False
  185. fullstate = FullState(**init_args)
  186. partstate = PartState(**init_args)
  187. equal = True
  188. try: # if it fails, the code most likely need access to the full state
  189. for _ in range(num_update_to_compare[0]):
  190. equal = equal & _allclose_recursive(fullstate(**input_args), partstate(**input_args))
  191. except RuntimeError:
  192. equal = False
  193. res1 = fullstate.compute()
  194. try: # if it fails, the code most likely need access to the full state
  195. res2 = partstate.compute()
  196. except RuntimeError:
  197. equal = False
  198. equal = equal & _allclose_recursive(res1, res2)
  199. if not equal: # we can stop early because the results did not match
  200. print("Recommended setting `full_state_update=True`")
  201. return
  202. # Do timings
  203. res = torch.zeros(2, len(num_update_to_compare), reps)
  204. for i, metric in enumerate([fullstate, partstate]):
  205. for j, t in enumerate(num_update_to_compare):
  206. for r in range(reps):
  207. start = perf_counter()
  208. for _ in range(t):
  209. _ = metric(**input_args)
  210. end = perf_counter()
  211. res[i, j, r] = end - start
  212. metric.reset()
  213. mean = torch.mean(res, -1)
  214. std = torch.std(res, -1)
  215. for t in range(len(num_update_to_compare)):
  216. print(f"Full state for {num_update_to_compare[t]} steps took: {mean[0, t]}+-{std[0, t]:0.3f}")
  217. print(f"Partial state for {num_update_to_compare[t]} steps took: {mean[1, t]:0.3f}+-{std[1, t]:0.3f}")
  218. faster = (mean[1, -1] < mean[0, -1]).item() # if faster on average, we recommend upgrading
  219. print(f"Recommended setting `full_state_update={not faster}`")
  220. return
  221. def is_overridden(method_name: str, instance: object, parent: object) -> bool:
  222. """Check if a method has been overridden by an instance compared to its parent class."""
  223. instance_attr = getattr(instance, method_name, None)
  224. if instance_attr is None:
  225. return False
  226. # `functools.wraps()` support
  227. if hasattr(instance_attr, "__wrapped__"):
  228. instance_attr = instance_attr.__wrapped__
  229. # `Mock(wraps=...)` support
  230. if isinstance(instance_attr, Mock):
  231. # access the wrapped function
  232. instance_attr = instance_attr._mock_wraps
  233. # `partial` support
  234. elif isinstance(instance_attr, partial):
  235. instance_attr = instance_attr.func
  236. if instance_attr is None:
  237. return False
  238. parent_attr = getattr(parent, method_name, None)
  239. if parent_attr is None:
  240. raise ValueError("The parent should define the method")
  241. return instance_attr.__code__ != parent_attr.__code__
  242. def _try_proceed_with_timeout(fn: Callable, timeout: int = _DOCTEST_DOWNLOAD_TIMEOUT) -> bool:
  243. """Check if a certain function is taking too long to execute.
  244. Function will only be executed if running inside a doctest context. Currently, does not support Windows.
  245. Args:
  246. fn: function to check
  247. timeout: timeout for function
  248. Returns:
  249. Bool indicating if the function finished within the specified timeout
  250. """
  251. # source: https://stackoverflow.com/a/14924210/4521646
  252. if multiprocessing.current_process().daemon:
  253. # Skip timeout check in daemon processes as they cannot spawn child processes.
  254. return True
  255. proc = multiprocessing.Process(target=fn)
  256. print(f"Trying to run `{fn.__name__}` for {timeout}s...", file=sys.stderr)
  257. proc.start()
  258. # Wait for N seconds or until process finishes
  259. proc.join(timeout)
  260. # If thread is still active
  261. if not proc.is_alive():
  262. return True
  263. print(f"`{fn.__name__}` did not complete with {timeout}, killing process and returning False", file=sys.stderr)
  264. # Terminate - may not work if process is stuck for good
  265. # proc.terminate()
  266. # proc.join()
  267. # OR Kill - will work for sure, no chance for process to finish nicely however
  268. proc.kill()
  269. return False