| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import logging
- import os
- from typing import Any, List, Optional, Tuple, Union
- import kornia
- from kornia.core import Tensor, stack
- from kornia.core.external import PILImage as Image
- from kornia.core.external import requests
- from kornia.io import load_image
- __all__ = [
- "get_sample_images",
- ]
- IMAGE_URLS: List[str] = [
- "https://raw.githubusercontent.com/kornia/data/main/panda.jpg",
- "https://raw.githubusercontent.com/kornia/data/main/simba.png",
- "https://raw.githubusercontent.com/kornia/data/main/girona.png",
- "https://raw.githubusercontent.com/kornia/data/main/baby_giraffe.png",
- "https://raw.githubusercontent.com/kornia/data/main/persistencia_memoria.jpg",
- "https://raw.githubusercontent.com/kornia/data/main/delorean.png",
- ]
- def download_image(url: str, save_to: str) -> None:
- """Download an image from a given URL and save it to a specified file path.
- Args:
- url: The URL of the image to download.
- save_to: The file path where the downloaded image will be saved.
- """
- im = Image.open(requests.get(url, stream=True, timeout=30).raw) # type:ignore
- im.save(save_to)
- def get_sample_images(
- resize: Optional[Tuple[int, int]] = None,
- paths: List[str] = IMAGE_URLS,
- download: bool = True,
- cache_dir: Optional[str] = None,
- as_list: Optional[bool] = None,
- divisible_factor: Optional[int] = None,
- **kwargs: Any,
- ) -> Union[Tensor, List[Tensor]]:
- """Load multiple images from the given URLs.
- Optionally download them, resize them if specified, and return them as a batch of tensors or a list of tensors.
- Args:
- paths: A list of path or URL from which to load or download images.
- Defaults to a pre-defined constant `IMAGE_URLS` if not provided.
- resize: Optional target size for resizing all images as a tuple (height, width).
- If not provided, the images will not be resized, and their original sizes will be retained.
- download: Whether to download the images if they are not already cached. Defaults to True.
- cache_dir: The directory where the downloaded images will be cached.
- Defaults to ".kornia_hub/images".
- as_list: if to keep the output as a list. If None and `resize` is None, the output will be a list.
- If None and `resize` is not None, the output will be a single tensor.
- divisible_factor:
- If not None, the images will be resized to the nearest multiple of `divisible_factor`.
- **kwargs: Additional keyword arguments to pass to `kornia.geometry.resize`.
- Returns:
- torch.Tensor | list[torch.Tensor]:
- If `resize` is provided, returns a single stacked tensor with shape (B, C, H, W).
- Otherwise, returns a list of tensors, each with its original shape (C, H, W).
- """
- if cache_dir is None:
- cache_dir = ".kornia_hub/images"
- os.makedirs(cache_dir, exist_ok=True)
- tensors = []
- for path in paths:
- if path.startswith("http"):
- name = os.path.basename(path)
- fname = os.path.join(cache_dir, name)
- if not os.path.exists(fname) and download:
- logging.info(f"Downloading `{path}` to `{fname}`.")
- download_image(path, fname)
- elif not os.path.exists(fname) and not download:
- logging.error(
- f"Image `{path}` not found at `{fname}`. You may want to set `download=True` to download it."
- )
- else:
- fname = path
- img_tensor = load_image(fname)
- if resize is not None:
- img_tensor = kornia.geometry.resize(img_tensor, resize, **kwargs)
- if divisible_factor is not None:
- img_tensor = kornia.geometry.transform.resize_to_be_divisible(img_tensor, divisible_factor, **kwargs)
- tensors.append(img_tensor)
- if not as_list and resize is not None:
- return stack(tensors)
- return tensors
|