planar_tracker.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Dict, Optional, Tuple
  18. import torch
  19. from kornia.core import Module, Tensor
  20. from kornia.feature import DescriptorMatcher, GFTTAffNetHardNet, LocalFeatureMatcher, LoFTR
  21. from kornia.feature.integrated import LocalFeature
  22. from kornia.geometry.linalg import transform_points
  23. from kornia.geometry.ransac import RANSAC
  24. from kornia.geometry.transform import warp_perspective
  25. class HomographyTracker(Module):
  26. r"""Perform local-feature-based tracking of the target planar object in the sequence of the frames.
  27. Args:
  28. initial_matcher: image matching module, e.g. :class:`~kornia.feature.LocalFeatureMatcher`
  29. or :class:`~kornia.feature.LoFTR`. Default: :class:`~kornia.feature.GFTTAffNetHardNet`.
  30. fast_matcher: fast image matching module, e.g. :class:`~kornia.feature.LocalFeatureMatcher`
  31. or :class:`~kornia.feature.LoFTR`. Default: :class:`~kornia.feature.DescriptorMatcher`.
  32. ransac: homography estimation module. Default: :class:`~kornia.geometry.RANSAC`.
  33. minimum_inliers_num: threshold for number inliers for matching to be successful.
  34. """
  35. def __init__(
  36. self,
  37. initial_matcher: Optional[LocalFeature] = None,
  38. fast_matcher: Optional[Module] = None,
  39. ransac: Optional[Module] = None,
  40. minimum_inliers_num: int = 30,
  41. ) -> None:
  42. super().__init__()
  43. self.initial_matcher = initial_matcher or (
  44. LocalFeatureMatcher(GFTTAffNetHardNet(3000), DescriptorMatcher("smnn", 0.95))
  45. )
  46. self.fast_matcher = fast_matcher or LoFTR("outdoor")
  47. self.ransac = ransac or RANSAC("homography", inl_th=5.0, batch_size=4096, max_iter=10, max_lo_iters=10)
  48. self.minimum_inliers_num = minimum_inliers_num
  49. # placeholders
  50. self.target: Tensor
  51. self.target_initial_representation: Dict[str, Tensor] = {}
  52. self.target_fast_representation: Dict[str, Tensor] = {}
  53. self.previous_homography: Optional[Tensor] = None
  54. self.inliers_num: int = 0
  55. self.keypoints0_num: int = 0
  56. self.keypoints1_num: int = 0
  57. self.reset_tracking()
  58. @property
  59. def device(self) -> torch.device:
  60. return self.target.device
  61. @property
  62. def dtype(self) -> torch.dtype:
  63. return self.target.dtype
  64. @torch.no_grad()
  65. def set_target(self, target: Tensor) -> None:
  66. self.target = target
  67. self.target_initial_representation = {}
  68. self.target_fast_representation = {}
  69. if hasattr(self.initial_matcher, "extract_features") and isinstance(
  70. self.initial_matcher.extract_features, Module
  71. ):
  72. self.target_initial_representation = self.initial_matcher.extract_features(target)
  73. if hasattr(self.fast_matcher, "extract_features") and isinstance(self.fast_matcher.extract_features, Module):
  74. self.target_fast_representation = self.fast_matcher.extract_features(target)
  75. def reset_tracking(self) -> None:
  76. self.previous_homography = None
  77. def no_match(self) -> Tuple[Tensor, bool]:
  78. self.inliers_num = 0
  79. self.keypoints0_num = 0
  80. self.keypoints1_num = 0
  81. return torch.empty(3, 3, device=self.device, dtype=self.dtype), False
  82. def match_initial(self, x: Tensor) -> Tuple[Tensor, bool]:
  83. """Match the frame `x` with initial_matcher and verified with ransac."""
  84. input_dict: Dict[str, Tensor] = {"image0": self.target, "image1": x}
  85. for k, v in self.target_initial_representation.items():
  86. input_dict[f"{k}0"] = v
  87. match_dict: Dict[str, Tensor] = self.initial_matcher(input_dict)
  88. keypoints0 = match_dict["keypoints0"][match_dict["batch_indexes"] == 0]
  89. keypoints1 = match_dict["keypoints1"][match_dict["batch_indexes"] == 0]
  90. self.keypoints0_num = len(keypoints0)
  91. self.keypoints1_num = len(keypoints1)
  92. if self.keypoints0_num < self.minimum_inliers_num:
  93. return self.no_match()
  94. H, inliers = self.ransac(keypoints0, keypoints1)
  95. self.inliers_num = inliers.sum().item()
  96. if self.inliers_num < self.minimum_inliers_num:
  97. return self.no_match()
  98. self.previous_homography = H.clone()
  99. return H, True
  100. def track_next_frame(self, x: Tensor) -> Tuple[Tensor, bool]:
  101. """Prewarp the frame `x` according to the previous frame homography.
  102. Matched with fast_matcher verified with ransac.
  103. """
  104. if self.previous_homography is not None: # mypy, shut up
  105. Hwarp = self.previous_homography.clone()[None]
  106. # make a bit of border for safety
  107. Hwarp[:, 0:2, 0:2] = Hwarp[:, 0:2, 0:2] / 0.8
  108. Hwarp[:, 0:2, 2] -= 10.0
  109. Hinv = torch.inverse(Hwarp)
  110. h, w = self.target.shape[2:]
  111. frame_warped = warp_perspective(x, Hinv, (h, w))
  112. input_dict: Dict[str, Tensor] = {"image0": self.target, "image1": frame_warped}
  113. for k, v in self.target_fast_representation.items():
  114. input_dict[f"{k}0"] = v
  115. match_dict = self.fast_matcher(input_dict)
  116. keypoints0 = match_dict["keypoints0"][match_dict["batch_indexes"] == 0]
  117. keypoints1 = match_dict["keypoints1"][match_dict["batch_indexes"] == 0]
  118. keypoints1 = transform_points(Hwarp, keypoints1)
  119. self.keypoints0_num = len(keypoints0)
  120. self.keypoints1_num = len(keypoints1)
  121. if self.keypoints0_num < self.minimum_inliers_num:
  122. self.reset_tracking()
  123. return self.no_match()
  124. H, inliers = self.ransac(keypoints0, keypoints1)
  125. self.inliers_num = inliers.sum().item()
  126. if self.inliers_num < self.minimum_inliers_num:
  127. self.reset_tracking()
  128. return self.no_match()
  129. self.previous_homography = H.clone()
  130. return H, True
  131. def forward(self, x: Tensor) -> Tuple[Tensor, bool]:
  132. if self.previous_homography is not None:
  133. return self.track_next_frame(x)
  134. return self.match_initial(x)