# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import json import logging import os import pprint from typing import Any, Optional import kornia from kornia.config import kornia_config from kornia.core.external import numpy as np from kornia.core.external import onnx, requests from kornia.utils.download import CachedDownloader __all__ = ["ONNXLoader", "add_metadata", "io_name_conversion"] logger = logging.getLogger(__name__) class ONNXLoader(CachedDownloader): """Manages ONNX models, handling local caching, downloading from Hugging Face, and loading models.""" @classmethod def load_config(cls, url: str, download: bool = True, **kwargs: Any) -> dict[str, Any]: """Load JSON config from the specified URL. Args: url: The URL of the preprocessor config to load. download: If True, the config will be downloaded if it's not already in the local cache. kwargs: Additional download arguments. Returns: dict[str, Any]: The loaded preprocessor config. """ if url.startswith(("http:", "https:")): file_path = cls.download_to_cache( url, os.path.split(url)[-1], download=download, suffix=".json", **kwargs, ) with open(file_path) as f: json_data = json.load(f) return json_data if not download: raise RuntimeError(f"File `{url}` not found. You may set `download=True`.") raise RuntimeError(f"File `{file_path}` not found.") @classmethod def load_model(cls, model_name: str, download: bool = True, with_data: bool = False, **kwargs) -> onnx.ModelProto: # type:ignore """Load an ONNX model from the local cache or downloads it from Hugging Face if necessary. Args: model_name: The name of the ONNX model or operator. For Hugging Face-hosted models, use the format 'hf://model_name'. Valid `model_name` can be found on https://huggingface.co/kornia/ONNX_models. Or a URL to the ONNX model. download: If True, the model will be downloaded from Hugging Face if it's not already in the local cache. cache_dir: The directory where the model should be cached. Defaults to None, which will use a default `{kornia_config.hub_onnx_dir}` directory. with_data: If True, the model will be loaded with its `.onnx_data` weights. **kwargs: Additional arguments to pass to the download method, if needed. Returns: onnx.ModelProto: The loaded ONNX model. """ if model_name.startswith("hf://"): model_name = model_name[len("hf://") :] url = f"https://huggingface.co/kornia/ONNX_models/resolve/main/{model_name}.onnx" cache_dir = kwargs.get("cache_dir", None) or os.path.join( kornia_config.hub_onnx_dir, model_name.split("/")[0] ) kwargs.update({"cache_dir": cache_dir}) file_path = cls.download_to_cache( url, model_name.split("/")[1], download=download, suffix=".onnx", **kwargs ) if with_data: url_data = f"https://huggingface.co/kornia/ONNX_models/resolve/main/{model_name}.onnx_data" cls.download_to_cache( url_data, model_name.split("/")[1], download=download, suffix=".onnx_data", **kwargs ) return onnx.load(file_path) # type:ignore elif model_name.startswith("http://") or model_name.startswith("https://"): cache_dir = kwargs.get("cache_dir", None) or kornia_config.hub_onnx_dir kwargs.update({"cache_dir": cache_dir}) file_path = cls.download_to_cache( model_name, os.path.split(model_name)[-1], download=download, suffix=".onnx", **kwargs, ) if with_data: url_data = model_name[:-5] + ".onnx_data" cls.download_to_cache( url_data, os.path.split(model_name)[-1][:-5] + ".onnx_data", download=download, suffix=".onnx_data", **kwargs, ) return onnx.load(file_path) # type:ignore elif os.path.exists(model_name): return onnx.load(model_name) # type:ignore raise ValueError(f"File {model_name} not found") @staticmethod def _fetch_repo_contents(folder: str) -> list[dict[str, Any]]: """Fetch the contents of the Hugging Face repository using the Hugging Face API. Returns: A list of all files in the repository as dictionaries containing file details. """ url = f"https://huggingface.co/api/models/kornia/ONNX_models/tree/main/{folder}" response = requests.get(url, timeout=10) # type:ignore if response.status_code == 200: return response.json() # Returns the JSON content of the repo else: raise ValueError(f"Failed to fetch repository contents: {response.status_code}") @classmethod def list_operators(cls) -> None: """List all available ONNX operators in the 'operators' folder of the Hugging Face repository.""" repo_contents = cls._fetch_repo_contents("operators") # Filter for operators in the 'operators' directory operators = [file["path"] for file in repo_contents] pprint.pp(operators) @classmethod def list_models(cls) -> None: """List all available ONNX models in the 'models' folder of the Hugging Face repository.""" repo_contents = cls._fetch_repo_contents("models") # Filter for models in the 'models' directory models = [file["path"] for file in repo_contents] pprint.pp(models) def io_name_conversion( onnx_model: onnx.ModelProto, # type:ignore io_name_mapping: dict[str, str], ) -> onnx.ModelProto: # type:ignore """Convert the input and output names of an ONNX model to 'input' and 'output'. Args: onnx_model: The ONNX model to convert. io_name_mapping: A dictionary mapping the original input and output names to the new ones. """ # Convert I/O nodes for i in range(len(onnx_model.graph.input)): in_name = onnx_model.graph.input[i].name if in_name in io_name_mapping: onnx_model.graph.input[i].name = io_name_mapping[in_name] for i in range(len(onnx_model.graph.output)): out_name = onnx_model.graph.output[i].name if out_name in io_name_mapping: onnx_model.graph.output[i].name = io_name_mapping[out_name] # Convert intermediate nodes for i in range(len(onnx_model.graph.node)): for j in range(len(onnx_model.graph.node[i].input)): if onnx_model.graph.node[i].input[j] in io_name_mapping: onnx_model.graph.node[i].input[j] = io_name_mapping[in_name] for j in range(len(onnx_model.graph.node[i].output)): if onnx_model.graph.node[i].output[j] in io_name_mapping: onnx_model.graph.node[i].output[j] = io_name_mapping[out_name] return onnx_model def add_metadata( onnx_model: onnx.ModelProto, # type: ignore additional_metadata: Optional[list[tuple[str, str]]] = None, ) -> onnx.ModelProto: # type: ignore """Add metadata to an ONNX model. The metadata includes the source library (set to "kornia"), the version of kornia, and any additional metadata provided as a list of key-value pairs. Args: onnx_model: The ONNX model to add metadata to. additional_metadata: A list of tuples, where each tuple contains a key and a value for the additional metadata to add to the ONNX model. Returns: The ONNX model with the added metadata. """ if additional_metadata is None: additional_metadata = [] for key, value in [ ("source", "kornia"), ("version", kornia.__version__), *additional_metadata, ]: metadata_props = onnx_model.metadata_props.add() metadata_props.key = key metadata_props.value = str(value) return onnx_model def onnx_type_to_numpy(onnx_type: str) -> Any: type_mapping = { "tensor(float)": np.float32, "tensor(float16)": np.float16, "tensor(double)": np.float64, "tensor(int32)": np.int32, "tensor(int64)": np.int64, "tensor(uint8)": np.uint8, "tensor(int8)": np.int8, "tensor(bool)": np.bool_, } if onnx_type not in type_mapping: raise TypeError(f"ONNX type {onnx_type} not understood") return type_mapping[onnx_type]