config.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import os
  18. from dataclasses import dataclass, field
  19. from enum import Enum
  20. __all__ = ["InstallationMode", "kornia_config"]
  21. class InstallationMode(str, Enum):
  22. # Ask the user if to install the dependencies
  23. ASK = "ASK"
  24. # Install the dependencies
  25. AUTO = "AUTO"
  26. # Raise an error if the dependencies are not installed
  27. RAISE = "RAISE"
  28. def __eq__(self, other: object) -> bool:
  29. if isinstance(other, str):
  30. return self.value.lower() == other.lower() # Case-insensitive comparison
  31. return super().__eq__(other)
  32. class LazyLoaderConfig:
  33. _installation_mode: InstallationMode = InstallationMode.ASK
  34. @property
  35. def installation_mode(self) -> InstallationMode:
  36. return self._installation_mode
  37. @installation_mode.setter
  38. def installation_mode(self, value: str) -> None:
  39. # Allow setting via string by converting to the Enum
  40. if isinstance(value, str):
  41. try:
  42. self._installation_mode = InstallationMode(value.upper())
  43. except ValueError:
  44. raise ValueError(
  45. f"{value} is not a valid InstallationMode. Choose from: {list(InstallationMode)}"
  46. ) from None
  47. elif isinstance(value, InstallationMode):
  48. self._installation_mode = value
  49. else:
  50. raise TypeError("installation_mode must be a string or InstallationMode Enum.")
  51. @dataclass
  52. class KorniaConfig:
  53. hub_models_dir: str
  54. hub_onnx_dir: str
  55. output_dir: str = "kornia_outputs"
  56. hub_cache_dir: str = ".kornia_hub"
  57. lazyloader: LazyLoaderConfig = field(default_factory=LazyLoaderConfig)
  58. kornia_config = KorniaConfig(
  59. hub_models_dir=os.path.join(".kornia_hub", "models"), hub_onnx_dir=os.path.join(".kornia_hub", "onnx_models")
  60. )