helpers.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import importlib.util
  18. import platform
  19. import sys
  20. import warnings
  21. from dataclasses import asdict, fields, is_dataclass
  22. from functools import wraps
  23. from inspect import isclass, isfunction
  24. from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, overload
  25. import torch
  26. from torch.linalg import inv_ex
  27. from kornia.core import Tensor
  28. from kornia.utils._compat import torch_version_ge
  29. def xla_is_available() -> bool:
  30. """Return whether `torch_xla` is available in the system."""
  31. if importlib.util.find_spec("torch_xla") is not None:
  32. return True
  33. return False
  34. def is_mps_tensor_safe(x: Tensor) -> bool:
  35. """Return whether tensor is on MPS device."""
  36. return "mps" in str(x.device)
  37. def get_cuda_device_if_available(index: int = 0) -> torch.device:
  38. """Try to get cuda device, if fail, return cpu.
  39. Args:
  40. index: cuda device index
  41. Returns:
  42. torch.device
  43. """
  44. if torch.cuda.is_available():
  45. return torch.device(f"cuda:{index}")
  46. return torch.device("cpu")
  47. def get_mps_device_if_available() -> torch.device:
  48. """Try to get mps device, if fail, return cpu.
  49. Returns:
  50. torch.device
  51. """
  52. dev = "cpu"
  53. if hasattr(torch.backends, "mps"):
  54. if torch.backends.mps.is_available():
  55. dev = "mps"
  56. return torch.device(dev)
  57. def get_cuda_or_mps_device_if_available() -> torch.device:
  58. """Check OS and platform and run get_cuda_device_if_available or get_mps_device_if_available.
  59. Returns:
  60. torch.device
  61. """
  62. if sys.platform == "darwin" and platform.machine() == "arm64":
  63. return get_mps_device_if_available()
  64. else:
  65. return get_cuda_device_if_available()
  66. @overload
  67. def map_location_to_cpu(storage: Tensor, location: str) -> Tensor: ...
  68. @overload
  69. def map_location_to_cpu(storage: str) -> str: ...
  70. def map_location_to_cpu(storage: Union[str, Tensor], *args: Any, **kwargs: Any) -> Union[str, Tensor]:
  71. """Map location of device to CPU, util for loading things from HUB."""
  72. return storage
  73. def deprecated(
  74. replace_with: Optional[str] = None, version: Optional[str] = None, extra_reason: Optional[str] = None
  75. ) -> Any:
  76. """Mark methods as deprecated."""
  77. def _deprecated(func: Callable[..., Any]) -> Any:
  78. @wraps(func)
  79. def wrapper(*args: Any, **kwargs: Any) -> Any:
  80. name = ""
  81. beginning = f"Since kornia {version} the " if version is not None else ""
  82. if isclass(func):
  83. name = func.__class__.__name__
  84. if isfunction(func):
  85. name = func.__name__
  86. warnings.simplefilter("always", DeprecationWarning)
  87. if replace_with is not None:
  88. warnings.warn(
  89. f"{beginning}`{name}` is deprecated in favor of `{replace_with}`.{extra_reason}",
  90. category=DeprecationWarning,
  91. stacklevel=2,
  92. )
  93. else:
  94. warnings.warn(
  95. f"{beginning}`{name}` is deprecated and will be removed in the future versions.{extra_reason}",
  96. category=DeprecationWarning,
  97. stacklevel=2,
  98. )
  99. warnings.simplefilter("default", DeprecationWarning)
  100. return func(*args, **kwargs)
  101. return wrapper
  102. return _deprecated
  103. def _extract_device_dtype(tensor_list: List[Optional[Any]]) -> Tuple[torch.device, torch.dtype]:
  104. """Check if all the input are in the same device (only if when they are Tensor).
  105. If so, it would return a tuple of (device, dtype). Default: (cpu, ``get_default_dtype()``).
  106. Returns:
  107. [torch.device, torch.dtype]
  108. """
  109. device, dtype = None, None
  110. for tensor in tensor_list:
  111. if tensor is not None:
  112. if not isinstance(tensor, (Tensor,)):
  113. continue
  114. _device = tensor.device
  115. _dtype = tensor.dtype
  116. if device is None and dtype is None:
  117. device = _device
  118. dtype = _dtype
  119. elif device != _device or dtype != _dtype:
  120. raise ValueError(
  121. "Passed values are not in the same device and dtype."
  122. f"Got ({device}, {dtype}) and ({_device}, {_dtype})."
  123. )
  124. if device is None:
  125. # TODO: update this when having torch.get_default_device()
  126. device = torch.device("cpu")
  127. if dtype is None:
  128. dtype = torch.get_default_dtype()
  129. return (device, dtype)
  130. def _torch_inverse_cast(input: Tensor) -> Tensor:
  131. """Make torch.inverse work with other than fp32/64.
  132. The function torch.inverse is only implemented for fp32/64 which makes impossible to be used by fp16 or others. What
  133. this function does, is cast input data type to fp32, apply torch.inverse, and cast back to the input dtype.
  134. """
  135. if not isinstance(input, Tensor):
  136. raise AssertionError(f"Input must be Tensor. Got: {type(input)}.")
  137. dtype: torch.dtype = input.dtype
  138. if dtype not in (torch.float32, torch.float64):
  139. dtype = torch.float32
  140. return torch.linalg.inv(input.to(dtype)).to(input.dtype)
  141. def _torch_histc_cast(input: Tensor, bins: int, min: Union[float, bool], max: Union[float, bool]) -> Tensor:
  142. """Make torch.histc work with other than fp32/64.
  143. The function torch.histc is only implemented for fp32/64 which makes impossible to be used by fp16 or others. What
  144. this function does, is cast input data type to fp32, apply torch.inverse, and cast back to the input dtype.
  145. """
  146. if not isinstance(input, Tensor):
  147. raise AssertionError(f"Input must be Tensor. Got: {type(input)}.")
  148. dtype: torch.dtype = input.dtype
  149. if dtype not in (torch.float32, torch.float64):
  150. dtype = torch.float32
  151. return torch.histc(input.to(dtype), bins, min, max).to(input.dtype)
  152. def _torch_svd_cast(input: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
  153. """Make torch.svd work with other than fp32/64.
  154. The function torch.svd is only implemented for fp32/64 which makes
  155. impossible to be used by fp16 or others. What this function does, is cast
  156. input data type to fp32, apply torch.svd, and cast back to the input dtype.
  157. NOTE: in torch 1.8.1 this function is recommended to use as torch.linalg.svd
  158. """
  159. # if not isinstance(input, torch.Tensor):
  160. # raise AssertionError(f"Input must be torch.Tensor. Got: {type(input)}.")
  161. dtype = input.dtype
  162. if dtype not in (torch.float32, torch.float64):
  163. dtype = torch.float32
  164. out1, out2, out3H = torch.linalg.svd(input.to(dtype))
  165. if torch_version_ge(1, 11):
  166. out3 = out3H.mH
  167. else:
  168. out3 = out3H.transpose(-1, -2)
  169. return (out1.to(input.dtype), out2.to(input.dtype), out3.to(input.dtype))
  170. def _torch_linalg_svdvals(input: Tensor) -> Tensor:
  171. """Make torch.linalg.svdvals work with other than fp32/64.
  172. The function torch.svd is only implemented for fp32/64 which makes
  173. impossible to be used by fp16 or others. What this function does, is cast
  174. input data type to fp32, apply torch.svd, and cast back to the input dtype.
  175. NOTE: in torch 1.8.1 this function is recommended to use as torch.linalg.svd
  176. """
  177. if not isinstance(input, Tensor):
  178. raise AssertionError(f"Input must be Tensor. Got: {type(input)}.")
  179. dtype: torch.dtype = input.dtype
  180. if dtype not in (torch.float32, torch.float64):
  181. dtype = torch.float32
  182. if TYPE_CHECKING:
  183. # TODO: remove this branch when kornia relies on torch >= 1.10
  184. out: Tensor
  185. elif torch_version_ge(1, 10):
  186. out = torch.linalg.svdvals(input.to(dtype))
  187. else:
  188. # TODO: remove this branch when kornia relies on torch >= 1.10
  189. _, out, _ = torch.linalg.svd(input.to(dtype))
  190. return out.to(input.dtype)
  191. # TODO: return only `Tensor` and review all the calls to adjust
  192. def _torch_solve_cast(A: Tensor, B: Tensor) -> Tensor:
  193. """Make torch.solve work with other than fp32/64.
  194. For stable operation, the input matrices should be cast to fp64, and the output will
  195. be cast back to the input dtype. However, fp64 is not yet supported on MPS.
  196. """
  197. if is_mps_tensor_safe(A):
  198. dtype = torch.float32
  199. else:
  200. dtype = torch.float64
  201. out = torch.linalg.solve(A.to(dtype), B.to(dtype))
  202. # cast back to the input dtype
  203. return out.to(A.dtype)
  204. def safe_solve_with_mask(B: Tensor, A: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
  205. r"""Solves the system of equations.
  206. Avoids crashing because of singular matrix input and outputs the mask of valid solution.
  207. """
  208. if not torch_version_ge(1, 10):
  209. sol = _torch_solve_cast(A, B)
  210. warnings.warn("PyTorch version < 1.10, solve validness mask maybe not correct", RuntimeWarning, stacklevel=1)
  211. return sol, sol, torch.ones(len(A), dtype=torch.bool, device=A.device)
  212. # Based on https://github.com/pytorch/pytorch/issues/31546#issuecomment-694135622
  213. if not isinstance(B, Tensor):
  214. raise AssertionError(f"B must be Tensor. Got: {type(B)}.")
  215. dtype: torch.dtype = B.dtype
  216. if dtype not in (torch.float32, torch.float64):
  217. dtype = torch.float32
  218. if TYPE_CHECKING:
  219. # TODO: remove this branch when kornia relies on torch >= 1.13
  220. A_LU: Tensor
  221. pivots: Tensor
  222. info: Tensor
  223. elif torch_version_ge(1, 13):
  224. A_LU, pivots, info = torch.linalg.lu_factor_ex(A.to(dtype))
  225. else:
  226. # TODO: remove this branch when kornia relies on torch >= 1.13
  227. A_LU, pivots, info = torch.lu(A.to(dtype), True, get_infos=True)
  228. valid_mask: Tensor = info == 0
  229. n_dim_B = len(B.shape)
  230. n_dim_A = len(A.shape)
  231. if n_dim_A - n_dim_B == 1:
  232. B = B.unsqueeze(-1)
  233. if TYPE_CHECKING:
  234. # TODO: remove this branch when kornia relies on torch >= 1.13
  235. X: Tensor
  236. elif torch_version_ge(1, 13):
  237. X = torch.linalg.lu_solve(A_LU, pivots, B.to(dtype))
  238. else:
  239. # TODO: remove this branch when kornia relies on torch >= 1.13
  240. X = torch.lu_solve(B.to(dtype), A_LU, pivots)
  241. return X.to(B.dtype), A_LU.to(A.dtype), valid_mask
  242. def safe_inverse_with_mask(A: Tensor) -> Tuple[Tensor, Tensor]:
  243. r"""Perform inverse.
  244. Avoids crashing because of non-invertable matrix input and outputs the mask of valid solution.
  245. """
  246. if not isinstance(A, Tensor):
  247. raise AssertionError(f"A must be Tensor. Got: {type(A)}.")
  248. dtype_original = A.dtype
  249. if dtype_original not in (torch.float32, torch.float64):
  250. dtype = torch.float32
  251. else:
  252. dtype = dtype_original
  253. inverse, info = inv_ex(A.to(dtype))
  254. mask = info == 0
  255. return inverse.to(dtype_original), mask
  256. def is_autocast_enabled(both: bool = True) -> bool:
  257. """Check if torch autocast is enabled.
  258. Args:
  259. both: if True will consider autocast region for both types of devices
  260. Returns:
  261. Return a Bool,
  262. will always return False for a torch without support, otherwise will be: if both is True
  263. `torch.is_autocast_enabled() or torch.is_autocast_enabled('cpu')`. If both is False will return just
  264. `torch.is_autocast_enabled()`.
  265. """
  266. if TYPE_CHECKING:
  267. # TODO: remove this branch when kornia relies on torch >= 1.10.2
  268. return False
  269. if not torch_version_ge(1, 10, 2):
  270. return False
  271. if both:
  272. if torch_version_ge(2, 4):
  273. return torch.is_autocast_enabled() or torch.is_autocast_enabled("cpu")
  274. else:
  275. return torch.is_autocast_enabled() or torch.is_autocast_cpu_enabled()
  276. return torch.is_autocast_enabled()
  277. def dataclass_to_dict(obj: Any) -> Any:
  278. """Recursively convert dataclass instances to dictionaries."""
  279. if is_dataclass(obj) and not isinstance(obj, type):
  280. return {key: dataclass_to_dict(value) for key, value in asdict(obj).items()}
  281. elif isinstance(obj, (list, tuple)):
  282. return type(obj)(dataclass_to_dict(item) for item in obj)
  283. elif isinstance(obj, dict):
  284. return {key: dataclass_to_dict(value) for key, value in obj.items()}
  285. else:
  286. return obj
  287. T = TypeVar("T")
  288. def dict_to_dataclass(dict_obj: Dict[str, Any], dataclass_type: Type[T]) -> T:
  289. """Recursively convert dictionaries to dataclass instances."""
  290. if not isinstance(dict_obj, dict):
  291. raise TypeError("Input conf must be dict")
  292. if not is_dataclass(dataclass_type):
  293. raise TypeError("dataclass_type must be a dataclass")
  294. field_types: dict[str, Any] = {f.name: f.type for f in fields(dataclass_type)}
  295. constructor_args = {}
  296. for key, value in dict_obj.items():
  297. if key in field_types and is_dataclass(field_types[key]):
  298. constructor_args[key] = dict_to_dataclass(value, field_types[key])
  299. else:
  300. constructor_args[key] = value
  301. # TODO: remove type ignore when https://github.com/python/mypy/issues/14941 be andressed
  302. return dataclass_type(**constructor_args)