integrated.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import warnings
  18. from typing import ClassVar, Dict, List, Optional, Tuple
  19. import torch
  20. from kornia.color import rgb_to_grayscale
  21. from kornia.constants import pi
  22. from kornia.core import Device, Module, Tensor, concatenate, deg2rad
  23. from kornia.core.check import KORNIA_CHECK_LAF
  24. from kornia.geometry.subpix import ConvQuadInterp3d
  25. from kornia.geometry.transform import ScalePyramid
  26. from .affine_shape import LAFAffNetShapeEstimator
  27. from .hardnet import HardNet
  28. from .keynet import KeyNetDetector
  29. from .laf import extract_patches_from_pyramid, get_laf_center, get_laf_orientation, get_laf_scale, scale_laf
  30. from .lightglue import LightGlue
  31. from .matching import GeometryAwareDescriptorMatcher, _no_match
  32. from .orientation import LAFOrienter, OriNet, PassLAF
  33. from .responses import BlobDoG, BlobDoGSingle, BlobHessian, CornerGFTT
  34. from .scale_space_detector import (
  35. Detector_config,
  36. MultiResolutionDetector,
  37. ScaleSpaceDetector,
  38. get_default_detector_config,
  39. )
  40. from .siftdesc import SIFTDescriptor
  41. def get_laf_descriptors(
  42. img: Tensor, lafs: Tensor, patch_descriptor: Module, patch_size: int = 32, grayscale_descriptor: bool = True
  43. ) -> Tensor:
  44. r"""Get local descriptors, corresponding to LAFs (keypoints).
  45. Args:
  46. img: image features with shape :math:`(B,C,H,W)`.
  47. lafs: local affine frames :math:`(B,N,2,3)`.
  48. patch_descriptor: patch descriptor module, e.g. :class:`~kornia.feature.SIFTDescriptor`
  49. or :class:`~kornia.feature.HardNet`.
  50. patch_size: patch size in pixels, which descriptor expects.
  51. grayscale_descriptor: True if ``patch_descriptor`` expects single-channel image.
  52. Returns:
  53. Local descriptors of shape :math:`(B,N,D)` where :math:`D` is descriptor size.
  54. """
  55. KORNIA_CHECK_LAF(lafs)
  56. patch_descriptor = patch_descriptor.to(img)
  57. patch_descriptor.eval()
  58. timg: Tensor = img
  59. if lafs.shape[1] == 0:
  60. warnings.warn(f"LAF contains no keypoints {lafs.shape}, returning empty tensor", stacklevel=1)
  61. return torch.empty(lafs.shape[0], lafs.shape[1], 128, dtype=lafs.dtype, device=lafs.device)
  62. if grayscale_descriptor and img.size(1) == 3:
  63. timg = rgb_to_grayscale(img)
  64. patches: Tensor = extract_patches_from_pyramid(timg, lafs, patch_size)
  65. # Descriptor accepts standard tensor [B, CH, H, W], while patches are [B, N, CH, H, W] shape
  66. # So we need to reshape a bit :)
  67. B, N, CH, H, W = patches.size()
  68. return patch_descriptor(patches.view(B * N, CH, H, W)).view(B, N, -1)
  69. class LAFDescriptor(Module):
  70. r"""Module to get local descriptors, corresponding to LAFs (keypoints).
  71. Internally uses :func:`~kornia.feature.get_laf_descriptors`.
  72. Args:
  73. patch_descriptor_module: patch descriptor module, e.g. :class:`~kornia.feature.SIFTDescriptor`
  74. or :class:`~kornia.feature.HardNet`. Default: :class:`~kornia.feature.HardNet`.
  75. patch_size: patch size in pixels, which descriptor expects.
  76. grayscale_descriptor: ``True`` if patch_descriptor expects single-channel image.
  77. """
  78. def __init__(
  79. self, patch_descriptor_module: Optional[Module] = None, patch_size: int = 32, grayscale_descriptor: bool = True
  80. ) -> None:
  81. super().__init__()
  82. if patch_descriptor_module is None:
  83. patch_descriptor_module = HardNet(True)
  84. self.descriptor = patch_descriptor_module
  85. self.patch_size = patch_size
  86. self.grayscale_descriptor = grayscale_descriptor
  87. def __repr__(self) -> str:
  88. return (
  89. f"{self.__class__.__name__}"
  90. f"(descriptor={self.descriptor.__repr__()}, "
  91. f"patch_size={self.patch_size}, "
  92. f"grayscale_descriptor='{self.grayscale_descriptor})"
  93. )
  94. def forward(self, img: Tensor, lafs: Tensor) -> Tensor:
  95. r"""Three stage local feature detection.
  96. First the location and scale of interest points are determined by
  97. detect function. Then affine shape and orientation.
  98. Args:
  99. img: image features with shape :math:`(B,C,H,W)`.
  100. lafs: local affine frames :math:`(B,N,2,3)`.
  101. Returns:
  102. Local descriptors of shape :math:`(B,N,D)` where :math:`D` is descriptor size.
  103. """
  104. return get_laf_descriptors(img, lafs, self.descriptor, self.patch_size, self.grayscale_descriptor)
  105. class LocalFeature(Module):
  106. """Module, which combines local feature detector and descriptor.
  107. Args:
  108. detector: the detection module.
  109. descriptor: the descriptor module.
  110. scaling_coef: multiplier for change default detector scale (e.g. it is too small for KeyNet by default)
  111. """
  112. def __init__(self, detector: Module, descriptor: LAFDescriptor, scaling_coef: float = 1.0) -> None:
  113. super().__init__()
  114. self.detector = detector
  115. self.descriptor = descriptor
  116. if scaling_coef <= 0:
  117. raise ValueError(f"Scaling coef should be >= 0, got {scaling_coef}")
  118. self.scaling_coef = scaling_coef
  119. def forward(self, img: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
  120. """Run forward.
  121. Args:
  122. img: image to extract features with shape :math:`(B,C,H,W)`.
  123. mask: a mask with weights where to apply the response function.
  124. The shape must be the same as the input image.
  125. Returns:
  126. - Detected local affine frames with shape :math:`(B,N,2,3)`.
  127. - Response function values for corresponding lafs with shape :math:`(B,N,1)`.
  128. - Local descriptors of shape :math:`(B,N,D)` where :math:`D` is descriptor size.
  129. """
  130. lafs, responses = self.detector(img, mask)
  131. lafs = scale_laf(lafs, self.scaling_coef)
  132. descs = self.descriptor(img, lafs)
  133. return (lafs, responses, descs)
  134. class SIFTFeature(LocalFeature):
  135. """Convenience module, which implements DoG detector + (Root)SIFT descriptor.
  136. Using `kornia.feature.MultiResolutionDetector` without blur pyramid Still not as good as OpenCV/VLFeat because of
  137. https://github.com/kornia/kornia/pull/884,
  138. but we are working on it
  139. """
  140. def __init__(
  141. self,
  142. num_features: int = 8000,
  143. upright: bool = False,
  144. rootsift: bool = True,
  145. device: Optional[Device] = None,
  146. config: Optional[Detector_config] = None,
  147. ) -> None:
  148. patch_size: int = 41
  149. if device is None:
  150. device = torch.device("cpu")
  151. if config is None:
  152. config = get_default_detector_config()
  153. detector = MultiResolutionDetector(
  154. BlobDoGSingle(1.0, 1.6),
  155. num_features,
  156. config,
  157. ori_module=PassLAF() if upright else LAFOrienter(19),
  158. aff_module=PassLAF(),
  159. ).to(device)
  160. descriptor = LAFDescriptor(
  161. SIFTDescriptor(patch_size=patch_size, rootsift=rootsift), patch_size=patch_size, grayscale_descriptor=True
  162. ).to(device)
  163. super().__init__(detector, descriptor)
  164. class SIFTFeatureScaleSpace(LocalFeature):
  165. """Convenience module, which implements DoG detector + (Root)SIFT descriptor.
  166. Using `kornia.feature.ScaleSpaceDetector` with blur pyramid.
  167. Still not as good as OpenCV/VLFeat because of https://github.com/kornia/kornia/pull/884, but we are working on it
  168. """
  169. def __init__(
  170. self,
  171. num_features: int = 8000,
  172. upright: bool = False,
  173. rootsift: bool = True,
  174. device: Optional[Device] = None,
  175. ) -> None:
  176. if device is None:
  177. device = torch.device("cpu")
  178. patch_size: int = 41
  179. detector = ScaleSpaceDetector(
  180. num_features,
  181. resp_module=BlobDoG(),
  182. nms_module=ConvQuadInterp3d(10),
  183. scale_pyr_module=ScalePyramid(3, 1.6, 32, double_image=True),
  184. ori_module=PassLAF() if upright else LAFOrienter(19),
  185. scale_space_response=True,
  186. minima_are_also_good=True,
  187. mr_size=6.0,
  188. ).to(device)
  189. descriptor = LAFDescriptor(
  190. SIFTDescriptor(patch_size=patch_size, rootsift=rootsift), patch_size=patch_size, grayscale_descriptor=True
  191. ).to(device)
  192. super().__init__(detector, descriptor)
  193. class GFTTAffNetHardNet(LocalFeature):
  194. """Convenience module, which implements GFTT detector + AffNet-HardNet descriptor."""
  195. def __init__(
  196. self,
  197. num_features: int = 8000,
  198. upright: bool = False,
  199. device: Optional[Device] = None,
  200. config: Optional[Detector_config] = None,
  201. ) -> None:
  202. if device is None:
  203. device = torch.device("cpu")
  204. if config is None:
  205. config = get_default_detector_config()
  206. detector = MultiResolutionDetector(
  207. CornerGFTT(),
  208. num_features,
  209. config,
  210. ori_module=PassLAF() if upright else LAFOrienter(19),
  211. aff_module=LAFAffNetShapeEstimator(True).eval(),
  212. ).to(device)
  213. descriptor = LAFDescriptor(None, patch_size=32, grayscale_descriptor=True).to(device)
  214. super().__init__(detector, descriptor)
  215. class HesAffNetHardNet(LocalFeature):
  216. """Convenience module, which implements GFTT detector + AffNet-HardNet descriptor."""
  217. def __init__(
  218. self,
  219. num_features: int = 2048,
  220. upright: bool = False,
  221. device: Optional[Device] = None,
  222. config: Optional[Detector_config] = None,
  223. ) -> None:
  224. if device is None:
  225. device = torch.device("cpu")
  226. if config is None:
  227. config = get_default_detector_config()
  228. detector = MultiResolutionDetector(
  229. BlobHessian(),
  230. num_features,
  231. config,
  232. ori_module=PassLAF() if upright else LAFOrienter(19),
  233. aff_module=LAFAffNetShapeEstimator(True).eval(),
  234. ).to(device)
  235. descriptor = LAFDescriptor(None, patch_size=32, grayscale_descriptor=True).to(device)
  236. super().__init__(detector, descriptor)
  237. class KeyNetHardNet(LocalFeature):
  238. """Convenience module, which implements KeyNet detector + HardNet descriptor."""
  239. def __init__(
  240. self,
  241. num_features: int = 8000,
  242. upright: bool = False,
  243. device: Optional[Device] = None,
  244. scale_laf: float = 1.0,
  245. ) -> None:
  246. if device is None:
  247. device = torch.device("cpu")
  248. ori_module = PassLAF() if upright else LAFOrienter(angle_detector=OriNet(True))
  249. detector = KeyNetDetector(True, num_features=num_features, ori_module=ori_module).to(device)
  250. descriptor = LAFDescriptor(None, patch_size=32, grayscale_descriptor=True).to(device)
  251. super().__init__(detector, descriptor, scale_laf)
  252. class KeyNetAffNetHardNet(LocalFeature):
  253. """Convenience module, which implements KeyNet detector + AffNet + HardNet descriptor.
  254. .. image:: _static/img/keynet_affnet.jpg
  255. """
  256. def __init__(
  257. self,
  258. num_features: int = 8000,
  259. upright: bool = False,
  260. device: Optional[Device] = None,
  261. scale_laf: float = 1.0,
  262. ) -> None:
  263. if device is None:
  264. device = torch.device("cpu")
  265. ori_module = PassLAF() if upright else LAFOrienter(angle_detector=OriNet(True))
  266. detector = KeyNetDetector(
  267. True, num_features=num_features, ori_module=ori_module, aff_module=LAFAffNetShapeEstimator(True).eval()
  268. ).to(device)
  269. descriptor = LAFDescriptor(None, patch_size=32, grayscale_descriptor=True).to(device)
  270. super().__init__(detector, descriptor, scale_laf)
  271. class LocalFeatureMatcher(Module):
  272. r"""Module, which finds correspondences between two images based on local features.
  273. Args:
  274. local_feature: Local feature detector. See :class:`~kornia.feature.GFTTAffNetHardNet`.
  275. matcher: Descriptor matcher, see :class:`~kornia.feature.DescriptorMatcher`.
  276. Returns:
  277. Dict[str, Tensor]: Dictionary with image correspondences and confidence scores.
  278. Example:
  279. >>> img1 = torch.rand(1, 1, 320, 200)
  280. >>> img2 = torch.rand(1, 1, 128, 128)
  281. >>> input = {"image0": img1, "image1": img2}
  282. >>> gftt_hardnet_matcher = LocalFeatureMatcher(
  283. ... GFTTAffNetHardNet(10), kornia.feature.DescriptorMatcher('snn', 0.8)
  284. ... )
  285. >>> out = gftt_hardnet_matcher(input)
  286. """
  287. def __init__(self, local_feature: Module, matcher: Module) -> None:
  288. super().__init__()
  289. self.local_feature = local_feature
  290. self.matcher = matcher
  291. self.eval()
  292. def extract_features(self, image: Tensor, mask: Optional[Tensor] = None) -> Dict[str, Tensor]:
  293. """Extract features from simple image."""
  294. lafs0, resps0, descs0 = self.local_feature(image, mask)
  295. return {"lafs": lafs0, "responses": resps0, "descriptors": descs0}
  296. def no_match_output(self, device: Device, dtype: torch.dtype) -> Dict[str, Tensor]:
  297. return {
  298. "keypoints0": torch.empty(0, 2, device=device, dtype=dtype),
  299. "keypoints1": torch.empty(0, 2, device=device, dtype=dtype),
  300. "lafs0": torch.empty(0, 0, 2, 3, device=device, dtype=dtype),
  301. "lafs1": torch.empty(0, 0, 2, 3, device=device, dtype=dtype),
  302. "confidence": torch.empty(0, device=device, dtype=dtype),
  303. "batch_indexes": torch.empty(0, device=device, dtype=torch.long),
  304. }
  305. def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
  306. """Run forward.
  307. Args:
  308. data: dictionary containing the input data in the following format:
  309. Keyword Args:
  310. image0: left image with shape :math:`(N, 1, H1, W1)`.
  311. image1: right image with shape :math:`(N, 1, H2, W2)`.
  312. mask0 (optional): left image mask. '0' indicates a padded position :math:`(N, H1, W1)`.
  313. mask1 (optional): right image mask. '0' indicates a padded position :math:`(N, H2, W2)`.
  314. Returns:
  315. - ``keypoints0``, matching keypoints from image0 :math:`(NC, 2)`.
  316. - ``keypoints1``, matching keypoints from image1 :math:`(NC, 2)`.
  317. - ``confidence``, confidence score [0, 1] :math:`(NC)`.
  318. - ``lafs0``, matching LAFs from image0 :math:`(1, NC, 2, 3)`.
  319. - ``lafs1``, matching LAFs from image1 :math:`(1, NC, 2, 3)`.
  320. - ``batch_indexes``, batch indexes for the keypoints and lafs :math:`(NC)`.
  321. """
  322. num_image_pairs: int = data["image0"].shape[0]
  323. if ("lafs0" not in data.keys()) or ("descriptors0" not in data.keys()):
  324. # One can supply pre-extracted local features
  325. feats_dict0: Dict[str, Tensor] = self.extract_features(data["image0"])
  326. lafs0, descs0 = feats_dict0["lafs"], feats_dict0["descriptors"]
  327. else:
  328. lafs0, descs0 = data["lafs0"], data["descriptors0"]
  329. if ("lafs1" not in data.keys()) or ("descriptors1" not in data.keys()):
  330. feats_dict1: Dict[str, Tensor] = self.extract_features(data["image1"])
  331. lafs1, descs1 = feats_dict1["lafs"], feats_dict1["descriptors"]
  332. else:
  333. lafs1, descs1 = data["lafs1"], data["descriptors1"]
  334. keypoints0: Tensor = get_laf_center(lafs0)
  335. keypoints1: Tensor = get_laf_center(lafs1)
  336. out_keypoints0: List[Tensor] = []
  337. out_keypoints1: List[Tensor] = []
  338. out_confidence: List[Tensor] = []
  339. out_batch_indexes: List[Tensor] = []
  340. out_lafs0: List[Tensor] = []
  341. out_lafs1: List[Tensor] = []
  342. for batch_idx in range(num_image_pairs):
  343. dists, idxs = self.matcher(descs0[batch_idx], descs1[batch_idx])
  344. if len(idxs) == 0:
  345. continue
  346. current_keypoints_0 = keypoints0[batch_idx, idxs[:, 0]]
  347. current_keypoints_1 = keypoints1[batch_idx, idxs[:, 1]]
  348. current_lafs_0 = lafs0[batch_idx, idxs[:, 0]]
  349. current_lafs_1 = lafs1[batch_idx, idxs[:, 1]]
  350. out_confidence.append(1.0 - dists)
  351. batch_idxs = batch_idx * torch.ones(len(dists), device=keypoints0.device, dtype=torch.long)
  352. out_keypoints0.append(current_keypoints_0)
  353. out_keypoints1.append(current_keypoints_1)
  354. out_lafs0.append(current_lafs_0)
  355. out_lafs1.append(current_lafs_1)
  356. out_batch_indexes.append(batch_idxs)
  357. if len(out_batch_indexes) == 0:
  358. return self.no_match_output(data["image0"].device, data["image0"].dtype)
  359. return {
  360. "keypoints0": concatenate(out_keypoints0, dim=0).view(-1, 2),
  361. "keypoints1": concatenate(out_keypoints1, dim=0).view(-1, 2),
  362. "lafs0": concatenate(out_lafs0, dim=0).view(1, -1, 2, 3),
  363. "lafs1": concatenate(out_lafs1, dim=0).view(1, -1, 2, 3),
  364. "confidence": concatenate(out_confidence, dim=0).view(-1),
  365. "batch_indexes": concatenate(out_batch_indexes, dim=0).view(-1),
  366. }
  367. class LightGlueMatcher(GeometryAwareDescriptorMatcher):
  368. """LightGlue-based matcher in kornia API.
  369. This is based on the original code from paper "LightGlue: Local Feature Matching at Light Speed".
  370. See :cite:`LightGlue2023` for more details.
  371. Args:
  372. feature_name: type of feature for matching, can be `disk` or `superpoint`.
  373. params: LightGlue params.
  374. """
  375. known_modes: ClassVar[List[str]] = [
  376. "aliked",
  377. "dedodeb",
  378. "dedodeg",
  379. "disk",
  380. "dog_affnet_hardnet",
  381. "doghardnet",
  382. "keynet_affnet_hardnet",
  383. "sift",
  384. "superpoint",
  385. ]
  386. def __init__(self, feature_name: str = "disk", params: Optional[Dict] = None) -> None: # type: ignore
  387. feature_name_: str = feature_name.lower()
  388. super().__init__(feature_name_)
  389. self.feature_name = feature_name_
  390. if params is None:
  391. params = {}
  392. self.params = params
  393. self.matcher = LightGlue(self.feature_name, **params)
  394. def forward(
  395. self,
  396. desc1: Tensor,
  397. desc2: Tensor,
  398. lafs1: Tensor,
  399. lafs2: Tensor,
  400. hw1: Optional[Tuple[int, int]] = None,
  401. hw2: Optional[Tuple[int, int]] = None,
  402. ) -> Tuple[Tensor, Tensor]:
  403. """Run forward.
  404. Args:
  405. desc1: Batch of descriptors of a shape :math:`(B1, D)`.
  406. desc2: Batch of descriptors of a shape :math:`(B2, D)`.
  407. lafs1: LAFs of a shape :math:`(1, B1, 2, 3)`.
  408. lafs2: LAFs of a shape :math:`(1, B2, 2, 3)`.
  409. hw1: Height/width of image.
  410. hw2: Height/width of image.
  411. Return:
  412. - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
  413. - Long tensor indexes of matching descriptors in desc1 and desc2,
  414. shape of :math:`(B3, 2)` where :math:`0 <= B3 <= B1`.
  415. """
  416. if (desc1.shape[0] < 2) or (desc2.shape[0] < 2):
  417. return _no_match(desc1)
  418. keypoints1 = get_laf_center(lafs1)
  419. keypoints2 = get_laf_center(lafs2)
  420. if len(desc1.shape) == 2:
  421. desc1 = desc1.unsqueeze(0)
  422. if len(desc2.shape) == 2:
  423. desc2 = desc2.unsqueeze(0)
  424. dev = lafs1.device
  425. if hw1 is None:
  426. hw1_ = keypoints1.max(dim=1)[0].squeeze().flip(0)
  427. else:
  428. hw1_ = torch.tensor(hw1, device=dev)
  429. if hw2 is None:
  430. hw2_ = keypoints2.max(dim=1)[0].squeeze().flip(0)
  431. else:
  432. hw2_ = torch.tensor(hw2, device=dev)
  433. ori0 = deg2rad(get_laf_orientation(lafs1).reshape(1, -1))
  434. ori0[ori0 < 0] += 2.0 * pi
  435. ori1 = deg2rad(get_laf_orientation(lafs2).reshape(1, -1))
  436. ori1[ori1 < 0] += 2.0 * pi
  437. input_dict = {
  438. "image0": {
  439. "keypoints": keypoints1,
  440. "scales": get_laf_scale(lafs1).reshape(1, -1),
  441. "oris": ori0,
  442. "lafs": lafs1,
  443. "descriptors": desc1,
  444. "image_size": hw1_.flip(0).reshape(-1, 2).to(dev),
  445. },
  446. "image1": {
  447. "keypoints": keypoints2,
  448. "lafs": lafs2,
  449. "scales": get_laf_scale(lafs2).reshape(1, -1),
  450. "oris": ori1,
  451. "descriptors": desc2,
  452. "image_size": hw2_.flip(0).reshape(-1, 2).to(dev),
  453. },
  454. }
  455. pred = self.matcher(input_dict)
  456. matches0, mscores0 = pred["matches0"], pred["matching_scores0"]
  457. valid = matches0 > -1
  458. matches = torch.stack([torch.where(valid)[1], matches0[valid]], -1)
  459. return mscores0[valid].reshape(-1, 1), matches