# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from enum import Enum, EnumMeta from typing import Iterator, Type, TypeVar, Union import torch from kornia.core import Tensor __all__ = ["BorderType", "DType", "Resample", "SamplePadding", "TKEnum", "pi"] pi = torch.tensor(3.14159265358979323846) T = TypeVar("T", bound=Enum) TKEnum = Union[str, int, T] class _KORNIA_EnumMeta(EnumMeta): def __iter__(self) -> Iterator[Enum]: # type: ignore[override] return super().__iter__() def __contains__(self, other: TKEnum[Enum]) -> bool: # type: ignore[override] if isinstance(other, str): return any(val.name.upper() == other.upper() for val in self) elif isinstance(other, int): return any(val.value == other for val in self) return any(val == other for val in self) def __repr__(self) -> str: return " | ".join(f"{self.__name__}.{val.name}" for val in self) def _get(cls: Type[T], value: TKEnum[T]) -> T: if isinstance(value, str): return cls[value.upper()] elif isinstance(value, int): return cls(value) elif isinstance(value, cls): return value raise TypeError( f"The `.get` method from `{cls}` expects a value with type `str`, `int` or `{cls}`. Gotcha {type(value)}" ) class Resample(Enum, metaclass=_KORNIA_EnumMeta): NEAREST = 0 BILINEAR = 1 BICUBIC = 2 @classmethod def get(cls, value: TKEnum["Resample"]) -> "Resample": return _get(cls, value) class BorderType(Enum, metaclass=_KORNIA_EnumMeta): CONSTANT = 0 REFLECT = 1 REPLICATE = 2 CIRCULAR = 3 @classmethod def get(cls, value: TKEnum["BorderType"]) -> "BorderType": return _get(cls, value) class SamplePadding(Enum, metaclass=_KORNIA_EnumMeta): ZEROS = 0 BORDER = 1 REFLECTION = 2 @classmethod def get(cls, value: TKEnum["SamplePadding"]) -> "SamplePadding": return _get(cls, value) class DType(Enum, metaclass=_KORNIA_EnumMeta): INT64 = 0 FLOAT16 = 1 FLOAT32 = 2 FLOAT64 = 3 @classmethod def get(cls, value: Union[str, int, torch.dtype, Tensor, "DType"]) -> "DType": if isinstance(value, torch.dtype): return cls[str(value).upper()[6:]] elif isinstance(value, Tensor): return cls(int(value.item())) elif isinstance(value, str): return cls[value.upper()] elif isinstance(value, int): return cls(value) elif isinstance(value, cls): return value raise TypeError(f"Invalid identifier {value} with type {type(value)}.") @classmethod def to_torch(cls, value: TKEnum["DType"]) -> torch.dtype: data = cls.get(value=value) if data == DType.INT64: return torch.long elif data == DType.FLOAT16: return torch.float16 elif data == DType.FLOAT32: return torch.float32 elif data == DType.FLOAT64: return torch.float64 raise ValueError # TODO: (low-priority) add INPUT3D, MASK3D, BBOX3D, LAFs etc. class DataKey(Enum, metaclass=_KORNIA_EnumMeta): IMAGE = 0 INPUT = 0 MASK = 1 BBOX = 2 BBOX_XYXY = 3 BBOX_XYWH = 4 KEYPOINTS = 5 LABEL = 6 CLASS = 6 @classmethod def get(cls, value: TKEnum["DataKey"]) -> "DataKey": return _get(cls, value)