# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import TYPE_CHECKING, Any, Callable, ContextManager, List, Optional, Tuple, TypeVar import torch from packaging import version from torch import Tensor def torch_version() -> str: """Parse the `torch.__version__` variable and removes +cu*/cpu.""" return torch.__version__.partition("+")[0] def torch_version_lt(major: int, minor: int, patch: int) -> bool: _version = version.parse(torch_version()) return _version < version.parse(f"{major}.{minor}.{patch}") def torch_version_le(major: int, minor: int, patch: int) -> bool: _version = version.parse(torch_version()) return _version <= version.parse(f"{major}.{minor}.{patch}") def torch_version_ge(major: int, minor: int, patch: Optional[int] = None) -> bool: _version = version.parse(torch_version()) if patch is None: return _version >= version.parse(f"{major}.{minor}") else: return _version >= version.parse(f"{major}.{minor}.{patch}") if TYPE_CHECKING: # TODO: remove this branch when kornia relies on torch >= 1.10.0 def torch_meshgrid(tensors: List[Tensor], indexing: Optional[str] = None) -> Tuple[Tensor, ...]: ... elif torch_version_ge(1, 10, 0): def torch_meshgrid(tensors: List[Tensor], indexing: str): return torch.meshgrid(tensors, indexing=indexing) else: # TODO: remove this branch when kornia relies on torch >= 1.10.0 def torch_meshgrid(tensors: List[Tensor], indexing: str): return torch.meshgrid(tensors) if TYPE_CHECKING: # TODO: remove this branch when kornia relies on torch >= 1.10.0 _T = TypeVar("_T") torch_inference_mode: Callable[..., ContextManager[_T]] elif torch_version_ge(1, 10, 0): torch_inference_mode = torch.inference_mode else: # TODO: remove this branch when kornia relies on torch >= 1.10.0 torch_inference_mode = torch.no_grad if TYPE_CHECKING: # TODO (@johnnv1): remove this branch when bump the pytorch CI to support torch 2.4 custom_fwd: Callable[..., Any] autocast: Callable[..., Any] elif torch_version_ge(2, 4): from functools import partial from torch.amp import autocast as _autocast from torch.amp import custom_fwd as _custom_fwd custom_fwd = partial(_custom_fwd, device_type="cuda") autocast = partial(_autocast, "cuda") else: custom_fwd = torch.cuda.amp.custom_fwd autocast = torch.cuda.amp.autocast