adalam.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # Integrated from original AdaLAM repo
  18. # https://github.com/cavalli1234/AdaLAM
  19. # Copyright (c) 2020, Luca Cavalli
  20. from typing import Optional, Tuple, Union
  21. import torch
  22. from kornia.core import Tensor, as_tensor
  23. from kornia.core.check import KORNIA_CHECK_LAF, KORNIA_CHECK_SHAPE
  24. from kornia.feature.laf import get_laf_center, get_laf_orientation, get_laf_scale
  25. from .core import AdalamConfig, _no_match, adalam_core
  26. from .utils import dist_matrix
  27. def get_adalam_default_config() -> AdalamConfig:
  28. """Return default Adalam Config."""
  29. return AdalamConfig(
  30. area_ratio=100,
  31. search_expansion=4,
  32. ransac_iters=128,
  33. min_inliers=6,
  34. min_confidence=200,
  35. orientation_difference_threshold=30,
  36. scale_rate_threshold=1.5,
  37. detected_scale_rate_threshold=5,
  38. refit=True,
  39. force_seed_mnn=True,
  40. device=torch.device("cpu"),
  41. )
  42. def match_adalam(
  43. desc1: Tensor,
  44. desc2: Tensor,
  45. lafs1: Tensor,
  46. lafs2: Tensor,
  47. config: Optional[AdalamConfig] = None,
  48. hw1: Optional[Tuple[int, int]] = None,
  49. hw2: Optional[Tuple[int, int]] = None,
  50. dm: Optional[Tensor] = None,
  51. ) -> Tuple[Tensor, Tensor]:
  52. """Perform descriptor matching, followed by AdaLAM filtering.
  53. See :cite:`AdaLAM2020` for more details.
  54. If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
  55. Args:
  56. desc1: Batch of descriptors of a shape :math:`(B1, D)`.
  57. desc2: Batch of descriptors of a shape :math:`(B2, D)`.
  58. lafs1: LAFs of a shape :math:`(1, B1, 2, 3)`.
  59. lafs2: LAFs of a shape :math:`(1, B2, 2, 3)`.
  60. config: dict with AdaLAM config
  61. dm: Tensor containing the distances from each descriptor in desc1
  62. to each descriptor in desc2, shape of :math:`(B1, B2)`.
  63. hw1: Height/width of image.
  64. hw2: Height/width of image.
  65. Return:
  66. - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
  67. - Long tensor indexes of matching descriptors in desc1 and desc2. Shape: :math:`(B3, 2)`,
  68. where 0 <= B3 <= B1.
  69. """
  70. KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
  71. KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])
  72. KORNIA_CHECK_LAF(lafs1)
  73. KORNIA_CHECK_LAF(lafs2)
  74. config_ = get_adalam_default_config()
  75. if config is None:
  76. config_["device"] = desc1.device
  77. else:
  78. config_ = get_adalam_default_config()
  79. for key, val in config.items():
  80. if key not in config_.keys():
  81. print(
  82. f"WARNING: custom configuration contains a key which is not recognized ({key}). "
  83. f"Known configurations are {list(config_.keys())}."
  84. )
  85. continue
  86. # TypedDict does not support variable names. https://stackoverflow.com/a/59583427/1983544
  87. config_[key] = val # type: ignore
  88. adalam_object = AdalamFilter(config_)
  89. idxs, quality = adalam_object.match_and_filter(
  90. get_laf_center(lafs1).reshape(-1, 2),
  91. get_laf_center(lafs2).reshape(-1, 2),
  92. desc1,
  93. desc2,
  94. hw1,
  95. hw2,
  96. get_laf_orientation(lafs1).reshape(-1),
  97. get_laf_orientation(lafs2).reshape(-1),
  98. get_laf_scale(lafs1).reshape(-1),
  99. get_laf_scale(lafs2).reshape(-1),
  100. return_dist=True,
  101. )
  102. return quality, idxs
  103. class AdalamFilter:
  104. def __init__(self, custom_config: Optional[AdalamConfig] = None) -> None:
  105. """Wrap the method AdaLAM for outlier filtering.
  106. init args:
  107. custom_config: dictionary overriding the default configuration. Missing parameters are kept as default.
  108. See documentation of DEFAULT_CONFIG for specific explanations on the accepted parameters.
  109. """
  110. if custom_config is not None:
  111. self.config = custom_config
  112. else:
  113. self.config = get_adalam_default_config()
  114. def filter_matches(
  115. self,
  116. k1: Tensor,
  117. k2: Tensor,
  118. putative_matches: Tensor,
  119. scores: Tensor,
  120. mnn: Optional[Tensor] = None,
  121. im1shape: Optional[Tuple[int, int]] = None,
  122. im2shape: Optional[Tuple[int, int]] = None,
  123. o1: Optional[Tensor] = None,
  124. o2: Optional[Tensor] = None,
  125. s1: Optional[Tensor] = None,
  126. s2: Optional[Tensor] = None,
  127. return_dist: bool = False,
  128. ) -> Union[Tuple[Tensor, Tensor], Tensor]:
  129. """Call the core functionality of AdaLAM, i.e. just outlier filtering.
  130. No sanity check is performed on the inputs.
  131. Args:
  132. k1: keypoint locations in the source image, in pixel coordinates.
  133. Expected a float32 tensor with shape (num_keypoints_in_source_image, 2).
  134. k2: keypoint locations in the destination image, in pixel coordinates.
  135. Expected a float32 tensor with shape (num_keypoints_in_destination_image, 2).
  136. putative_matches: Initial set of putative matches to be filtered.
  137. The current implementation assumes that these are unfiltered nearest neighbor matches,
  138. so it requires this to be a list of indices a_i such that the source keypoint i is
  139. associated to the destination keypoint a_i. For now to use AdaLAM on different inputs a
  140. workaround on the input format is required.
  141. Expected a long tensor with shape (num_keypoints_in_source_image,).
  142. scores: Confidence scores on the putative_matches. Usually holds Lowe's ratio scores.
  143. mnn: A mask indicating which putative matches are also mutual nearest neighbors. See documentation on
  144. 'force_seed_mnn' in the DEFAULT_CONFIG. If None, it disables the mutual nearest neighbor filtering on
  145. seed point selection. Expected a bool tensor with shape (num_keypoints_in_source_image,)
  146. im1shape: Shape of the source image. If None, it is inferred from keypoints max and min, at the cost of
  147. wasted runtime. So please provide it. Expected a tuple with (width, height) or (height, width)
  148. of source image
  149. im2shape: Shape of the destination image. If None, it is inferred from keypoints max and min, at the cost
  150. of wasted runtime. So please provide it. Expected a tuple with (width, height) or (height, width)
  151. of destination image
  152. o1: keypoint orientations in degrees. They can be None if 'orientation_difference_threshold' in config
  153. is set to None. See documentation on 'orientation_difference_threshold' in the DEFAULT_CONFIG.
  154. Expected a float32 tensor with shape (num_keypoints_in_source/destination_image,)
  155. o2: same as o1 but for destination.
  156. s1: keypoint scales. They can be None if 'scale_rate_threshold' in config is set to None.
  157. See documentation on 'scale_rate_threshold' in the DEFAULT_CONFIG.
  158. Expected a float32 tensor with shape (num_keypoints_in_source/destination_image,)
  159. s2: same as s1 but for destination.
  160. return_dist: if True, inverse confidence value is also outputted.
  161. Returns:
  162. Filtered putative matches.
  163. A long tensor with shape (num_filtered_matches, 2) with indices of corresponding keypoints in k1 and k2.
  164. """
  165. with torch.no_grad():
  166. return adalam_core(
  167. k1,
  168. k2,
  169. fnn12=putative_matches,
  170. scores1=scores,
  171. mnn=mnn,
  172. im1shape=im1shape,
  173. im2shape=im2shape,
  174. o1=o1,
  175. o2=o2,
  176. s1=s1,
  177. s2=s2,
  178. config=self.config,
  179. return_dist=return_dist,
  180. )
  181. def match_and_filter(
  182. self,
  183. k1: Tensor,
  184. k2: Tensor,
  185. d1: Tensor,
  186. d2: Tensor,
  187. im1shape: Optional[Tuple[int, int]] = None,
  188. im2shape: Optional[Tuple[int, int]] = None,
  189. o1: Optional[Tensor] = None,
  190. o2: Optional[Tensor] = None,
  191. s1: Optional[Tensor] = None,
  192. s2: Optional[Tensor] = None,
  193. return_dist: bool = False,
  194. ) -> Union[Tuple[Tensor, Tensor], Tensor]:
  195. """Match and filter with AdaLAM.
  196. This function:
  197. - performs some elementary sanity check on the inputs;
  198. - wraps input arrays into torch tensors and loads to GPU if necessary;
  199. - extracts nearest neighbors;
  200. - finds mutual nearest neighbors if required;
  201. - finally calls AdaLAM filtering.
  202. Args:
  203. k1: keypoint locations in the source image, in pixel coordinates.
  204. Expected an array with shape (num_keypoints_in_source_image, 2).
  205. k2: keypoint locations in the destination image, in pixel coordinates.
  206. Expected an array with shape (num_keypoints_in_destination_image, 2).
  207. d1: descriptors in the source image.
  208. Expected an array with shape (num_keypoints_in_source_image, descriptor_size).
  209. d2: descriptors in the destination image.
  210. Expected an array with shape (num_keypoints_in_destination_image, descriptor_size).
  211. im1shape: Shape of the source image. If None, it is inferred from keypoints max and min, at the cost of
  212. wasted runtime. So please provide it. Expected a tuple with (width, height) or (height, width)
  213. of source image
  214. im2shape: Shape of the destination image. If None, it is inferred from keypoints max and min, at the cost
  215. of wasted runtime. So please provide it. Expected a tuple with (width, height) or (height, width)
  216. of destination image
  217. o1: keypoint orientations in degrees. They can be None if 'orientation_difference_threshold' in config
  218. is set to None. See documentation on 'orientation_difference_threshold' in the DEFAULT_CONFIG.
  219. Expected an array with shape (num_keypoints_in_source/destination_image,)
  220. o2: Same as o1 for destination.
  221. s1: keypoint scales. They can be None if 'scale_rate_threshold' in config is set to None.
  222. See documentation on 'scale_rate_threshold' in the DEFAULT_CONFIG.
  223. Expected an array with shape (num_keypoints_in_source/destination_image,)
  224. s2: Same as s1 for destination.
  225. return_dist: if True, inverse confidence value is also outputted.
  226. Returns:
  227. Filtered putative matches.
  228. A long tensor with shape (num_filtered_matches, 2) with indices of corresponding keypoints in k1 and k2.
  229. """
  230. if s1 is None or s2 is None:
  231. if self.config["scale_rate_threshold"] is not None:
  232. raise AttributeError(
  233. "Current configuration considers keypoint scales for filtering, but scales have not been provided.\n" # noqa: E501
  234. "Please either provide scales or set 'scale_rate_threshold' to None to disable scale filtering"
  235. )
  236. if o1 is None or o2 is None:
  237. if self.config["orientation_difference_threshold"] is not None:
  238. raise AttributeError(
  239. "Current configuration considers keypoint orientations for filtering, but orientations have not been provided.\n" # noqa: E501
  240. "Please either provide orientations or set 'orientation_difference_threshold' to None to disable orientations filtering" # noqa: E501
  241. )
  242. _k1 = as_tensor(k1, device=self.config["device"], dtype=torch.float32)
  243. _k2 = as_tensor(k2, device=self.config["device"], dtype=torch.float32)
  244. _d1 = as_tensor(d1, device=self.config["device"], dtype=torch.float32)
  245. _d2 = as_tensor(d2, device=self.config["device"], dtype=torch.float32)
  246. if o1 is not None:
  247. _o1 = as_tensor(o1, device=self.config["device"], dtype=torch.float32)
  248. _o2 = as_tensor(o2, device=self.config["device"], dtype=torch.float32)
  249. else:
  250. _o1, _o2 = o1, o2
  251. if s1 is not None:
  252. _s1 = as_tensor(s1, device=self.config["device"], dtype=torch.float32)
  253. _s2 = as_tensor(s2, device=self.config["device"], dtype=torch.float32)
  254. else:
  255. _s1, _s2 = s1, s2
  256. if (len(_d2) <= 1) or (len(_d1) <= 1):
  257. idxs, dists = _no_match(_d1)
  258. if return_dist:
  259. return idxs, dists
  260. return idxs
  261. distmat = dist_matrix(_d1, _d2, is_normalized=False)
  262. dd12, nn12 = torch.topk(distmat, k=2, dim=1, largest=False) # (n1, 2)
  263. putative_matches = nn12[:, 0]
  264. scores = dd12[:, 0] / dd12[:, 1].clamp_min_(1e-3)
  265. if self.config["force_seed_mnn"]:
  266. _dd21, nn21 = torch.min(distmat, dim=0) # (n2,)
  267. mnn = nn21[putative_matches] == torch.arange(_k1.shape[0], device=self.config["device"])
  268. else:
  269. mnn = None
  270. return self.filter_matches(
  271. _k1, _k2, putative_matches, scores, mnn, im1shape, im2shape, _o1, _o2, _s1, _s2, return_dist
  272. )