matching.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, ClassVar, Dict, List, Optional, Tuple
  18. import torch
  19. from kornia.core import Module, Tensor, concatenate
  20. from kornia.core.check import KORNIA_CHECK_DM_DESC, KORNIA_CHECK_SHAPE
  21. from kornia.feature.laf import get_laf_center
  22. from kornia.feature.steerers import DiscreteSteerer
  23. from kornia.utils.helpers import is_mps_tensor_safe
  24. from .adalam import get_adalam_default_config, match_adalam
  25. def _cdist(d1: Tensor, d2: Tensor) -> Tensor:
  26. r"""Manual `torch.cdist` for M1."""
  27. if (not is_mps_tensor_safe(d1)) and (not is_mps_tensor_safe(d2)):
  28. return torch.cdist(d1, d2)
  29. d1_sq = (d1**2).sum(dim=1, keepdim=True)
  30. d2_sq = (d2**2).sum(dim=1, keepdim=True)
  31. dm = d1_sq.repeat(1, d2.size(0)) + d2_sq.repeat(1, d1.size(0)).t() - 2.0 * d1 @ d2.t()
  32. dm = dm.clamp(min=0.0).sqrt()
  33. return dm
  34. def _get_default_fginn_params() -> Dict[str, Any]:
  35. config = {"th": 0.85, "mutual": False, "spatial_th": 10.0}
  36. return config
  37. def _get_lazy_distance_matrix(desc1: Tensor, desc2: Tensor, dm_: Optional[Tensor] = None) -> Tensor:
  38. """Check validity of provided distance matrix, or calculates L2-distance matrix if dm is not provided.
  39. Args:
  40. desc1: Batch of descriptors of a shape :math:`(B1, D)`.
  41. desc2: Batch of descriptors of a shape :math:`(B2, D)`.
  42. dm_: Tensor containing the distances from each descriptor in desc1
  43. to each descriptor in desc2, shape of :math:`(B1, B2)`.
  44. """
  45. if dm_ is None:
  46. dm = _cdist(desc1, desc2)
  47. else:
  48. KORNIA_CHECK_DM_DESC(desc1, desc2, dm_)
  49. dm = dm_
  50. return dm
  51. def _no_match(dm: Tensor) -> Tuple[Tensor, Tensor]:
  52. """Output empty tensors.
  53. Returns:
  54. - Descriptor distance of matching descriptors, shape of :math:`(0, 1)`.
  55. - Long tensor indexes of matching descriptors in desc1 and desc2, shape of :math:`(0, 2)`.
  56. """
  57. dists = torch.empty(0, 1, device=dm.device, dtype=dm.dtype)
  58. idxs = torch.empty(0, 2, device=dm.device, dtype=torch.long)
  59. return dists, idxs
  60. def match_nn(desc1: Tensor, desc2: Tensor, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
  61. r"""Find nearest neighbors in desc2 for each vector in desc1.
  62. If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
  63. Args:
  64. desc1: Batch of descriptors of a shape :math:`(B1, D)`.
  65. desc2: Batch of descriptors of a shape :math:`(B2, D)`.
  66. dm: Tensor containing the distances from each descriptor in desc1
  67. to each descriptor in desc2, shape of :math:`(B1, B2)`.
  68. Returns:
  69. - Descriptor distance of matching descriptors, shape of :math:`(B1, 1)`.
  70. - Long tensor indexes of matching descriptors in desc1 and desc2, shape of :math:`(B1, 2)`.
  71. """
  72. KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
  73. KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])
  74. if (len(desc1) == 0) or (len(desc2) == 0):
  75. return _no_match(desc1)
  76. distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)
  77. match_dists, idxs_in_2 = torch.min(distance_matrix, dim=1)
  78. idxs_in1 = torch.arange(0, idxs_in_2.size(0), device=idxs_in_2.device)
  79. matches_idxs = concatenate([idxs_in1.view(-1, 1), idxs_in_2.view(-1, 1)], 1)
  80. return match_dists.view(-1, 1), matches_idxs.view(-1, 2)
  81. def match_mnn(desc1: Tensor, desc2: Tensor, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
  82. """Find mutual nearest neighbors in desc2 for each vector in desc1.
  83. If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
  84. Args:
  85. desc1: Batch of descriptors of a shape :math:`(B1, D)`.
  86. desc2: Batch of descriptors of a shape :math:`(B2, D)`.
  87. dm: Tensor containing the distances from each descriptor in desc1
  88. to each descriptor in desc2, shape of :math:`(B1, B2)`.
  89. Return:
  90. - Descriptor distance of matching descriptors, shape of. :math:`(B3, 1)`.
  91. - Long tensor indexes of matching descriptors in desc1 and desc2, shape of :math:`(B3, 2)`,
  92. where 0 <= B3 <= min(B1, B2)
  93. """
  94. KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
  95. KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])
  96. if (len(desc1) == 0) or (len(desc2) == 0):
  97. return _no_match(desc1)
  98. distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)
  99. ms = min(distance_matrix.size(0), distance_matrix.size(1))
  100. match_dists, idxs_in_2 = torch.min(distance_matrix, dim=1)
  101. match_dists2, idxs_in_1 = torch.min(distance_matrix, dim=0)
  102. minsize_idxs = torch.arange(ms, device=distance_matrix.device)
  103. if distance_matrix.size(0) <= distance_matrix.size(1):
  104. mutual_nns = minsize_idxs == idxs_in_1[idxs_in_2][:ms]
  105. matches_idxs = concatenate([minsize_idxs.view(-1, 1), idxs_in_2.view(-1, 1)], 1)[mutual_nns]
  106. match_dists = match_dists[mutual_nns]
  107. else:
  108. mutual_nns = minsize_idxs == idxs_in_2[idxs_in_1][:ms]
  109. matches_idxs = concatenate([idxs_in_1.view(-1, 1), minsize_idxs.view(-1, 1)], 1)[mutual_nns]
  110. match_dists = match_dists2[mutual_nns]
  111. return match_dists.view(-1, 1), matches_idxs.view(-1, 2)
  112. def match_snn(desc1: Tensor, desc2: Tensor, th: float = 0.8, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
  113. """Find nearest neighbors in desc2 for each vector in desc1.
  114. The method satisfies first to second nearest neighbor distance <= th.
  115. If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
  116. Args:
  117. desc1: Batch of descriptors of a shape :math:`(B1, D)`.
  118. desc2: Batch of descriptors of a shape :math:`(B2, D)`.
  119. th: distance ratio threshold.
  120. dm: Tensor containing the distances from each descriptor in desc1
  121. to each descriptor in desc2, shape of :math:`(B1, B2)`.
  122. Return:
  123. - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
  124. - Long tensor indexes of matching descriptors in desc1 and desc2. Shape: :math:`(B3, 2)`,
  125. where 0 <= B3 <= B1.
  126. """
  127. KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
  128. KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])
  129. if desc2.shape[0] < 2: # We cannot perform snn check, so output empty matches
  130. return _no_match(desc1)
  131. distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)
  132. vals, idxs_in_2 = torch.topk(distance_matrix, 2, dim=1, largest=False)
  133. ratio = vals[:, 0] / vals[:, 1]
  134. mask = ratio <= th
  135. match_dists = ratio[mask]
  136. if len(match_dists) == 0:
  137. return _no_match(distance_matrix)
  138. idxs_in1 = torch.arange(0, idxs_in_2.size(0), device=distance_matrix.device)[mask]
  139. idxs_in_2 = idxs_in_2[:, 0][mask]
  140. matches_idxs = concatenate([idxs_in1.view(-1, 1), idxs_in_2.view(-1, 1)], 1)
  141. return match_dists.view(-1, 1), matches_idxs.view(-1, 2)
  142. def match_smnn(desc1: Tensor, desc2: Tensor, th: float = 0.95, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
  143. """Find mutual nearest neighbors in desc2 for each vector in desc1.
  144. the method satisfies first to second nearest neighbor distance <= th.
  145. If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
  146. Args:
  147. desc1: Batch of descriptors of a shape :math:`(B1, D)`.
  148. desc2: Batch of descriptors of a shape :math:`(B2, D)`.
  149. th: distance ratio threshold.
  150. dm: Tensor containing the distances from each descriptor in desc1
  151. to each descriptor in desc2, shape of :math:`(B1, B2)`.
  152. Return:
  153. - Descriptor distance of matching descriptors, shape of. :math:`(B3, 1)`.
  154. - Long tensor indexes of matching descriptors in desc1 and desc2,
  155. shape of :math:`(B3, 2)` where 0 <= B3 <= B1.
  156. """
  157. KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
  158. KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])
  159. if (desc1.shape[0] < 2) or (desc2.shape[0] < 2):
  160. return _no_match(desc1)
  161. distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)
  162. dists1, idx1 = match_snn(desc1, desc2, th, distance_matrix)
  163. dists2, idx2 = match_snn(desc2, desc1, th, distance_matrix.t())
  164. if len(dists2) > 0 and len(dists1) > 0:
  165. idx2 = idx2.flip(1)
  166. if not is_mps_tensor_safe(idx1):
  167. idxs_dm = torch.cdist(idx1.float(), idx2.float(), p=1.0)
  168. else:
  169. idxs1_rep = idx1.to(desc1).repeat_interleave(idx2.size(0), dim=0)
  170. idxs_dm = (idx2.to(desc2).repeat(idx1.size(0), 1) - idxs1_rep).abs().sum(dim=1)
  171. idxs_dm = idxs_dm.reshape(idx1.size(0), idx2.size(0))
  172. mutual_idxs1 = idxs_dm.min(dim=1)[0] < 1e-8
  173. mutual_idxs2 = idxs_dm.min(dim=0)[0] < 1e-8
  174. good_idxs1 = idx1[mutual_idxs1.view(-1)]
  175. good_idxs2 = idx2[mutual_idxs2.view(-1)]
  176. dists1_good = dists1[mutual_idxs1.view(-1)]
  177. dists2_good = dists2[mutual_idxs2.view(-1)]
  178. _, idx_upl1 = torch.sort(good_idxs1[:, 0])
  179. _, idx_upl2 = torch.sort(good_idxs2[:, 0])
  180. good_idxs1 = good_idxs1[idx_upl1]
  181. match_dists = torch.max(dists1_good[idx_upl1], dists2_good[idx_upl2])
  182. matches_idxs = good_idxs1
  183. match_dists, matches_idxs = match_dists.view(-1, 1), matches_idxs.view(-1, 2)
  184. else:
  185. match_dists, matches_idxs = _no_match(distance_matrix)
  186. return match_dists, matches_idxs
  187. def match_fginn(
  188. desc1: Tensor,
  189. desc2: Tensor,
  190. lafs1: Tensor,
  191. lafs2: Tensor,
  192. th: float = 0.8,
  193. spatial_th: float = 10.0,
  194. mutual: bool = False,
  195. dm: Optional[Tensor] = None,
  196. ) -> Tuple[Tensor, Tensor]:
  197. """Find nearest neighbors in desc2 for each vector in desc1.
  198. The method satisfies first to second nearest neighbor distance <= th,
  199. and assures 2nd nearest neighbor is geometrically inconsistent with the 1st one
  200. (see :cite:`MODS2015` for more details)
  201. If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
  202. Args:
  203. desc1: Batch of descriptors of a shape :math:`(B1, D)`.
  204. desc2: Batch of descriptors of a shape :math:`(B2, D)`.
  205. lafs1: LAFs of a shape :math:`(1, B1, 2, 3)`.
  206. lafs2: LAFs of a shape :math:`(1, B2, 2, 3)`.
  207. th: distance ratio threshold.
  208. spatial_th: minimal distance in pixels to 2nd nearest neighbor.
  209. mutual: also perform mutual nearest neighbor check
  210. dm: Tensor containing the distances from each descriptor in desc1
  211. to each descriptor in desc2, shape of :math:`(B1, B2)`.
  212. Return:
  213. - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
  214. - Long tensor indexes of matching descriptors in desc1 and desc2. Shape: :math:`(B3, 2)`,
  215. where 0 <= B3 <= B1.
  216. """
  217. KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
  218. KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])
  219. BIG_NUMBER = 1000000.0
  220. distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)
  221. dtype = distance_matrix.dtype
  222. if desc2.shape[0] < 2: # We cannot perform snn check, so output empty matches
  223. return _no_match(distance_matrix)
  224. num_candidates = max(2, min(10, desc2.shape[0]))
  225. vals_cand, idxs_in_2 = torch.topk(distance_matrix, num_candidates, dim=1, largest=False)
  226. vals = vals_cand[:, 0]
  227. xy2 = get_laf_center(lafs2).view(-1, 2)
  228. candidates_xy = xy2[idxs_in_2]
  229. kdist = torch.norm(candidates_xy - candidates_xy[0:1], p=2, dim=2)
  230. fginn_vals = vals_cand[:, 1:] + (kdist[:, 1:] < spatial_th).to(dtype) * BIG_NUMBER
  231. fginn_vals_best, _fginn_idxs_best = fginn_vals.min(dim=1)
  232. # orig_idxs = idxs_in_2.gather(1, fginn_idxs_best.unsqueeze(1))[0]
  233. # if you need to know fginn indexes - uncomment
  234. vals_2nd = fginn_vals_best
  235. idxs_in_2 = idxs_in_2[:, 0]
  236. ratio = vals / vals_2nd
  237. mask = ratio <= th
  238. match_dists = ratio[mask]
  239. if len(match_dists) == 0:
  240. return _no_match(distance_matrix)
  241. idxs_in1 = torch.arange(0, idxs_in_2.size(0), device=distance_matrix.device)[mask]
  242. idxs_in_2 = idxs_in_2[mask]
  243. matches_idxs = concatenate([idxs_in1.view(-1, 1), idxs_in_2.view(-1, 1)], 1)
  244. match_dists, matches_idxs = match_dists.view(-1, 1), matches_idxs.view(-1, 2)
  245. if not mutual: # returning 1-way matches
  246. return match_dists, matches_idxs
  247. _, idxs_in_1_mut = torch.min(distance_matrix, dim=0)
  248. good_mask = matches_idxs[:, 0] == idxs_in_1_mut[matches_idxs[:, 1]]
  249. return match_dists[good_mask], matches_idxs[good_mask]
  250. class DescriptorMatcher(Module):
  251. """Module version of matching functions.
  252. See :func:`~kornia.feature.match_nn`, :func:`~kornia.feature.match_snn`,
  253. :func:`~kornia.feature.match_mnn` or :func:`~kornia.feature.match_smnn` for more details.
  254. Args:
  255. match_mode: type of matching, can be `nn`, `snn`, `mnn`, `smnn`.
  256. th: threshold on distance ratio, or other quality measure.
  257. """
  258. def __init__(self, match_mode: str = "snn", th: float = 0.8) -> None:
  259. super().__init__()
  260. _match_mode: str = match_mode.lower()
  261. self.known_modes = ["nn", "mnn", "snn", "smnn"]
  262. if _match_mode not in self.known_modes:
  263. raise NotImplementedError(f"{match_mode} is not supported. Try one of {self.known_modes}")
  264. self.match_mode = _match_mode
  265. self.th = th
  266. def forward(self, desc1: Tensor, desc2: Tensor) -> Tuple[Tensor, Tensor]:
  267. """Run forward.
  268. Args:
  269. desc1: Batch of descriptors of a shape :math:`(B1, D)`.
  270. desc2: Batch of descriptors of a shape :math:`(B2, D)`.
  271. Returns:
  272. - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
  273. - Long tensor indexes of matching descriptors in desc1 and desc2,
  274. shape of :math:`(B3, 2)` where :math:`0 <= B3 <= B1`.
  275. """
  276. if self.match_mode == "nn":
  277. out = match_nn(desc1, desc2)
  278. elif self.match_mode == "mnn":
  279. out = match_mnn(desc1, desc2)
  280. elif self.match_mode == "snn":
  281. out = match_snn(desc1, desc2, self.th)
  282. elif self.match_mode == "smnn":
  283. out = match_smnn(desc1, desc2, self.th)
  284. else:
  285. raise NotImplementedError
  286. return out
  287. class DescriptorMatcherWithSteerer(Module):
  288. """Matching that is invariant under rotations, using Steerers.
  289. Args:
  290. steerer: An instance of :func:`kornia.feature.steerers.DiscreteSteerer`.
  291. steerer_order: order of discretisation of rotation angles, e.g. 4 leads to quarter rotations.
  292. steer_mode: can be `global`, `local`.
  293. `global` means that the we output matches from the global rotation with most matches.
  294. `local` means that we output matches from a distance matrix
  295. where the distance between each descriptor pair is the minimal over rotations.
  296. match_mode: type of matching, can be `nn`, `snn`, `mnn`, `smnn`.
  297. WARNING: using steer_mode `global` with match_mode `nn` will lead to bad results
  298. since `nn` doesn't generate different amount of matches depending on goodness of fit.
  299. th: threshold on distance ratio, or other quality measure.
  300. Example:
  301. >>> import kornia as K
  302. >>> import kornia.feature as KF
  303. >>> device = K.utils.get_cuda_or_mps_device_if_available()
  304. >>> img1 = torch.randn([1, 3, 768, 768], device=device)
  305. >>> img2 = torch.randn([1, 3, 768, 768], device=device)
  306. >>> dedode = KF.DeDoDe.from_pretrained(detector_weights="L-C4-v2", descriptor_weights="B-SO2").to(device)
  307. >>> steerer_order = 8 # discretisation order of rotation angles
  308. >>> steerer = KF.steerers.DiscreteSteerer.create_dedode_default(
  309. ... generator_type="SO2", steerer_order=steerer_order
  310. ... )
  311. >>> steerer = steerer.to(device)
  312. >>> matcher = KF.matching.DescriptorMatcherWithSteerer(
  313. ... steerer=steerer, steerer_order=steerer_order, steer_mode="global", match_mode="smnn", th=0.98
  314. ... )
  315. >>> with torch.inference_mode():
  316. ... kps1, scores1, descs1 = dedode(img1, n=20_000)
  317. ... kps2, scores2, descs2 = dedode(img2, n=20_000)
  318. ... kps1, kps2, descs1, descs2 = kps1[0], kps2[0], descs1[0], descs2[0]
  319. ... dists, idxs, num_rot = matcher(
  320. ... descs1, descs2, normalize=True, subset_size=1000,
  321. ... )
  322. >>> # print(f"{idxs.shape[0]} tentative matches with steered DeDoDe")
  323. >>> # print(f"at rotation of {num_rot * 360 / steerer_order} degrees")
  324. """
  325. def __init__(
  326. self,
  327. steerer: DiscreteSteerer,
  328. steerer_order: int,
  329. steer_mode: str = "global",
  330. match_mode: str = "snn",
  331. th: float = 0.8,
  332. ) -> None:
  333. super().__init__()
  334. self.steerer = steerer
  335. self.steerer_order = steerer_order
  336. _steer_mode: str = steer_mode.lower()
  337. self.known_steer_modes = ["global", "local"]
  338. if _steer_mode not in self.known_steer_modes:
  339. raise NotImplementedError(f"{steer_mode} is not supported. Try one of {self.known_steer_modes}")
  340. self.steer_mode = _steer_mode
  341. _match_mode: str = match_mode.lower()
  342. self.known_modes = ["nn", "mnn", "snn", "smnn"]
  343. if _match_mode not in self.known_modes:
  344. raise NotImplementedError(f"{match_mode} is not supported. Try one of {self.known_modes}")
  345. self.match_mode = _match_mode
  346. self.th = th
  347. def matching_function(self, d1: Tensor, d2: Tensor, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
  348. if self.match_mode == "nn":
  349. return match_nn(d1, d2, dm=dm)
  350. elif self.match_mode == "mnn":
  351. return match_mnn(d1, d2, dm=dm)
  352. elif self.match_mode == "snn":
  353. return match_snn(d1, d2, self.th, dm=dm)
  354. elif self.match_mode == "smnn":
  355. return match_smnn(d1, d2, self.th, dm=dm)
  356. else:
  357. raise NotImplementedError
  358. def forward(
  359. self,
  360. desc1: Tensor,
  361. desc2: Tensor,
  362. normalize: bool = False,
  363. subset_size: Optional[int] = None,
  364. ) -> Tuple[Tensor, Tensor, Optional[int]]:
  365. """Run forward.
  366. Args:
  367. desc1: Batch of descriptors of a shape :math:`(B1, D)`.
  368. desc2: Batch of descriptors of a shape :math:`(B2, D)`.
  369. normalize: bool to decide whether to normalize descriptors to unit norm.
  370. subset_size: If set, the subset size to use for determining optimal
  371. number of rotations. Smaller subset size leads to faster but less
  372. accurate matching. Only used when `self.steer_mode` is `"global"`.
  373. Returns:
  374. - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
  375. - Long tensor indexes of matching descriptors in desc1 and desc2,
  376. shape of :math:`(B3, 2)` where :math:`0 <= B3 <= B1`.
  377. - Number of global rotations from desc1 to desc2, in terms of `self.steerer_order`
  378. (will be `None` if `self.steer_mode` is `local`).
  379. """
  380. rot1to2 = None
  381. if normalize:
  382. desc1 = torch.nn.functional.normalize(desc1, dim=-1)
  383. desc2 = torch.nn.functional.normalize(desc2, dim=-1)
  384. if self.steer_mode == "global":
  385. if subset_size is not None:
  386. subsample1 = torch.randperm(desc1.shape[0])[:subset_size]
  387. subsample2 = torch.randperm(desc2.shape[0])[:subset_size]
  388. _, _, rot1to2 = self(
  389. desc1[subsample1],
  390. desc2[subsample2],
  391. normalize=normalize,
  392. )
  393. desc1 = self.steerer.steer_descriptions(
  394. desc1,
  395. steerer_power=rot1to2,
  396. normalize=normalize,
  397. )
  398. dist, idx = self.matching_function(desc1, desc2, None)
  399. return dist, idx, rot1to2
  400. dist, idx = self.matching_function(desc1, desc2, None)
  401. rot1to2 = 0
  402. for r in range(1, self.steerer_order):
  403. desc1 = self.steerer.steer_descriptions(desc1, normalize=normalize)
  404. dist_new, idx_new = self.matching_function(desc1, desc2, None)
  405. if idx_new.shape[0] > idx.shape[0]:
  406. dist, idx, rot1to2 = dist_new, idx_new, r
  407. elif self.steer_mode == "local":
  408. dm = _cdist(desc1, desc2)
  409. for _ in range(1, self.steerer_order):
  410. desc1 = self.steerer.steer_descriptions(desc1, normalize=normalize)
  411. dm_new = _cdist(desc1, desc2)
  412. dm = torch.minimum(dm, dm_new)
  413. dist, idx = self.matching_function(desc1, desc2, dm)
  414. else:
  415. raise NotImplementedError
  416. return dist, idx, rot1to2
  417. class GeometryAwareDescriptorMatcher(Module):
  418. """Module version of matching functions.
  419. See :func:`~kornia.feature.match_nn`, :func:`~kornia.feature.match_snn`,
  420. :func:`~kornia.feature.match_mnn` or :func:`~kornia.feature.match_smnn` for more details.
  421. Args:
  422. match_mode: type of matching, can be `fginn`.
  423. th: threshold on distance ratio, or other quality measure.
  424. """
  425. known_modes: ClassVar[List[str]] = ["fginn", "adalam"]
  426. def __init__(self, match_mode: str = "fginn", params: Optional[Dict[str, Tensor]] = None) -> None:
  427. super().__init__()
  428. _match_mode: str = match_mode.lower()
  429. if _match_mode not in self.known_modes:
  430. raise NotImplementedError(f"{match_mode} is not supported. Try one of {self.known_modes}")
  431. self.match_mode = _match_mode
  432. self.params = params or {}
  433. def forward(self, desc1: Tensor, desc2: Tensor, lafs1: Tensor, lafs2: Tensor) -> Tuple[Tensor, Tensor]:
  434. """Run forward.
  435. Args:
  436. desc1: Batch of descriptors of a shape :math:`(B1, D)`.
  437. desc2: Batch of descriptors of a shape :math:`(B2, D)`.
  438. lafs1: LAFs of a shape :math:`(1, B1, 2, 3)`.
  439. lafs2: LAFs of a shape :math:`(1, B2, 2, 3)`.
  440. Returns:
  441. - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
  442. - Long tensor indexes of matching descriptors in desc1 and desc2,
  443. shape of :math:`(B3, 2)` where :math:`0 <= B3 <= B1`.
  444. """
  445. if self.match_mode == "fginn":
  446. params = _get_default_fginn_params()
  447. params.update(self.params)
  448. out = match_fginn(desc1, desc2, lafs1, lafs2, params["th"], params["spatial_th"], params["mutual"])
  449. elif self.match_mode == "adalam":
  450. _params = get_adalam_default_config()
  451. _params.update(self.params) # type: ignore[typeddict-item]
  452. out = match_adalam(desc1, desc2, lafs1, lafs2, config=_params)
  453. else:
  454. raise NotImplementedError
  455. return out