ransac.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """Module containing RANSAC modules."""
  18. import math
  19. from functools import partial
  20. from typing import Callable, Optional, Tuple
  21. import torch
  22. from kornia.core import Device, Module, Tensor, zeros
  23. from kornia.core.check import KORNIA_CHECK_SHAPE
  24. from kornia.geometry import (
  25. find_fundamental,
  26. find_homography_dlt,
  27. find_homography_dlt_iterated,
  28. find_homography_lines_dlt,
  29. find_homography_lines_dlt_iterated,
  30. symmetrical_epipolar_distance,
  31. )
  32. from kornia.geometry.homography import (
  33. line_segment_transfer_error_one_way,
  34. oneway_transfer_error,
  35. sample_is_valid_for_homography,
  36. )
  37. __all__ = ["RANSAC"]
  38. class RANSAC(Module):
  39. """Module for robust geometry estimation with RANSAC. https://en.wikipedia.org/wiki/Random_sample_consensus.
  40. Args:
  41. model_type: type of model to estimate: "homography", "fundamental", "fundamental_7pt",
  42. "homography_from_linesegments".
  43. inliers_threshold: threshold for the correspondence to be an inlier.
  44. batch_size: number of generated samples at once.
  45. max_iterations: maximum batches to generate. Actual number of models to try is ``batch_size * max_iterations``.
  46. confidence: desired confidence of the result, used for the early stopping.
  47. max_local_iterations: number of local optimization (polishing) iterations.
  48. """
  49. def __init__(
  50. self,
  51. model_type: str = "homography",
  52. inl_th: float = 2.0,
  53. batch_size: int = 2048,
  54. max_iter: int = 10,
  55. confidence: float = 0.99,
  56. max_lo_iters: int = 5,
  57. ) -> None:
  58. """Initialize the RANSAC estimator.
  59. Args:
  60. model_type: type of model to estimate: "homography", "fundamental", "fundamental_7pt",
  61. "homography_from_linesegments".
  62. inl_th: threshold for the correspondence to be an inlier.
  63. batch_size: number of generated samples at once.
  64. max_iter: maximum batches to generate. Actual number of models to try is ``batch_size * max_iter``.
  65. confidence: desired confidence of the result, used for the early stopping.
  66. max_lo_iters: number of local optimization (polishing) iterations.
  67. """
  68. super().__init__()
  69. self.supported_models = ["homography", "fundamental", "fundamental_7pt", "homography_from_linesegments"]
  70. self.inl_th = inl_th
  71. self.max_iter = max_iter
  72. self.batch_size = batch_size
  73. self.model_type = model_type
  74. self.confidence = confidence
  75. self.max_lo_iters = max_lo_iters
  76. self.model_type = model_type
  77. self.error_fn: Callable[..., Tensor]
  78. self.minimal_solver: Callable[..., Tensor]
  79. self.polisher_solver: Callable[..., Tensor]
  80. if model_type == "homography":
  81. self.error_fn = oneway_transfer_error
  82. self.minimal_solver = find_homography_dlt
  83. self.polisher_solver = find_homography_dlt_iterated
  84. self.minimal_sample_size = 4
  85. elif model_type == "homography_from_linesegments":
  86. self.error_fn = line_segment_transfer_error_one_way
  87. self.minimal_solver = find_homography_lines_dlt
  88. self.polisher_solver = find_homography_lines_dlt_iterated
  89. self.minimal_sample_size = 4
  90. elif model_type == "fundamental":
  91. self.error_fn = symmetrical_epipolar_distance
  92. self.minimal_solver = find_fundamental
  93. self.minimal_sample_size = 8
  94. self.polisher_solver = find_fundamental
  95. elif model_type == "fundamental_7pt":
  96. self.error_fn = symmetrical_epipolar_distance
  97. self.minimal_solver = partial(find_fundamental, method="7POINT")
  98. self.minimal_sample_size = 7
  99. self.polisher_solver = find_fundamental
  100. else:
  101. raise NotImplementedError(f"{model_type} is unknown. Try one of {self.supported_models}")
  102. def sample(self, sample_size: int, pop_size: int, batch_size: int, device: Optional[Device] = None) -> Tensor:
  103. """Minimal sampler, but unlike traditional RANSAC we sample in batches.
  104. Yields the benefit of the parallel processing, esp. on GPU.
  105. Args:
  106. sample_size: number of samples to draw from the population.
  107. pop_size: size of the population to sample from.
  108. batch_size: number of sample sets to generate.
  109. device: device to place the samples on.
  110. Returns:
  111. Tensor of sampled indices with shape :math:`(batch_size, sample_size)`.
  112. """
  113. if device is None:
  114. device = torch.device("cpu")
  115. rand = torch.rand(batch_size, pop_size, device=device)
  116. _, out = rand.topk(k=sample_size, dim=1)
  117. return out
  118. @staticmethod
  119. def max_samples_by_conf(n_inl: int, num_tc: int, sample_size: int, conf: float) -> float:
  120. """Update max_iter to stop iterations earlier https://en.wikipedia.org/wiki/Random_sample_consensus.
  121. Args:
  122. n_inl: number of inliers.
  123. num_tc: total number of correspondences.
  124. sample_size: size of minimal sample.
  125. conf: desired confidence level.
  126. Returns:
  127. Maximum number of samples needed to achieve the desired confidence.
  128. """
  129. eps = 1e-9
  130. if num_tc <= sample_size:
  131. return 1.0
  132. if n_inl == num_tc:
  133. return 1.0
  134. return math.log(1.0 - conf) / min(-eps, math.log(max(eps, 1.0 - math.pow(n_inl / num_tc, sample_size))))
  135. def estimate_model_from_minsample(self, kp1: Tensor, kp2: Tensor) -> Tensor:
  136. """Estimate models from minimal samples.
  137. Args:
  138. kp1: source keypoints with shape :math:`(batch_size, sample_size, 2)`.
  139. kp2: target keypoints with shape :math:`(batch_size, sample_size, 2)`.
  140. Returns:
  141. Estimated models tensor.
  142. """
  143. batch_size, sample_size = kp1.shape[:2]
  144. H = self.minimal_solver(kp1, kp2, torch.ones(batch_size, sample_size, dtype=kp1.dtype, device=kp1.device))
  145. return H
  146. def verify(self, kp1: Tensor, kp2: Tensor, models: Tensor, inl_th: float) -> Tuple[Tensor, Tensor, float]:
  147. """Verify models by computing inliers and selecting the best model.
  148. Args:
  149. kp1: source keypoints.
  150. kp2: target keypoints.
  151. models: candidate models to verify.
  152. inl_th: inlier threshold.
  153. Returns:
  154. Tuple containing:
  155. - Best model
  156. - Inlier mask for the best model
  157. - Score of the best model
  158. """
  159. if len(kp1.shape) == 2:
  160. kp1 = kp1[None]
  161. if len(kp2.shape) == 2:
  162. kp2 = kp2[None]
  163. batch_size = models.shape[0]
  164. if self.model_type == "homography_from_linesegments":
  165. errors = self.error_fn(kp1.expand(batch_size, -1, 2, 2), kp2.expand(batch_size, -1, 2, 2), models)
  166. else:
  167. errors = self.error_fn(kp1.expand(batch_size, -1, 2), kp2.expand(batch_size, -1, 2), models)
  168. inl = errors <= inl_th
  169. models_score = inl.to(kp1).sum(dim=1)
  170. best_model_idx = models_score.argmax()
  171. best_model_score = models_score[best_model_idx].item()
  172. model_best = models[best_model_idx].clone()
  173. inliers_best = inl[best_model_idx]
  174. return model_best, inliers_best, best_model_score
  175. def remove_bad_samples(self, kp1: Tensor, kp2: Tensor) -> Tuple[Tensor, Tensor]:
  176. """Remove degenerate samples based on model-specific constraints.
  177. Args:
  178. kp1: source keypoints.
  179. kp2: target keypoints.
  180. Returns:
  181. Tuple of filtered keypoints (kp1, kp2).
  182. """
  183. # ToDo: add (model-specific) verification of the samples,
  184. # E.g. constraints on not to be a degenerate sample
  185. if self.model_type == "homography":
  186. mask = sample_is_valid_for_homography(kp1, kp2)
  187. return kp1[mask], kp2[mask]
  188. return kp1, kp2
  189. def remove_bad_models(self, models: Tensor) -> Tensor:
  190. """Remove degenerate models based on simple heuristics.
  191. Args:
  192. models: candidate models to filter.
  193. Returns:
  194. Filtered models tensor.
  195. """
  196. # ToDo: add more and better degenerate model rejection
  197. # For now it is simple and hardcoded
  198. main_diagonal = torch.diagonal(models, dim1=1, dim2=2)
  199. mask = main_diagonal.abs().min(dim=1)[0] > 1e-4
  200. return models[mask]
  201. def polish_model(self, kp1: Tensor, kp2: Tensor, inliers: Tensor) -> Tensor:
  202. """Polish the model using inliers through local optimization.
  203. Args:
  204. kp1: source keypoints.
  205. kp2: target keypoints.
  206. inliers: boolean mask indicating inlier correspondences.
  207. Returns:
  208. Polished model tensor.
  209. """
  210. # TODO: Replace this with MAGSAC++ polisher
  211. kp1_inl = kp1[inliers][None]
  212. kp2_inl = kp2[inliers][None]
  213. num_inl = kp1_inl.size(1)
  214. model = self.polisher_solver(
  215. kp1_inl, kp2_inl, torch.ones(1, num_inl, dtype=kp1_inl.dtype, device=kp1_inl.device)
  216. )
  217. return model
  218. def validate_inputs(self, kp1: Tensor, kp2: Tensor, weights: Optional[Tensor] = None) -> None:
  219. """Validate input tensors for shape and size requirements.
  220. Args:
  221. kp1: source keypoints.
  222. kp2: target keypoints.
  223. weights: optional correspondence weights (not used currently).
  224. Raises:
  225. ValueError: if input shapes are invalid or insufficient correspondences.
  226. """
  227. if self.model_type in ["homography", "fundamental"]:
  228. KORNIA_CHECK_SHAPE(kp1, ["N", "2"])
  229. KORNIA_CHECK_SHAPE(kp2, ["N", "2"])
  230. if not (kp1.shape[0] == kp2.shape[0]) or (kp1.shape[0] < self.minimal_sample_size):
  231. raise ValueError(
  232. "kp1 and kp2 should be equal shape at least"
  233. f" [{self.minimal_sample_size}, 2], got {kp1.shape}, {kp2.shape}"
  234. )
  235. if self.model_type == "homography_from_linesegments":
  236. KORNIA_CHECK_SHAPE(kp1, ["N", "2", "2"])
  237. KORNIA_CHECK_SHAPE(kp2, ["N", "2", "2"])
  238. if not (kp1.shape[0] == kp2.shape[0]) or (kp1.shape[0] < self.minimal_sample_size):
  239. raise ValueError(
  240. "kp1 and kp2 should be equal shape at least"
  241. f" [{self.minimal_sample_size}, 2, 2], got {kp1.shape},"
  242. f" {kp2.shape}"
  243. )
  244. def forward(self, kp1: Tensor, kp2: Tensor, weights: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
  245. r"""Call main forward method to execute the RANSAC algorithm.
  246. Args:
  247. kp1: source image keypoints :math:`(N, 2)`.
  248. kp2: distance image keypoints :math:`(N, 2)`.
  249. weights: optional correspondences weights. Not used now.
  250. Returns:
  251. - Estimated model, shape of :math:`(1, 3, 3)`.
  252. - The inlier/outlier mask, shape of :math:`(1, N)`, where N is number of input correspondences.
  253. """
  254. self.validate_inputs(kp1, kp2, weights)
  255. best_score_total: float = float(self.minimal_sample_size)
  256. num_tc: int = len(kp1)
  257. best_model_total = zeros(3, 3, dtype=kp1.dtype, device=kp1.device)
  258. inliers_best_total: Tensor = zeros(num_tc, 1, device=kp1.device, dtype=torch.bool)
  259. for i in range(self.max_iter):
  260. # Sample minimal samples in batch to estimate models
  261. idxs = self.sample(self.minimal_sample_size, num_tc, self.batch_size, kp1.device)
  262. kp1_sampled = kp1[idxs]
  263. kp2_sampled = kp2[idxs]
  264. kp1_sampled, kp2_sampled = self.remove_bad_samples(kp1_sampled, kp2_sampled)
  265. if len(kp1_sampled) == 0:
  266. continue
  267. # Estimate models
  268. models = self.estimate_model_from_minsample(kp1_sampled, kp2_sampled)
  269. models = self.remove_bad_models(models)
  270. if (models is None) or (len(models) == 0):
  271. continue
  272. # Score the models and select the best one
  273. model, inliers, model_score = self.verify(kp1, kp2, models, self.inl_th)
  274. # Store far-the-best model and (optionally) do a local optimization
  275. if model_score > best_score_total:
  276. # Local optimization
  277. for _ in range(self.max_lo_iters):
  278. model_lo = self.polish_model(kp1, kp2, inliers)
  279. if (model_lo is None) or (len(model_lo) == 0):
  280. continue
  281. _, inliers_lo, score_lo = self.verify(kp1, kp2, model_lo, self.inl_th)
  282. # print (f"Orig score = {best_model_score}, LO score = {score_lo} TC={num_tc}")
  283. if score_lo > model_score:
  284. model = model_lo.clone()[0]
  285. inliers = inliers_lo.clone()
  286. model_score = score_lo
  287. else:
  288. break
  289. # Now storing the best model
  290. best_model_total = model.clone()
  291. inliers_best_total = inliers.clone()
  292. best_score_total = model_score
  293. # Should we already stop?
  294. new_max_iter = int(
  295. self.max_samples_by_conf(int(best_score_total), num_tc, self.minimal_sample_size, self.confidence)
  296. )
  297. # print (f"New max_iter = {new_max_iter}")
  298. # Stop estimation, if the model is very good
  299. if (i + 1) * self.batch_size >= new_max_iter:
  300. break
  301. # local optimization with all inliers for better precision
  302. return best_model_total, inliers_best_total