constants.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from enum import Enum, EnumMeta
  18. from typing import Iterator, Type, TypeVar, Union
  19. import torch
  20. from kornia.core import Tensor
  21. __all__ = ["BorderType", "DType", "Resample", "SamplePadding", "TKEnum", "pi"]
  22. pi = torch.tensor(3.14159265358979323846)
  23. T = TypeVar("T", bound=Enum)
  24. TKEnum = Union[str, int, T]
  25. class _KORNIA_EnumMeta(EnumMeta):
  26. def __iter__(self) -> Iterator[Enum]: # type: ignore[override]
  27. return super().__iter__()
  28. def __contains__(self, other: TKEnum[Enum]) -> bool: # type: ignore[override]
  29. if isinstance(other, str):
  30. return any(val.name.upper() == other.upper() for val in self)
  31. elif isinstance(other, int):
  32. return any(val.value == other for val in self)
  33. return any(val == other for val in self)
  34. def __repr__(self) -> str:
  35. return " | ".join(f"{self.__name__}.{val.name}" for val in self)
  36. def _get(cls: Type[T], value: TKEnum[T]) -> T:
  37. if isinstance(value, str):
  38. return cls[value.upper()]
  39. elif isinstance(value, int):
  40. return cls(value)
  41. elif isinstance(value, cls):
  42. return value
  43. raise TypeError(
  44. f"The `.get` method from `{cls}` expects a value with type `str`, `int` or `{cls}`. Gotcha {type(value)}"
  45. )
  46. class Resample(Enum, metaclass=_KORNIA_EnumMeta):
  47. NEAREST = 0
  48. BILINEAR = 1
  49. BICUBIC = 2
  50. @classmethod
  51. def get(cls, value: TKEnum["Resample"]) -> "Resample":
  52. return _get(cls, value)
  53. class BorderType(Enum, metaclass=_KORNIA_EnumMeta):
  54. CONSTANT = 0
  55. REFLECT = 1
  56. REPLICATE = 2
  57. CIRCULAR = 3
  58. @classmethod
  59. def get(cls, value: TKEnum["BorderType"]) -> "BorderType":
  60. return _get(cls, value)
  61. class SamplePadding(Enum, metaclass=_KORNIA_EnumMeta):
  62. ZEROS = 0
  63. BORDER = 1
  64. REFLECTION = 2
  65. @classmethod
  66. def get(cls, value: TKEnum["SamplePadding"]) -> "SamplePadding":
  67. return _get(cls, value)
  68. class DType(Enum, metaclass=_KORNIA_EnumMeta):
  69. INT64 = 0
  70. FLOAT16 = 1
  71. FLOAT32 = 2
  72. FLOAT64 = 3
  73. @classmethod
  74. def get(cls, value: Union[str, int, torch.dtype, Tensor, "DType"]) -> "DType":
  75. if isinstance(value, torch.dtype):
  76. return cls[str(value).upper()[6:]]
  77. elif isinstance(value, Tensor):
  78. return cls(int(value.item()))
  79. elif isinstance(value, str):
  80. return cls[value.upper()]
  81. elif isinstance(value, int):
  82. return cls(value)
  83. elif isinstance(value, cls):
  84. return value
  85. raise TypeError(f"Invalid identifier {value} with type {type(value)}.")
  86. @classmethod
  87. def to_torch(cls, value: TKEnum["DType"]) -> torch.dtype:
  88. data = cls.get(value=value)
  89. if data == DType.INT64:
  90. return torch.long
  91. elif data == DType.FLOAT16:
  92. return torch.float16
  93. elif data == DType.FLOAT32:
  94. return torch.float32
  95. elif data == DType.FLOAT64:
  96. return torch.float64
  97. raise ValueError
  98. # TODO: (low-priority) add INPUT3D, MASK3D, BBOX3D, LAFs etc.
  99. class DataKey(Enum, metaclass=_KORNIA_EnumMeta):
  100. IMAGE = 0
  101. INPUT = 0
  102. MASK = 1
  103. BBOX = 2
  104. BBOX_XYXY = 3
  105. BBOX_XYWH = 4
  106. KEYPOINTS = 5
  107. LABEL = 6
  108. CLASS = 6
  109. @classmethod
  110. def get(cls, value: TKEnum["DataKey"]) -> "DataKey":
  111. return _get(cls, value)