utils.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. import json
  19. import logging
  20. import os
  21. import pprint
  22. from typing import Any, Optional
  23. import kornia
  24. from kornia.config import kornia_config
  25. from kornia.core.external import numpy as np
  26. from kornia.core.external import onnx, requests
  27. from kornia.utils.download import CachedDownloader
  28. __all__ = ["ONNXLoader", "add_metadata", "io_name_conversion"]
  29. logger = logging.getLogger(__name__)
  30. class ONNXLoader(CachedDownloader):
  31. """Manages ONNX models, handling local caching, downloading from Hugging Face, and loading models."""
  32. @classmethod
  33. def load_config(cls, url: str, download: bool = True, **kwargs: Any) -> dict[str, Any]:
  34. """Load JSON config from the specified URL.
  35. Args:
  36. url: The URL of the preprocessor config to load.
  37. download: If True, the config will be downloaded if it's not already in the local cache.
  38. kwargs: Additional download arguments.
  39. Returns:
  40. dict[str, Any]: The loaded preprocessor config.
  41. """
  42. if url.startswith(("http:", "https:")):
  43. file_path = cls.download_to_cache(
  44. url,
  45. os.path.split(url)[-1],
  46. download=download,
  47. suffix=".json",
  48. **kwargs,
  49. )
  50. with open(file_path) as f:
  51. json_data = json.load(f)
  52. return json_data
  53. if not download:
  54. raise RuntimeError(f"File `{url}` not found. You may set `download=True`.")
  55. raise RuntimeError(f"File `{file_path}` not found.")
  56. @classmethod
  57. def load_model(cls, model_name: str, download: bool = True, with_data: bool = False, **kwargs) -> onnx.ModelProto: # type:ignore
  58. """Load an ONNX model from the local cache or downloads it from Hugging Face if necessary.
  59. Args:
  60. model_name: The name of the ONNX model or operator. For Hugging Face-hosted models,
  61. use the format 'hf://model_name'. Valid `model_name` can be found on
  62. https://huggingface.co/kornia/ONNX_models.
  63. Or a URL to the ONNX model.
  64. download: If True, the model will be downloaded from Hugging Face if it's not already in the local cache.
  65. cache_dir: The directory where the model should be cached.
  66. Defaults to None, which will use a default `{kornia_config.hub_onnx_dir}` directory.
  67. with_data: If True, the model will be loaded with its `.onnx_data` weights.
  68. **kwargs: Additional arguments to pass to the download method, if needed.
  69. Returns:
  70. onnx.ModelProto: The loaded ONNX model.
  71. """
  72. if model_name.startswith("hf://"):
  73. model_name = model_name[len("hf://") :]
  74. url = f"https://huggingface.co/kornia/ONNX_models/resolve/main/{model_name}.onnx"
  75. cache_dir = kwargs.get("cache_dir", None) or os.path.join(
  76. kornia_config.hub_onnx_dir, model_name.split("/")[0]
  77. )
  78. kwargs.update({"cache_dir": cache_dir})
  79. file_path = cls.download_to_cache(
  80. url, model_name.split("/")[1], download=download, suffix=".onnx", **kwargs
  81. )
  82. if with_data:
  83. url_data = f"https://huggingface.co/kornia/ONNX_models/resolve/main/{model_name}.onnx_data"
  84. cls.download_to_cache(
  85. url_data, model_name.split("/")[1], download=download, suffix=".onnx_data", **kwargs
  86. )
  87. return onnx.load(file_path) # type:ignore
  88. elif model_name.startswith("http://") or model_name.startswith("https://"):
  89. cache_dir = kwargs.get("cache_dir", None) or kornia_config.hub_onnx_dir
  90. kwargs.update({"cache_dir": cache_dir})
  91. file_path = cls.download_to_cache(
  92. model_name,
  93. os.path.split(model_name)[-1],
  94. download=download,
  95. suffix=".onnx",
  96. **kwargs,
  97. )
  98. if with_data:
  99. url_data = model_name[:-5] + ".onnx_data"
  100. cls.download_to_cache(
  101. url_data,
  102. os.path.split(model_name)[-1][:-5] + ".onnx_data",
  103. download=download,
  104. suffix=".onnx_data",
  105. **kwargs,
  106. )
  107. return onnx.load(file_path) # type:ignore
  108. elif os.path.exists(model_name):
  109. return onnx.load(model_name) # type:ignore
  110. raise ValueError(f"File {model_name} not found")
  111. @staticmethod
  112. def _fetch_repo_contents(folder: str) -> list[dict[str, Any]]:
  113. """Fetch the contents of the Hugging Face repository using the Hugging Face API.
  114. Returns:
  115. A list of all files in the repository as dictionaries containing file details.
  116. """
  117. url = f"https://huggingface.co/api/models/kornia/ONNX_models/tree/main/{folder}"
  118. response = requests.get(url, timeout=10) # type:ignore
  119. if response.status_code == 200:
  120. return response.json() # Returns the JSON content of the repo
  121. else:
  122. raise ValueError(f"Failed to fetch repository contents: {response.status_code}")
  123. @classmethod
  124. def list_operators(cls) -> None:
  125. """List all available ONNX operators in the 'operators' folder of the Hugging Face repository."""
  126. repo_contents = cls._fetch_repo_contents("operators")
  127. # Filter for operators in the 'operators' directory
  128. operators = [file["path"] for file in repo_contents]
  129. pprint.pp(operators)
  130. @classmethod
  131. def list_models(cls) -> None:
  132. """List all available ONNX models in the 'models' folder of the Hugging Face repository."""
  133. repo_contents = cls._fetch_repo_contents("models")
  134. # Filter for models in the 'models' directory
  135. models = [file["path"] for file in repo_contents]
  136. pprint.pp(models)
  137. def io_name_conversion(
  138. onnx_model: onnx.ModelProto, # type:ignore
  139. io_name_mapping: dict[str, str],
  140. ) -> onnx.ModelProto: # type:ignore
  141. """Convert the input and output names of an ONNX model to 'input' and 'output'.
  142. Args:
  143. onnx_model: The ONNX model to convert.
  144. io_name_mapping: A dictionary mapping the original input and output names to the new ones.
  145. """
  146. # Convert I/O nodes
  147. for i in range(len(onnx_model.graph.input)):
  148. in_name = onnx_model.graph.input[i].name
  149. if in_name in io_name_mapping:
  150. onnx_model.graph.input[i].name = io_name_mapping[in_name]
  151. for i in range(len(onnx_model.graph.output)):
  152. out_name = onnx_model.graph.output[i].name
  153. if out_name in io_name_mapping:
  154. onnx_model.graph.output[i].name = io_name_mapping[out_name]
  155. # Convert intermediate nodes
  156. for i in range(len(onnx_model.graph.node)):
  157. for j in range(len(onnx_model.graph.node[i].input)):
  158. if onnx_model.graph.node[i].input[j] in io_name_mapping:
  159. onnx_model.graph.node[i].input[j] = io_name_mapping[in_name]
  160. for j in range(len(onnx_model.graph.node[i].output)):
  161. if onnx_model.graph.node[i].output[j] in io_name_mapping:
  162. onnx_model.graph.node[i].output[j] = io_name_mapping[out_name]
  163. return onnx_model
  164. def add_metadata(
  165. onnx_model: onnx.ModelProto, # type: ignore
  166. additional_metadata: Optional[list[tuple[str, str]]] = None,
  167. ) -> onnx.ModelProto: # type: ignore
  168. """Add metadata to an ONNX model.
  169. The metadata includes the source library (set to "kornia"), the version of kornia,
  170. and any additional metadata provided as a list of key-value pairs.
  171. Args:
  172. onnx_model: The ONNX model to add metadata to.
  173. additional_metadata: A list of tuples, where each tuple contains a key and a value
  174. for the additional metadata to add to the ONNX model.
  175. Returns:
  176. The ONNX model with the added metadata.
  177. """
  178. if additional_metadata is None:
  179. additional_metadata = []
  180. for key, value in [
  181. ("source", "kornia"),
  182. ("version", kornia.__version__),
  183. *additional_metadata,
  184. ]:
  185. metadata_props = onnx_model.metadata_props.add()
  186. metadata_props.key = key
  187. metadata_props.value = str(value)
  188. return onnx_model
  189. def onnx_type_to_numpy(onnx_type: str) -> Any:
  190. type_mapping = {
  191. "tensor(float)": np.float32,
  192. "tensor(float16)": np.float16,
  193. "tensor(double)": np.float64,
  194. "tensor(int32)": np.int32,
  195. "tensor(int64)": np.int64,
  196. "tensor(uint8)": np.uint8,
  197. "tensor(int8)": np.int8,
  198. "tensor(bool)": np.bool_,
  199. }
  200. if onnx_type not in type_mapping:
  201. raise TypeError(f"ONNX type {onnx_type} not understood")
  202. return type_mapping[onnx_type]