# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import logging import os import urllib.request from typing import Any, Optional from kornia.config import kornia_config __all__ = ["CachedDownloader"] logger = logging.getLogger(__name__) class CachedDownloader: """Downloads files from URLs to the local cache or .kornia_hub directory.""" @classmethod def _get_file_path(cls, model_name: str, cache_dir: Optional[str], suffix: Optional[str] = None) -> str: """Construct the file path for the ONNX model based on the model name and cache directory. Args: model_name: The name of the model or operator, typically in the format 'operators/model_name'. cache_dir: The directory where the model should be cached. Defaults to None, which will use a default `kornia.config.hub_onnx_dir` directory. suffix: Optional file suffix when the filename is the model name. Returns: str: The full local path where the model should be stored or loaded from. """ # Determine the local file path if cache_dir is None: cache_dir = kornia_config.hub_cache_dir # The filename is the model name (without directory path) if suffix is not None and not model_name.endswith(suffix): file_name = f"{os.path.split(model_name)[-1]}{suffix}" else: file_name = os.path.split(model_name)[-1] file_path = os.path.join(*cache_dir.split(os.sep), *model_name.split(os.sep)[:-1], file_name) return file_path @classmethod def download_to_cache(cls, url: str, name: str, download: bool = True, **kwargs: Any) -> str: if url.startswith(("http:", "https:")): cache_dir = kwargs.get("cache_dir", None) suffix = kwargs.get("suffix", None) file_path = cls._get_file_path(name, cache_dir, suffix=suffix) cls.download(url, file_path, download_if_not_exists=download) return file_path raise ValueError(f"URL must start with 'http:' or 'https:'. Got {url}") @classmethod def download( cls, url: str, file_path: str, download_if_not_exists: bool = True, ) -> None: """Download an ONNX model from the specified URL and save it to the specified file path. Args: url: The URL of the ONNX model to download. file_path: The local path where the downloaded model should be saved. download_if_not_exists: If True, the file will be downloaded if it's not already downloaded. """ if os.path.exists(file_path): logger.info(f"Loading `{url}` from `{file_path}`.") return if not download_if_not_exists: raise ValueError(f"`{file_path}` not found. You may set `download=True`.") os.makedirs(os.path.dirname(file_path), exist_ok=True) # Create the cache directory if it doesn't exist if url.startswith(("http:", "https:")): try: logger.info(f"Downloading `{url}` to `{file_path}`.") urllib.request.urlretrieve(url, file_path) # noqa: S310 except urllib.error.HTTPError as e: raise ValueError(f"Error in resolving `{url}`.") from e else: raise ValueError("URL must start with 'http:' or 'https:'")