| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import Dict, Optional, Tuple
- import torch
- from kornia.core import Module, Tensor
- from kornia.feature import DescriptorMatcher, GFTTAffNetHardNet, LocalFeatureMatcher, LoFTR
- from kornia.feature.integrated import LocalFeature
- from kornia.geometry.linalg import transform_points
- from kornia.geometry.ransac import RANSAC
- from kornia.geometry.transform import warp_perspective
- class HomographyTracker(Module):
- r"""Perform local-feature-based tracking of the target planar object in the sequence of the frames.
- Args:
- initial_matcher: image matching module, e.g. :class:`~kornia.feature.LocalFeatureMatcher`
- or :class:`~kornia.feature.LoFTR`. Default: :class:`~kornia.feature.GFTTAffNetHardNet`.
- fast_matcher: fast image matching module, e.g. :class:`~kornia.feature.LocalFeatureMatcher`
- or :class:`~kornia.feature.LoFTR`. Default: :class:`~kornia.feature.DescriptorMatcher`.
- ransac: homography estimation module. Default: :class:`~kornia.geometry.RANSAC`.
- minimum_inliers_num: threshold for number inliers for matching to be successful.
- """
- def __init__(
- self,
- initial_matcher: Optional[LocalFeature] = None,
- fast_matcher: Optional[Module] = None,
- ransac: Optional[Module] = None,
- minimum_inliers_num: int = 30,
- ) -> None:
- super().__init__()
- self.initial_matcher = initial_matcher or (
- LocalFeatureMatcher(GFTTAffNetHardNet(3000), DescriptorMatcher("smnn", 0.95))
- )
- self.fast_matcher = fast_matcher or LoFTR("outdoor")
- self.ransac = ransac or RANSAC("homography", inl_th=5.0, batch_size=4096, max_iter=10, max_lo_iters=10)
- self.minimum_inliers_num = minimum_inliers_num
- # placeholders
- self.target: Tensor
- self.target_initial_representation: Dict[str, Tensor] = {}
- self.target_fast_representation: Dict[str, Tensor] = {}
- self.previous_homography: Optional[Tensor] = None
- self.inliers_num: int = 0
- self.keypoints0_num: int = 0
- self.keypoints1_num: int = 0
- self.reset_tracking()
- @property
- def device(self) -> torch.device:
- return self.target.device
- @property
- def dtype(self) -> torch.dtype:
- return self.target.dtype
- @torch.no_grad()
- def set_target(self, target: Tensor) -> None:
- self.target = target
- self.target_initial_representation = {}
- self.target_fast_representation = {}
- if hasattr(self.initial_matcher, "extract_features") and isinstance(
- self.initial_matcher.extract_features, Module
- ):
- self.target_initial_representation = self.initial_matcher.extract_features(target)
- if hasattr(self.fast_matcher, "extract_features") and isinstance(self.fast_matcher.extract_features, Module):
- self.target_fast_representation = self.fast_matcher.extract_features(target)
- def reset_tracking(self) -> None:
- self.previous_homography = None
- def no_match(self) -> Tuple[Tensor, bool]:
- self.inliers_num = 0
- self.keypoints0_num = 0
- self.keypoints1_num = 0
- return torch.empty(3, 3, device=self.device, dtype=self.dtype), False
- def match_initial(self, x: Tensor) -> Tuple[Tensor, bool]:
- """Match the frame `x` with initial_matcher and verified with ransac."""
- input_dict: Dict[str, Tensor] = {"image0": self.target, "image1": x}
- for k, v in self.target_initial_representation.items():
- input_dict[f"{k}0"] = v
- match_dict: Dict[str, Tensor] = self.initial_matcher(input_dict)
- keypoints0 = match_dict["keypoints0"][match_dict["batch_indexes"] == 0]
- keypoints1 = match_dict["keypoints1"][match_dict["batch_indexes"] == 0]
- self.keypoints0_num = len(keypoints0)
- self.keypoints1_num = len(keypoints1)
- if self.keypoints0_num < self.minimum_inliers_num:
- return self.no_match()
- H, inliers = self.ransac(keypoints0, keypoints1)
- self.inliers_num = inliers.sum().item()
- if self.inliers_num < self.minimum_inliers_num:
- return self.no_match()
- self.previous_homography = H.clone()
- return H, True
- def track_next_frame(self, x: Tensor) -> Tuple[Tensor, bool]:
- """Prewarp the frame `x` according to the previous frame homography.
- Matched with fast_matcher verified with ransac.
- """
- if self.previous_homography is not None: # mypy, shut up
- Hwarp = self.previous_homography.clone()[None]
- # make a bit of border for safety
- Hwarp[:, 0:2, 0:2] = Hwarp[:, 0:2, 0:2] / 0.8
- Hwarp[:, 0:2, 2] -= 10.0
- Hinv = torch.inverse(Hwarp)
- h, w = self.target.shape[2:]
- frame_warped = warp_perspective(x, Hinv, (h, w))
- input_dict: Dict[str, Tensor] = {"image0": self.target, "image1": frame_warped}
- for k, v in self.target_fast_representation.items():
- input_dict[f"{k}0"] = v
- match_dict = self.fast_matcher(input_dict)
- keypoints0 = match_dict["keypoints0"][match_dict["batch_indexes"] == 0]
- keypoints1 = match_dict["keypoints1"][match_dict["batch_indexes"] == 0]
- keypoints1 = transform_points(Hwarp, keypoints1)
- self.keypoints0_num = len(keypoints0)
- self.keypoints1_num = len(keypoints1)
- if self.keypoints0_num < self.minimum_inliers_num:
- self.reset_tracking()
- return self.no_match()
- H, inliers = self.ransac(keypoints0, keypoints1)
- self.inliers_num = inliers.sum().item()
- if self.inliers_num < self.minimum_inliers_num:
- self.reset_tracking()
- return self.no_match()
- self.previous_homography = H.clone()
- return H, True
- def forward(self, x: Tensor) -> Tuple[Tensor, bool]:
- if self.previous_homography is not None:
- return self.track_next_frame(x)
- return self.match_initial(x)
|