_compat.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import TYPE_CHECKING, Any, Callable, ContextManager, List, Optional, Tuple, TypeVar
  18. import torch
  19. from packaging import version
  20. from torch import Tensor
  21. def torch_version() -> str:
  22. """Parse the `torch.__version__` variable and removes +cu*/cpu."""
  23. return torch.__version__.partition("+")[0]
  24. def torch_version_lt(major: int, minor: int, patch: int) -> bool:
  25. _version = version.parse(torch_version())
  26. return _version < version.parse(f"{major}.{minor}.{patch}")
  27. def torch_version_le(major: int, minor: int, patch: int) -> bool:
  28. _version = version.parse(torch_version())
  29. return _version <= version.parse(f"{major}.{minor}.{patch}")
  30. def torch_version_ge(major: int, minor: int, patch: Optional[int] = None) -> bool:
  31. _version = version.parse(torch_version())
  32. if patch is None:
  33. return _version >= version.parse(f"{major}.{minor}")
  34. else:
  35. return _version >= version.parse(f"{major}.{minor}.{patch}")
  36. if TYPE_CHECKING:
  37. # TODO: remove this branch when kornia relies on torch >= 1.10.0
  38. def torch_meshgrid(tensors: List[Tensor], indexing: Optional[str] = None) -> Tuple[Tensor, ...]: ...
  39. elif torch_version_ge(1, 10, 0):
  40. def torch_meshgrid(tensors: List[Tensor], indexing: str):
  41. return torch.meshgrid(tensors, indexing=indexing)
  42. else:
  43. # TODO: remove this branch when kornia relies on torch >= 1.10.0
  44. def torch_meshgrid(tensors: List[Tensor], indexing: str):
  45. return torch.meshgrid(tensors)
  46. if TYPE_CHECKING:
  47. # TODO: remove this branch when kornia relies on torch >= 1.10.0
  48. _T = TypeVar("_T")
  49. torch_inference_mode: Callable[..., ContextManager[_T]]
  50. elif torch_version_ge(1, 10, 0):
  51. torch_inference_mode = torch.inference_mode
  52. else:
  53. # TODO: remove this branch when kornia relies on torch >= 1.10.0
  54. torch_inference_mode = torch.no_grad
  55. if TYPE_CHECKING: # TODO (@johnnv1): remove this branch when bump the pytorch CI to support torch 2.4
  56. custom_fwd: Callable[..., Any]
  57. autocast: Callable[..., Any]
  58. elif torch_version_ge(2, 4):
  59. from functools import partial
  60. from torch.amp import autocast as _autocast
  61. from torch.amp import custom_fwd as _custom_fwd
  62. custom_fwd = partial(_custom_fwd, device_type="cuda")
  63. autocast = partial(_autocast, "cuda")
  64. else:
  65. custom_fwd = torch.cuda.amp.custom_fwd
  66. autocast = torch.cuda.amp.autocast