| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from enum import Enum, EnumMeta
- from typing import Iterator, Type, TypeVar, Union
- import torch
- from kornia.core import Tensor
- __all__ = ["BorderType", "DType", "Resample", "SamplePadding", "TKEnum", "pi"]
- pi = torch.tensor(3.14159265358979323846)
- T = TypeVar("T", bound=Enum)
- TKEnum = Union[str, int, T]
- class _KORNIA_EnumMeta(EnumMeta):
- def __iter__(self) -> Iterator[Enum]: # type: ignore[override]
- return super().__iter__()
- def __contains__(self, other: TKEnum[Enum]) -> bool: # type: ignore[override]
- if isinstance(other, str):
- return any(val.name.upper() == other.upper() for val in self)
- elif isinstance(other, int):
- return any(val.value == other for val in self)
- return any(val == other for val in self)
- def __repr__(self) -> str:
- return " | ".join(f"{self.__name__}.{val.name}" for val in self)
- def _get(cls: Type[T], value: TKEnum[T]) -> T:
- if isinstance(value, str):
- return cls[value.upper()]
- elif isinstance(value, int):
- return cls(value)
- elif isinstance(value, cls):
- return value
- raise TypeError(
- f"The `.get` method from `{cls}` expects a value with type `str`, `int` or `{cls}`. Gotcha {type(value)}"
- )
- class Resample(Enum, metaclass=_KORNIA_EnumMeta):
- NEAREST = 0
- BILINEAR = 1
- BICUBIC = 2
- @classmethod
- def get(cls, value: TKEnum["Resample"]) -> "Resample":
- return _get(cls, value)
- class BorderType(Enum, metaclass=_KORNIA_EnumMeta):
- CONSTANT = 0
- REFLECT = 1
- REPLICATE = 2
- CIRCULAR = 3
- @classmethod
- def get(cls, value: TKEnum["BorderType"]) -> "BorderType":
- return _get(cls, value)
- class SamplePadding(Enum, metaclass=_KORNIA_EnumMeta):
- ZEROS = 0
- BORDER = 1
- REFLECTION = 2
- @classmethod
- def get(cls, value: TKEnum["SamplePadding"]) -> "SamplePadding":
- return _get(cls, value)
- class DType(Enum, metaclass=_KORNIA_EnumMeta):
- INT64 = 0
- FLOAT16 = 1
- FLOAT32 = 2
- FLOAT64 = 3
- @classmethod
- def get(cls, value: Union[str, int, torch.dtype, Tensor, "DType"]) -> "DType":
- if isinstance(value, torch.dtype):
- return cls[str(value).upper()[6:]]
- elif isinstance(value, Tensor):
- return cls(int(value.item()))
- elif isinstance(value, str):
- return cls[value.upper()]
- elif isinstance(value, int):
- return cls(value)
- elif isinstance(value, cls):
- return value
- raise TypeError(f"Invalid identifier {value} with type {type(value)}.")
- @classmethod
- def to_torch(cls, value: TKEnum["DType"]) -> torch.dtype:
- data = cls.get(value=value)
- if data == DType.INT64:
- return torch.long
- elif data == DType.FLOAT16:
- return torch.float16
- elif data == DType.FLOAT32:
- return torch.float32
- elif data == DType.FLOAT64:
- return torch.float64
- raise ValueError
- # TODO: (low-priority) add INPUT3D, MASK3D, BBOX3D, LAFs etc.
- class DataKey(Enum, metaclass=_KORNIA_EnumMeta):
- IMAGE = 0
- INPUT = 0
- MASK = 1
- BBOX = 2
- BBOX_XYXY = 3
- BBOX_XYWH = 4
- KEYPOINTS = 5
- LABEL = 6
- CLASS = 6
- @classmethod
- def get(cls, value: TKEnum["DataKey"]) -> "DataKey":
- return _get(cls, value)
|