__init__.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """The testing package contains testing-specific utilities."""
  18. import importlib.util
  19. import math
  20. import warnings
  21. from abc import ABC, abstractmethod
  22. from copy import deepcopy
  23. from itertools import product
  24. from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, TypeVar, Union
  25. import torch
  26. from torch.autograd import gradcheck
  27. from torch.testing import assert_close as _assert_close
  28. from kornia.core import Device, Dtype, Tensor, eye, tensor
  29. from kornia.utils.helpers import deprecated
  30. warnings.simplefilter("always", DeprecationWarning)
  31. warnings.warn(
  32. (
  33. "Since kornia 0.7.2 the `kornia.testing` module is deprecated and will be removed in kornia 0.8.0 (dec 2024)."
  34. " Most of these functionalities will be removed from kornia package and will be part of the tests of kornia."
  35. " Some functionalities which we think is important to keep will be moved to other kornia module."
  36. ),
  37. category=DeprecationWarning,
  38. stacklevel=2,
  39. )
  40. warnings.simplefilter("default", DeprecationWarning)
  41. __all__ = ["assert_close", "create_eye_batch", "tensor_to_gradcheck_var", "xla_is_available"]
  42. @deprecated(
  43. "kornia.utils.xla_is_available",
  44. "0.7.2",
  45. extra_reason="The `kornia.testing` module is deprecated and will be removed in kornia 0.8.0 (dec 2024).",
  46. )
  47. def xla_is_available() -> bool:
  48. """Return whether `torch_xla` is available in the system."""
  49. if importlib.util.find_spec("torch_xla") is not None:
  50. return True
  51. return False
  52. @deprecated(
  53. "kornia.utils.is_mps_tensor_safe",
  54. "0.7.2",
  55. extra_reason="The `kornia.testing` module is deprecated and will be removed in kornia 0.8.0 (dec 2024).",
  56. )
  57. def is_mps_tensor_safe(x: Tensor) -> bool:
  58. """Return whether tensor is on MPS device."""
  59. return "mps" in str(x.device)
  60. @deprecated(
  61. "kornia.utils.misc.eye_like",
  62. "0.7.2",
  63. extra_reason="The `kornia.testing` module is deprecated and will be removed in kornia 0.8.0 (dec 2024).",
  64. )
  65. def create_eye_batch(batch_size: int, eye_size: int, device: Device = None, dtype: Dtype = None) -> Tensor:
  66. """Create a batch of identity matrices of shape Bx3x3."""
  67. return eye(eye_size, device=device, dtype=dtype).view(1, eye_size, eye_size).expand(batch_size, -1, -1)
  68. def create_random_homography(batch_size: int, eye_size: int, std_val: float = 1e-3) -> Tensor:
  69. """Create a batch of random homographies of shape Bx3x3."""
  70. std = torch.FloatTensor(batch_size, eye_size, eye_size)
  71. eye = create_eye_batch(batch_size, eye_size)
  72. return eye + std.uniform_(-std_val, std_val)
  73. def tensor_to_gradcheck_var(
  74. tensor: Tensor, dtype: Dtype = torch.float64, requires_grad: bool = True
  75. ) -> Union[Tensor, str]:
  76. """Convert the input tensor to a valid variable to check the gradient.
  77. `gradcheck` needs 64-bit floating point and requires gradient.
  78. """
  79. if not torch.is_tensor(tensor):
  80. raise AssertionError(type(tensor))
  81. return tensor.requires_grad_(requires_grad).type(dtype)
  82. T = TypeVar("T")
  83. def dict_to(data: Dict[T, Any], device: Device, dtype: Dtype) -> Dict[T, Any]:
  84. out: Dict[T, Any] = {}
  85. for key, val in data.items():
  86. out[key] = val.to(device, dtype) if isinstance(val, Tensor) else val
  87. return out
  88. def compute_patch_error(x: Tensor, y: Tensor, h: int, w: int) -> Tensor:
  89. """Compute the absolute error between patches."""
  90. return torch.abs(x - y)[..., h // 4 : -h // 4, w // 4 : -w // 4].mean()
  91. def create_rectified_fundamental_matrix(batch_size: int) -> Tensor:
  92. """Create a batch of rectified fundamental matrices of shape Bx3x3."""
  93. F_rect = tensor([[0.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]]).view(1, 3, 3)
  94. F_repeat = F_rect.expand(batch_size, 3, 3)
  95. return F_repeat
  96. def create_random_fundamental_matrix(batch_size: int, std_val: float = 1e-3) -> Tensor:
  97. """Create a batch of random fundamental matrices of shape Bx3x3."""
  98. F_rect = create_rectified_fundamental_matrix(batch_size)
  99. H_left = create_random_homography(batch_size, 3, std_val)
  100. H_right = create_random_homography(batch_size, 3, std_val)
  101. return H_left.permute(0, 2, 1) @ F_rect @ H_right
  102. # {dtype: (rtol, atol)}
  103. _DTYPE_PRECISIONS = {
  104. torch.bfloat16: (7.8e-3, 7.8e-3),
  105. torch.float16: (9.7e-4, 9.7e-4),
  106. torch.float32: (1e-4, 1e-5), # TODO: Update to ~1.2e-7
  107. # TODO: Update to ~2.3e-16 for fp64
  108. torch.float64: (1e-5, 1e-5), # TODO: BaseTester used (1.3e-6, 1e-5), but it fails for general cases
  109. }
  110. class BaseTester(ABC):
  111. @abstractmethod
  112. def test_smoke(self, device: Device, dtype: Dtype) -> None:
  113. raise NotImplementedError("Implement a stupid routine.")
  114. @abstractmethod
  115. def test_exception(self, device: Device, dtype: Dtype) -> None:
  116. raise NotImplementedError("Implement a stupid routine.")
  117. @abstractmethod
  118. def test_cardinality(self, device: Device, dtype: Dtype) -> None:
  119. raise NotImplementedError("Implement a stupid routine.")
  120. @abstractmethod
  121. def test_dynamo(self, device: Device, dtype: Dtype, torch_optimizer: Callable[..., Any]) -> None:
  122. raise NotImplementedError("Implement a stupid routine.")
  123. @abstractmethod
  124. def test_gradcheck(self, device: Device) -> None:
  125. raise NotImplementedError("Implement a stupid routine.")
  126. @abstractmethod
  127. def test_module(self, device: Device, dtype: Dtype) -> None:
  128. pass
  129. @staticmethod
  130. def assert_close(
  131. actual: Tensor,
  132. expected: Tensor,
  133. rtol: Optional[float] = None,
  134. atol: Optional[float] = None,
  135. low_tolerance: bool = False,
  136. ) -> None:
  137. """Assert that `actual` and `expected` are close.
  138. Args:
  139. actual: Actual input.
  140. expected: Expected input.
  141. rtol: Relative tolerance.
  142. atol: Absolute tolerance.
  143. low_tolerance:
  144. This parameter allows to reduce tolerance. Half the decimal places.
  145. Example, 1e-4 -> 1e-2 or 1e-6 -> 1e-3
  146. """
  147. if hasattr(actual, "data"):
  148. actual = actual.data
  149. if hasattr(expected, "data"):
  150. expected = expected.data
  151. if "xla" in actual.device.type or "xla" in expected.device.type:
  152. rtol, atol = 1e-2, 1e-2
  153. if rtol is None and atol is None:
  154. actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0))
  155. expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0))
  156. rtol, atol = max(actual_rtol, expected_rtol), max(actual_atol, expected_atol)
  157. # halve the tolerance if `low_tolerance` is true
  158. rtol = math.sqrt(rtol) if low_tolerance else rtol
  159. atol = math.sqrt(atol) if low_tolerance else atol
  160. return assert_close(actual, expected, rtol=rtol, atol=atol)
  161. @staticmethod
  162. def gradcheck(
  163. func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
  164. inputs: Union[torch.Tensor, Sequence[torch.Tensor]],
  165. *,
  166. raise_exception: bool = True,
  167. fast_mode: bool = True,
  168. **kwargs: Any,
  169. ) -> bool:
  170. return gradcheck(func, inputs, raise_exception=raise_exception, fast_mode=fast_mode, **kwargs)
  171. def generate_two_view_random_scene(device: Optional[Device] = None, dtype: Dtype = torch.float32) -> Dict[str, Tensor]:
  172. from kornia.geometry import epipolar as epi
  173. if device is None:
  174. device = torch.device("cpu")
  175. num_views: int = 2
  176. num_points: int = 30
  177. scene: Dict[str, Tensor] = epi.generate_scene(num_views, num_points)
  178. # internal parameters (same K)
  179. K1 = scene["K"].to(device, dtype)
  180. K2 = K1.clone()
  181. # rotation
  182. R1 = scene["R"][0:1].to(device, dtype)
  183. R2 = scene["R"][1:2].to(device, dtype)
  184. # translation
  185. t1 = scene["t"][0:1].to(device, dtype)
  186. t2 = scene["t"][1:2].to(device, dtype)
  187. # projection matrix, P = K(R|t)
  188. P1 = scene["P"][0:1].to(device, dtype)
  189. P2 = scene["P"][1:2].to(device, dtype)
  190. # fundamental matrix
  191. F_mat = epi.fundamental_from_projections(P1[..., :3, :], P2[..., :3, :])
  192. F_mat = epi.normalize_transformation(F_mat)
  193. # points 3d
  194. X = scene["points3d"].to(device, dtype)
  195. # projected points
  196. x1 = scene["points2d"][0:1].to(device, dtype)
  197. x2 = scene["points2d"][1:2].to(device, dtype)
  198. return {
  199. "K1": K1,
  200. "K2": K2,
  201. "R1": R1,
  202. "R2": R2,
  203. "t1": t1,
  204. "t2": t2,
  205. "P1": P1,
  206. "P2": P2,
  207. "F": F_mat,
  208. "X": X,
  209. "x1": x1,
  210. "x2": x2,
  211. }
  212. def cartesian_product_of_parameters(**possible_parameters: Sequence[Any]) -> Iterator[Dict[str, Any]]:
  213. """Create cartesian product of given parameters."""
  214. parameter_names = possible_parameters.keys()
  215. possible_values = [possible_parameters[parameter_name] for parameter_name in parameter_names]
  216. for param_combination in product(*possible_values):
  217. yield dict(zip(parameter_names, param_combination))
  218. def default_with_one_parameter_changed(*, default: Optional[Dict[str, Any]] = None, **possible_parameters: Any) -> Any:
  219. if default is None:
  220. default = {}
  221. if not isinstance(default, dict):
  222. raise AssertionError(f"default should be a dict not a {type(default)}")
  223. for parameter_name, possible_values in possible_parameters.items():
  224. for v in possible_values:
  225. param_set = deepcopy(default)
  226. param_set[parameter_name] = v
  227. yield param_set
  228. def _get_precision(device: torch.device, dtype: Dtype) -> float:
  229. if "xla" in device.type:
  230. return 1e-2
  231. if dtype == torch.float16:
  232. return 1e-3
  233. return 1e-4
  234. def _get_precision_by_name(
  235. device: torch.device, device_target: str, tol_val: float, tol_val_default: float = 1e-4
  236. ) -> float:
  237. if device_target not in ["cpu", "cuda", "xla", "mps"]:
  238. raise ValueError(f"Invalid device name: {device_target}.")
  239. if device_target in device.type:
  240. return tol_val
  241. return tol_val_default
  242. def _default_tolerances(*inputs: Any) -> Tuple[float, float]:
  243. rtols, atols = zip(*[_DTYPE_PRECISIONS.get(torch.as_tensor(input_).dtype, (0.0, 0.0)) for input_ in inputs])
  244. return max(rtols), max(atols)
  245. def assert_close(
  246. actual: Tensor, expected: Tensor, *, rtol: Optional[float] = None, atol: Optional[float] = None, **kwargs: Any
  247. ) -> None:
  248. """Assert two tensors are similar within provided tolerance."""
  249. if rtol is None and atol is None:
  250. # `torch.testing.assert_close` used different default tolerances than `torch.testing.assert_allclose`.
  251. # TODO: remove this special handling as soon as https://github.com/kornia/kornia/issues/1134 is resolved
  252. # Basically, this whole wrapper function can be removed and `torch.testing.assert_close` can be used
  253. # directly.
  254. rtol, atol = _default_tolerances(actual, expected)
  255. return _assert_close(
  256. actual,
  257. expected,
  258. rtol=rtol,
  259. atol=atol,
  260. # this is the default value for torch>=1.10, but not for torch==1.9
  261. # TODO: remove this if kornia relies on torch>=1.10
  262. check_stride=False,
  263. equal_nan=False,
  264. **kwargs,
  265. )