| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from __future__ import annotations
- import json
- import logging
- import os
- import pprint
- from typing import Any, Optional
- import kornia
- from kornia.config import kornia_config
- from kornia.core.external import numpy as np
- from kornia.core.external import onnx, requests
- from kornia.utils.download import CachedDownloader
- __all__ = ["ONNXLoader", "add_metadata", "io_name_conversion"]
- logger = logging.getLogger(__name__)
- class ONNXLoader(CachedDownloader):
- """Manages ONNX models, handling local caching, downloading from Hugging Face, and loading models."""
- @classmethod
- def load_config(cls, url: str, download: bool = True, **kwargs: Any) -> dict[str, Any]:
- """Load JSON config from the specified URL.
- Args:
- url: The URL of the preprocessor config to load.
- download: If True, the config will be downloaded if it's not already in the local cache.
- kwargs: Additional download arguments.
- Returns:
- dict[str, Any]: The loaded preprocessor config.
- """
- if url.startswith(("http:", "https:")):
- file_path = cls.download_to_cache(
- url,
- os.path.split(url)[-1],
- download=download,
- suffix=".json",
- **kwargs,
- )
- with open(file_path) as f:
- json_data = json.load(f)
- return json_data
- if not download:
- raise RuntimeError(f"File `{url}` not found. You may set `download=True`.")
- raise RuntimeError(f"File `{file_path}` not found.")
- @classmethod
- def load_model(cls, model_name: str, download: bool = True, with_data: bool = False, **kwargs) -> onnx.ModelProto: # type:ignore
- """Load an ONNX model from the local cache or downloads it from Hugging Face if necessary.
- Args:
- model_name: The name of the ONNX model or operator. For Hugging Face-hosted models,
- use the format 'hf://model_name'. Valid `model_name` can be found on
- https://huggingface.co/kornia/ONNX_models.
- Or a URL to the ONNX model.
- download: If True, the model will be downloaded from Hugging Face if it's not already in the local cache.
- cache_dir: The directory where the model should be cached.
- Defaults to None, which will use a default `{kornia_config.hub_onnx_dir}` directory.
- with_data: If True, the model will be loaded with its `.onnx_data` weights.
- **kwargs: Additional arguments to pass to the download method, if needed.
- Returns:
- onnx.ModelProto: The loaded ONNX model.
- """
- if model_name.startswith("hf://"):
- model_name = model_name[len("hf://") :]
- url = f"https://huggingface.co/kornia/ONNX_models/resolve/main/{model_name}.onnx"
- cache_dir = kwargs.get("cache_dir", None) or os.path.join(
- kornia_config.hub_onnx_dir, model_name.split("/")[0]
- )
- kwargs.update({"cache_dir": cache_dir})
- file_path = cls.download_to_cache(
- url, model_name.split("/")[1], download=download, suffix=".onnx", **kwargs
- )
- if with_data:
- url_data = f"https://huggingface.co/kornia/ONNX_models/resolve/main/{model_name}.onnx_data"
- cls.download_to_cache(
- url_data, model_name.split("/")[1], download=download, suffix=".onnx_data", **kwargs
- )
- return onnx.load(file_path) # type:ignore
- elif model_name.startswith("http://") or model_name.startswith("https://"):
- cache_dir = kwargs.get("cache_dir", None) or kornia_config.hub_onnx_dir
- kwargs.update({"cache_dir": cache_dir})
- file_path = cls.download_to_cache(
- model_name,
- os.path.split(model_name)[-1],
- download=download,
- suffix=".onnx",
- **kwargs,
- )
- if with_data:
- url_data = model_name[:-5] + ".onnx_data"
- cls.download_to_cache(
- url_data,
- os.path.split(model_name)[-1][:-5] + ".onnx_data",
- download=download,
- suffix=".onnx_data",
- **kwargs,
- )
- return onnx.load(file_path) # type:ignore
- elif os.path.exists(model_name):
- return onnx.load(model_name) # type:ignore
- raise ValueError(f"File {model_name} not found")
- @staticmethod
- def _fetch_repo_contents(folder: str) -> list[dict[str, Any]]:
- """Fetch the contents of the Hugging Face repository using the Hugging Face API.
- Returns:
- A list of all files in the repository as dictionaries containing file details.
- """
- url = f"https://huggingface.co/api/models/kornia/ONNX_models/tree/main/{folder}"
- response = requests.get(url, timeout=10) # type:ignore
- if response.status_code == 200:
- return response.json() # Returns the JSON content of the repo
- else:
- raise ValueError(f"Failed to fetch repository contents: {response.status_code}")
- @classmethod
- def list_operators(cls) -> None:
- """List all available ONNX operators in the 'operators' folder of the Hugging Face repository."""
- repo_contents = cls._fetch_repo_contents("operators")
- # Filter for operators in the 'operators' directory
- operators = [file["path"] for file in repo_contents]
- pprint.pp(operators)
- @classmethod
- def list_models(cls) -> None:
- """List all available ONNX models in the 'models' folder of the Hugging Face repository."""
- repo_contents = cls._fetch_repo_contents("models")
- # Filter for models in the 'models' directory
- models = [file["path"] for file in repo_contents]
- pprint.pp(models)
- def io_name_conversion(
- onnx_model: onnx.ModelProto, # type:ignore
- io_name_mapping: dict[str, str],
- ) -> onnx.ModelProto: # type:ignore
- """Convert the input and output names of an ONNX model to 'input' and 'output'.
- Args:
- onnx_model: The ONNX model to convert.
- io_name_mapping: A dictionary mapping the original input and output names to the new ones.
- """
- # Convert I/O nodes
- for i in range(len(onnx_model.graph.input)):
- in_name = onnx_model.graph.input[i].name
- if in_name in io_name_mapping:
- onnx_model.graph.input[i].name = io_name_mapping[in_name]
- for i in range(len(onnx_model.graph.output)):
- out_name = onnx_model.graph.output[i].name
- if out_name in io_name_mapping:
- onnx_model.graph.output[i].name = io_name_mapping[out_name]
- # Convert intermediate nodes
- for i in range(len(onnx_model.graph.node)):
- for j in range(len(onnx_model.graph.node[i].input)):
- if onnx_model.graph.node[i].input[j] in io_name_mapping:
- onnx_model.graph.node[i].input[j] = io_name_mapping[in_name]
- for j in range(len(onnx_model.graph.node[i].output)):
- if onnx_model.graph.node[i].output[j] in io_name_mapping:
- onnx_model.graph.node[i].output[j] = io_name_mapping[out_name]
- return onnx_model
- def add_metadata(
- onnx_model: onnx.ModelProto, # type: ignore
- additional_metadata: Optional[list[tuple[str, str]]] = None,
- ) -> onnx.ModelProto: # type: ignore
- """Add metadata to an ONNX model.
- The metadata includes the source library (set to "kornia"), the version of kornia,
- and any additional metadata provided as a list of key-value pairs.
- Args:
- onnx_model: The ONNX model to add metadata to.
- additional_metadata: A list of tuples, where each tuple contains a key and a value
- for the additional metadata to add to the ONNX model.
- Returns:
- The ONNX model with the added metadata.
- """
- if additional_metadata is None:
- additional_metadata = []
- for key, value in [
- ("source", "kornia"),
- ("version", kornia.__version__),
- *additional_metadata,
- ]:
- metadata_props = onnx_model.metadata_props.add()
- metadata_props.key = key
- metadata_props.value = str(value)
- return onnx_model
- def onnx_type_to_numpy(onnx_type: str) -> Any:
- type_mapping = {
- "tensor(float)": np.float32,
- "tensor(float16)": np.float16,
- "tensor(double)": np.float64,
- "tensor(int32)": np.int32,
- "tensor(int64)": np.int64,
- "tensor(uint8)": np.uint8,
- "tensor(int8)": np.int8,
- "tensor(bool)": np.bool_,
- }
- if onnx_type not in type_mapping:
- raise TypeError(f"ONNX type {onnx_type} not understood")
- return type_mapping[onnx_type]
|