disk.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Optional
  19. import torch
  20. from kornia.core import Module, Tensor
  21. from ._unets import Unet
  22. from .detector import heatmap_to_keypoints
  23. from .structs import DISKFeatures
  24. class DISK(Module):
  25. r"""Module which detects and described local features in an image using the DISK method.
  26. See :cite:`tyszkiewicz2020disk` for details.
  27. .. image:: _static/img/disk_outdoor_depth.jpg
  28. Args:
  29. desc_dim: The dimension of the descriptor.
  30. unet: The U-Net to use. If None, a default U-Net is used. Kornia doesn't provide the training code for DISK
  31. so this is only useful when using a custom checkpoint trained using the code released with the paper.
  32. The unet should take as input a tensor of shape :math:`(B, C, H, W)` and output a tensor of shape
  33. :math:`(B, \mathrm{desc\_dim} + 1, H, W)`.
  34. Example:
  35. >>> disk = DISK.from_pretrained('depth')
  36. >>> images = torch.rand(1, 3, 256, 256)
  37. >>> features = disk(images)
  38. """
  39. def __init__(self, desc_dim: int = 128, unet: None | Module = None) -> None:
  40. super().__init__()
  41. self.desc_dim = desc_dim
  42. if unet is None:
  43. unet = Unet(in_features=3, size=5, down=[16, 32, 64, 64, 64], up=[64, 64, 64, desc_dim + 1])
  44. self.unet = unet
  45. def heatmap_and_dense_descriptors(self, images: Tensor) -> tuple[Tensor, Tensor]:
  46. """Return the heatmap and the dense descriptors.
  47. .. image:: _static/img/DISK.png
  48. Args:
  49. images: The image to detect features in. Shape :math:`(B, 3, H, W)`.
  50. Returns:
  51. A tuple of dense detection scores and descriptors.
  52. Shapes are :math:`(B, 1, H, W)` and :math:`(B, D, H, W)`, where
  53. :math:`D` is the descriptor dimension.
  54. """
  55. unet_output = self.unet(images)
  56. if unet_output.shape[1] != self.desc_dim + 1:
  57. raise ValueError(
  58. f"U-Net output has {unet_output.shape[1]} channels, but expected self.desc_dim={self.desc_dim} + 1."
  59. )
  60. descriptors = unet_output[:, : self.desc_dim]
  61. heatmaps = unet_output[:, self.desc_dim :]
  62. return heatmaps, descriptors
  63. def forward(
  64. self,
  65. images: Tensor,
  66. n: Optional[int] = None,
  67. window_size: int = 5,
  68. score_threshold: float = 0.0,
  69. pad_if_not_divisible: bool = False,
  70. ) -> list[DISKFeatures]:
  71. """Detect features in an image, returning keypoint locations, descriptors and detection scores.
  72. Args:
  73. images: The image to detect features in. Shape :math:`(B, 3, H, W)`.
  74. n: The maximum number of keypoints to detect. If None, all keypoints are returned.
  75. window_size: The size of the non-maxima suppression window used to filter detections.
  76. score_threshold: The minimum score a detection must have to be returned.
  77. See :py:class:`DISKFeatures` for details.
  78. pad_if_not_divisible: if True, the non-16 divisible input is zero-padded to the closest 16-multiply
  79. Returns:
  80. A list of length :math:`B` containing the detected features.
  81. """
  82. B = images.shape[0]
  83. if pad_if_not_divisible:
  84. h, w = images.shape[2:]
  85. pd_h = 16 - h % 16 if h % 16 > 0 else 0
  86. pd_w = 16 - w % 16 if w % 16 > 0 else 0
  87. images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0)
  88. heatmaps, descriptors = self.heatmap_and_dense_descriptors(images)
  89. if pad_if_not_divisible:
  90. heatmaps = heatmaps[..., :h, :w]
  91. descriptors = descriptors[..., :h, :w]
  92. keypoints = heatmap_to_keypoints(heatmaps, n=n, window_size=window_size, score_threshold=score_threshold)
  93. features = []
  94. for i in range(B):
  95. features.append(keypoints[i].merge_with_descriptors(descriptors[i]))
  96. return features
  97. @classmethod
  98. def from_pretrained(cls, checkpoint: str = "depth", device: Optional[torch.device] = None) -> DISK:
  99. r"""Load a pretrained model.
  100. Depth model was trained using depth map supervision and is slightly more precise but biased to detect keypoints
  101. only where SfM depth is available. Epipolar model was trained using epipolar geometry supervision and
  102. is less precise but detects keypoints everywhere where they are matchable. The difference is especially
  103. pronounced on thin structures and on edges of objects.
  104. Args:
  105. checkpoint: The checkpoint to load. One of 'depth' or 'epipolar'.
  106. device: The device to load the model to.
  107. Returns:
  108. The pretrained model.
  109. """
  110. urls = {
  111. "depth": "https://raw.githubusercontent.com/cvlab-epfl/disk/master/depth-save.pth",
  112. "epipolar": "https://raw.githubusercontent.com/cvlab-epfl/disk/master/epipolar-save.pth",
  113. }
  114. if checkpoint not in urls:
  115. raise ValueError(f"Unknown pretrained model: {checkpoint}")
  116. if device is None:
  117. device = torch.device("cpu")
  118. pretrained_dict = torch.hub.load_state_dict_from_url(urls[checkpoint], map_location=device)
  119. model: DISK = cls().to(device)
  120. model.load_state_dict(pretrained_dict["extractor"])
  121. model.eval()
  122. return model