# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Module containing RANSAC modules.""" import math from functools import partial from typing import Callable, Optional, Tuple import torch from kornia.core import Device, Module, Tensor, zeros from kornia.core.check import KORNIA_CHECK_SHAPE from kornia.geometry import ( find_fundamental, find_homography_dlt, find_homography_dlt_iterated, find_homography_lines_dlt, find_homography_lines_dlt_iterated, symmetrical_epipolar_distance, ) from kornia.geometry.homography import ( line_segment_transfer_error_one_way, oneway_transfer_error, sample_is_valid_for_homography, ) __all__ = ["RANSAC"] class RANSAC(Module): """Module for robust geometry estimation with RANSAC. https://en.wikipedia.org/wiki/Random_sample_consensus. Args: model_type: type of model to estimate: "homography", "fundamental", "fundamental_7pt", "homography_from_linesegments". inliers_threshold: threshold for the correspondence to be an inlier. batch_size: number of generated samples at once. max_iterations: maximum batches to generate. Actual number of models to try is ``batch_size * max_iterations``. confidence: desired confidence of the result, used for the early stopping. max_local_iterations: number of local optimization (polishing) iterations. """ def __init__( self, model_type: str = "homography", inl_th: float = 2.0, batch_size: int = 2048, max_iter: int = 10, confidence: float = 0.99, max_lo_iters: int = 5, ) -> None: """Initialize the RANSAC estimator. Args: model_type: type of model to estimate: "homography", "fundamental", "fundamental_7pt", "homography_from_linesegments". inl_th: threshold for the correspondence to be an inlier. batch_size: number of generated samples at once. max_iter: maximum batches to generate. Actual number of models to try is ``batch_size * max_iter``. confidence: desired confidence of the result, used for the early stopping. max_lo_iters: number of local optimization (polishing) iterations. """ super().__init__() self.supported_models = ["homography", "fundamental", "fundamental_7pt", "homography_from_linesegments"] self.inl_th = inl_th self.max_iter = max_iter self.batch_size = batch_size self.model_type = model_type self.confidence = confidence self.max_lo_iters = max_lo_iters self.model_type = model_type self.error_fn: Callable[..., Tensor] self.minimal_solver: Callable[..., Tensor] self.polisher_solver: Callable[..., Tensor] if model_type == "homography": self.error_fn = oneway_transfer_error self.minimal_solver = find_homography_dlt self.polisher_solver = find_homography_dlt_iterated self.minimal_sample_size = 4 elif model_type == "homography_from_linesegments": self.error_fn = line_segment_transfer_error_one_way self.minimal_solver = find_homography_lines_dlt self.polisher_solver = find_homography_lines_dlt_iterated self.minimal_sample_size = 4 elif model_type == "fundamental": self.error_fn = symmetrical_epipolar_distance self.minimal_solver = find_fundamental self.minimal_sample_size = 8 self.polisher_solver = find_fundamental elif model_type == "fundamental_7pt": self.error_fn = symmetrical_epipolar_distance self.minimal_solver = partial(find_fundamental, method="7POINT") self.minimal_sample_size = 7 self.polisher_solver = find_fundamental else: raise NotImplementedError(f"{model_type} is unknown. Try one of {self.supported_models}") def sample(self, sample_size: int, pop_size: int, batch_size: int, device: Optional[Device] = None) -> Tensor: """Minimal sampler, but unlike traditional RANSAC we sample in batches. Yields the benefit of the parallel processing, esp. on GPU. Args: sample_size: number of samples to draw from the population. pop_size: size of the population to sample from. batch_size: number of sample sets to generate. device: device to place the samples on. Returns: Tensor of sampled indices with shape :math:`(batch_size, sample_size)`. """ if device is None: device = torch.device("cpu") rand = torch.rand(batch_size, pop_size, device=device) _, out = rand.topk(k=sample_size, dim=1) return out @staticmethod def max_samples_by_conf(n_inl: int, num_tc: int, sample_size: int, conf: float) -> float: """Update max_iter to stop iterations earlier https://en.wikipedia.org/wiki/Random_sample_consensus. Args: n_inl: number of inliers. num_tc: total number of correspondences. sample_size: size of minimal sample. conf: desired confidence level. Returns: Maximum number of samples needed to achieve the desired confidence. """ eps = 1e-9 if num_tc <= sample_size: return 1.0 if n_inl == num_tc: return 1.0 return math.log(1.0 - conf) / min(-eps, math.log(max(eps, 1.0 - math.pow(n_inl / num_tc, sample_size)))) def estimate_model_from_minsample(self, kp1: Tensor, kp2: Tensor) -> Tensor: """Estimate models from minimal samples. Args: kp1: source keypoints with shape :math:`(batch_size, sample_size, 2)`. kp2: target keypoints with shape :math:`(batch_size, sample_size, 2)`. Returns: Estimated models tensor. """ batch_size, sample_size = kp1.shape[:2] H = self.minimal_solver(kp1, kp2, torch.ones(batch_size, sample_size, dtype=kp1.dtype, device=kp1.device)) return H def verify(self, kp1: Tensor, kp2: Tensor, models: Tensor, inl_th: float) -> Tuple[Tensor, Tensor, float]: """Verify models by computing inliers and selecting the best model. Args: kp1: source keypoints. kp2: target keypoints. models: candidate models to verify. inl_th: inlier threshold. Returns: Tuple containing: - Best model - Inlier mask for the best model - Score of the best model """ if len(kp1.shape) == 2: kp1 = kp1[None] if len(kp2.shape) == 2: kp2 = kp2[None] batch_size = models.shape[0] if self.model_type == "homography_from_linesegments": errors = self.error_fn(kp1.expand(batch_size, -1, 2, 2), kp2.expand(batch_size, -1, 2, 2), models) else: errors = self.error_fn(kp1.expand(batch_size, -1, 2), kp2.expand(batch_size, -1, 2), models) inl = errors <= inl_th models_score = inl.to(kp1).sum(dim=1) best_model_idx = models_score.argmax() best_model_score = models_score[best_model_idx].item() model_best = models[best_model_idx].clone() inliers_best = inl[best_model_idx] return model_best, inliers_best, best_model_score def remove_bad_samples(self, kp1: Tensor, kp2: Tensor) -> Tuple[Tensor, Tensor]: """Remove degenerate samples based on model-specific constraints. Args: kp1: source keypoints. kp2: target keypoints. Returns: Tuple of filtered keypoints (kp1, kp2). """ # ToDo: add (model-specific) verification of the samples, # E.g. constraints on not to be a degenerate sample if self.model_type == "homography": mask = sample_is_valid_for_homography(kp1, kp2) return kp1[mask], kp2[mask] return kp1, kp2 def remove_bad_models(self, models: Tensor) -> Tensor: """Remove degenerate models based on simple heuristics. Args: models: candidate models to filter. Returns: Filtered models tensor. """ # ToDo: add more and better degenerate model rejection # For now it is simple and hardcoded main_diagonal = torch.diagonal(models, dim1=1, dim2=2) mask = main_diagonal.abs().min(dim=1)[0] > 1e-4 return models[mask] def polish_model(self, kp1: Tensor, kp2: Tensor, inliers: Tensor) -> Tensor: """Polish the model using inliers through local optimization. Args: kp1: source keypoints. kp2: target keypoints. inliers: boolean mask indicating inlier correspondences. Returns: Polished model tensor. """ # TODO: Replace this with MAGSAC++ polisher kp1_inl = kp1[inliers][None] kp2_inl = kp2[inliers][None] num_inl = kp1_inl.size(1) model = self.polisher_solver( kp1_inl, kp2_inl, torch.ones(1, num_inl, dtype=kp1_inl.dtype, device=kp1_inl.device) ) return model def validate_inputs(self, kp1: Tensor, kp2: Tensor, weights: Optional[Tensor] = None) -> None: """Validate input tensors for shape and size requirements. Args: kp1: source keypoints. kp2: target keypoints. weights: optional correspondence weights (not used currently). Raises: ValueError: if input shapes are invalid or insufficient correspondences. """ if self.model_type in ["homography", "fundamental"]: KORNIA_CHECK_SHAPE(kp1, ["N", "2"]) KORNIA_CHECK_SHAPE(kp2, ["N", "2"]) if not (kp1.shape[0] == kp2.shape[0]) or (kp1.shape[0] < self.minimal_sample_size): raise ValueError( "kp1 and kp2 should be equal shape at least" f" [{self.minimal_sample_size}, 2], got {kp1.shape}, {kp2.shape}" ) if self.model_type == "homography_from_linesegments": KORNIA_CHECK_SHAPE(kp1, ["N", "2", "2"]) KORNIA_CHECK_SHAPE(kp2, ["N", "2", "2"]) if not (kp1.shape[0] == kp2.shape[0]) or (kp1.shape[0] < self.minimal_sample_size): raise ValueError( "kp1 and kp2 should be equal shape at least" f" [{self.minimal_sample_size}, 2, 2], got {kp1.shape}," f" {kp2.shape}" ) def forward(self, kp1: Tensor, kp2: Tensor, weights: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: r"""Call main forward method to execute the RANSAC algorithm. Args: kp1: source image keypoints :math:`(N, 2)`. kp2: distance image keypoints :math:`(N, 2)`. weights: optional correspondences weights. Not used now. Returns: - Estimated model, shape of :math:`(1, 3, 3)`. - The inlier/outlier mask, shape of :math:`(1, N)`, where N is number of input correspondences. """ self.validate_inputs(kp1, kp2, weights) best_score_total: float = float(self.minimal_sample_size) num_tc: int = len(kp1) best_model_total = zeros(3, 3, dtype=kp1.dtype, device=kp1.device) inliers_best_total: Tensor = zeros(num_tc, 1, device=kp1.device, dtype=torch.bool) for i in range(self.max_iter): # Sample minimal samples in batch to estimate models idxs = self.sample(self.minimal_sample_size, num_tc, self.batch_size, kp1.device) kp1_sampled = kp1[idxs] kp2_sampled = kp2[idxs] kp1_sampled, kp2_sampled = self.remove_bad_samples(kp1_sampled, kp2_sampled) if len(kp1_sampled) == 0: continue # Estimate models models = self.estimate_model_from_minsample(kp1_sampled, kp2_sampled) models = self.remove_bad_models(models) if (models is None) or (len(models) == 0): continue # Score the models and select the best one model, inliers, model_score = self.verify(kp1, kp2, models, self.inl_th) # Store far-the-best model and (optionally) do a local optimization if model_score > best_score_total: # Local optimization for _ in range(self.max_lo_iters): model_lo = self.polish_model(kp1, kp2, inliers) if (model_lo is None) or (len(model_lo) == 0): continue _, inliers_lo, score_lo = self.verify(kp1, kp2, model_lo, self.inl_th) # print (f"Orig score = {best_model_score}, LO score = {score_lo} TC={num_tc}") if score_lo > model_score: model = model_lo.clone()[0] inliers = inliers_lo.clone() model_score = score_lo else: break # Now storing the best model best_model_total = model.clone() inliers_best_total = inliers.clone() best_score_total = model_score # Should we already stop? new_max_iter = int( self.max_samples_by_conf(int(best_score_total), num_tc, self.minimal_sample_size, self.confidence) ) # print (f"New max_iter = {new_max_iter}") # Stop estimation, if the model is very good if (i + 1) * self.batch_size >= new_max_iter: break # local optimization with all inliers for better precision return best_model_total, inliers_best_total