base.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. import os
  19. from abc import ABC, abstractmethod
  20. from typing import Any, Generic, Optional, TypeVar, cast
  21. import torch
  22. from kornia.core import Module
  23. ModelConfig = TypeVar("ModelConfig")
  24. class ModelBase(ABC, Module, Generic[ModelConfig]):
  25. """Abstract model class with some utilities function."""
  26. def load_checkpoint(self, checkpoint: str, device: Optional[torch.device] = None) -> None:
  27. """Load checkpoint from a given url or file.
  28. Args:
  29. checkpoint: The url or filepath for the respective checkpoint
  30. device: The desired device to load the weights and move the model
  31. """
  32. if os.path.isfile(checkpoint):
  33. with open(checkpoint, "rb") as f:
  34. state_dict = torch.load(f, map_location=device)
  35. else:
  36. state_dict = torch.hub.load_state_dict_from_url(checkpoint, map_location=device)
  37. self.load_state_dict(state_dict)
  38. @staticmethod
  39. @abstractmethod
  40. def from_config(config: ModelConfig) -> ModelBase[ModelConfig]:
  41. """Build/load the model.
  42. Args:
  43. config: The specifications for the model be build/loaded
  44. """
  45. raise NotImplementedError
  46. def compile(
  47. self,
  48. *,
  49. fullgraph: bool = False,
  50. dynamic: bool = False,
  51. backend: str = "inductor",
  52. mode: Optional[str] = None,
  53. options: Optional[dict[Any, Any]] = None,
  54. disable: bool = False,
  55. ) -> ModelBase[ModelConfig]:
  56. compiled = torch.compile(
  57. self, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode, options=options, disable=disable
  58. )
  59. compiled = cast(ModelBase[ModelConfig], compiled)
  60. return compiled