| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from __future__ import annotations
- from typing import Optional
- import torch
- from kornia.core import Module, Tensor
- from ._unets import Unet
- from .detector import heatmap_to_keypoints
- from .structs import DISKFeatures
- class DISK(Module):
- r"""Module which detects and described local features in an image using the DISK method.
- See :cite:`tyszkiewicz2020disk` for details.
- .. image:: _static/img/disk_outdoor_depth.jpg
- Args:
- desc_dim: The dimension of the descriptor.
- unet: The U-Net to use. If None, a default U-Net is used. Kornia doesn't provide the training code for DISK
- so this is only useful when using a custom checkpoint trained using the code released with the paper.
- The unet should take as input a tensor of shape :math:`(B, C, H, W)` and output a tensor of shape
- :math:`(B, \mathrm{desc\_dim} + 1, H, W)`.
- Example:
- >>> disk = DISK.from_pretrained('depth')
- >>> images = torch.rand(1, 3, 256, 256)
- >>> features = disk(images)
- """
- def __init__(self, desc_dim: int = 128, unet: None | Module = None) -> None:
- super().__init__()
- self.desc_dim = desc_dim
- if unet is None:
- unet = Unet(in_features=3, size=5, down=[16, 32, 64, 64, 64], up=[64, 64, 64, desc_dim + 1])
- self.unet = unet
- def heatmap_and_dense_descriptors(self, images: Tensor) -> tuple[Tensor, Tensor]:
- """Return the heatmap and the dense descriptors.
- .. image:: _static/img/DISK.png
- Args:
- images: The image to detect features in. Shape :math:`(B, 3, H, W)`.
- Returns:
- A tuple of dense detection scores and descriptors.
- Shapes are :math:`(B, 1, H, W)` and :math:`(B, D, H, W)`, where
- :math:`D` is the descriptor dimension.
- """
- unet_output = self.unet(images)
- if unet_output.shape[1] != self.desc_dim + 1:
- raise ValueError(
- f"U-Net output has {unet_output.shape[1]} channels, but expected self.desc_dim={self.desc_dim} + 1."
- )
- descriptors = unet_output[:, : self.desc_dim]
- heatmaps = unet_output[:, self.desc_dim :]
- return heatmaps, descriptors
- def forward(
- self,
- images: Tensor,
- n: Optional[int] = None,
- window_size: int = 5,
- score_threshold: float = 0.0,
- pad_if_not_divisible: bool = False,
- ) -> list[DISKFeatures]:
- """Detect features in an image, returning keypoint locations, descriptors and detection scores.
- Args:
- images: The image to detect features in. Shape :math:`(B, 3, H, W)`.
- n: The maximum number of keypoints to detect. If None, all keypoints are returned.
- window_size: The size of the non-maxima suppression window used to filter detections.
- score_threshold: The minimum score a detection must have to be returned.
- See :py:class:`DISKFeatures` for details.
- pad_if_not_divisible: if True, the non-16 divisible input is zero-padded to the closest 16-multiply
- Returns:
- A list of length :math:`B` containing the detected features.
- """
- B = images.shape[0]
- if pad_if_not_divisible:
- h, w = images.shape[2:]
- pd_h = 16 - h % 16 if h % 16 > 0 else 0
- pd_w = 16 - w % 16 if w % 16 > 0 else 0
- images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0)
- heatmaps, descriptors = self.heatmap_and_dense_descriptors(images)
- if pad_if_not_divisible:
- heatmaps = heatmaps[..., :h, :w]
- descriptors = descriptors[..., :h, :w]
- keypoints = heatmap_to_keypoints(heatmaps, n=n, window_size=window_size, score_threshold=score_threshold)
- features = []
- for i in range(B):
- features.append(keypoints[i].merge_with_descriptors(descriptors[i]))
- return features
- @classmethod
- def from_pretrained(cls, checkpoint: str = "depth", device: Optional[torch.device] = None) -> DISK:
- r"""Load a pretrained model.
- Depth model was trained using depth map supervision and is slightly more precise but biased to detect keypoints
- only where SfM depth is available. Epipolar model was trained using epipolar geometry supervision and
- is less precise but detects keypoints everywhere where they are matchable. The difference is especially
- pronounced on thin structures and on edges of objects.
- Args:
- checkpoint: The checkpoint to load. One of 'depth' or 'epipolar'.
- device: The device to load the model to.
- Returns:
- The pretrained model.
- """
- urls = {
- "depth": "https://raw.githubusercontent.com/cvlab-epfl/disk/master/depth-save.pth",
- "epipolar": "https://raw.githubusercontent.com/cvlab-epfl/disk/master/epipolar-save.pth",
- }
- if checkpoint not in urls:
- raise ValueError(f"Unknown pretrained model: {checkpoint}")
- if device is None:
- device = torch.device("cpu")
- pretrained_dict = torch.hub.load_state_dict_from_url(urls[checkpoint], map_location=device)
- model: DISK = cls().to(device)
- model.load_state_dict(pretrained_dict["extractor"])
- model.eval()
- return model
|