| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- import collections.abc as collections
- from pathlib import Path
- from types import SimpleNamespace
- from typing import Callable, List, Optional, Tuple, Union
- import cv2
- import kornia
- import numpy as np
- import torch
- class ImagePreprocessor:
- default_conf = {
- "resize": None, # target edge length, None for no resizing
- "side": "long",
- "interpolation": "bilinear",
- "align_corners": None,
- "antialias": True,
- }
- def __init__(self, **conf) -> None:
- super().__init__()
- self.conf = {**self.default_conf, **conf}
- self.conf = SimpleNamespace(**self.conf)
- def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- """Resize and preprocess an image, return image and resize scale"""
- h, w = img.shape[-2:]
- if self.conf.resize is not None:
- img = kornia.geometry.transform.resize(
- img,
- self.conf.resize,
- side=self.conf.side,
- antialias=self.conf.antialias,
- align_corners=self.conf.align_corners,
- )
- scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
- return img, scale
- def map_tensor(input_, func: Callable):
- string_classes = (str, bytes)
- if isinstance(input_, string_classes):
- return input_
- elif isinstance(input_, collections.Mapping):
- return {k: map_tensor(sample, func) for k, sample in input_.items()}
- elif isinstance(input_, collections.Sequence):
- return [map_tensor(sample, func) for sample in input_]
- elif isinstance(input_, torch.Tensor):
- return func(input_)
- else:
- return input_
- def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True):
- """Move batch (dict) to device"""
- def _func(tensor):
- return tensor.to(device=device, non_blocking=non_blocking).detach()
- return map_tensor(batch, _func)
- def rbd(data: dict) -> dict:
- """Remove batch dimension from elements in data"""
- return {
- k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
- for k, v in data.items()
- }
- def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
- """Read an image from path as RGB or grayscale"""
- if not Path(path).exists():
- raise FileNotFoundError(f"No image at path {path}.")
- mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
- image = cv2.imread(str(path), mode)
- if image is None:
- raise IOError(f"Could not read image at {path}.")
- if not grayscale:
- image = image[..., ::-1]
- return image
- def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
- """Normalize the image tensor and reorder the dimensions."""
- if image.ndim == 3:
- image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
- elif image.ndim == 2:
- image = image[None] # add channel axis
- else:
- raise ValueError(f"Not an image: {image.shape}")
- return torch.tensor(image / 255.0, dtype=torch.float)
- def resize_image(
- image: np.ndarray,
- size: Union[List[int], int],
- fn: str = "max",
- interp: Optional[str] = "area",
- ) -> np.ndarray:
- """Resize an image to a fixed size, or according to max or min edge."""
- h, w = image.shape[:2]
- fn = {"max": max, "min": min}[fn]
- if isinstance(size, int):
- scale = size / fn(h, w)
- h_new, w_new = int(round(h * scale)), int(round(w * scale))
- scale = (w_new / w, h_new / h)
- elif isinstance(size, (tuple, list)):
- h_new, w_new = size
- scale = (w_new / w, h_new / h)
- else:
- raise ValueError(f"Incorrect new size: {size}")
- mode = {
- "linear": cv2.INTER_LINEAR,
- "cubic": cv2.INTER_CUBIC,
- "nearest": cv2.INTER_NEAREST,
- "area": cv2.INTER_AREA,
- }[interp]
- return cv2.resize(image, (w_new, h_new), interpolation=mode), scale
- def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor:
- image = read_image(path)
- if resize is not None:
- image, _ = resize_image(image, resize, **kwargs)
- return numpy_image_to_torch(image)
- class Extractor(torch.nn.Module):
- def __init__(self, **conf):
- super().__init__()
- self.conf = SimpleNamespace(**{**self.default_conf, **conf})
- @torch.no_grad()
- def extract(self, img: torch.Tensor, **conf) -> dict:
- """Perform extraction with online resizing"""
- if img.dim() == 3:
- img = img[None] # add batch dim
- assert img.dim() == 4 and img.shape[0] == 1
- shape = img.shape[-2:][::-1]
- img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
- feats = self.forward({"image": img})
- feats["image_size"] = torch.tensor(shape)[None].to(img).float()
- feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
- return feats
- def match_pair(
- extractor,
- matcher,
- image0: torch.Tensor,
- image1: torch.Tensor,
- device: str = "cpu",
- **preprocess,
- ):
- """Match a pair of images (image0, image1) with an extractor and matcher"""
- feats0 = extractor.extract(image0, **preprocess)
- feats1 = extractor.extract(image1, **preprocess)
- matches01 = matcher({"image0": feats0, "image1": feats1})
- data = [feats0, feats1, matches01]
- # remove batch dim and move to target device
- feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data]
- return feats0, feats1, matches01
|