ransac.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, Dict, Tuple, Union
  18. import torch
  19. from kornia.core import Tensor
  20. from .utils import arange_sequence, batch_2x2_ellipse, batch_2x2_inv, draw_first_k_couples, piecewise_arange
  21. def stable_sort_residuals(residuals: Tensor, ransidx: Tensor) -> Tuple[Tensor, Tensor]:
  22. """Sort residuals."""
  23. logres = torch.log(residuals + 1e-10)
  24. minlogres = torch.min(logres)
  25. maxlogres = torch.max(logres)
  26. sorting_score = ransidx.unsqueeze(0).float() + 0.99 * (logres - minlogres) / (maxlogres - minlogres)
  27. sorting_idxes = torch.argsort(sorting_score, dim=-1) # (niters, numsamples)
  28. iters_range = torch.arange(residuals.shape[0], device=residuals.device)
  29. return residuals[iters_range.unsqueeze(-1), sorting_idxes], sorting_idxes
  30. def group_sum_and_cumsum(
  31. scores_mat: Tensor, end_group_idx: Tensor, group_idx: Union[Tensor, slice, None] = None
  32. ) -> Tuple[Tensor, Union[Tensor, None]]:
  33. """Calculate cumulative sum over group."""
  34. cumulative_scores = torch.cumsum(scores_mat, dim=1)
  35. ending_cumusums = cumulative_scores[:, end_group_idx]
  36. shifted_ending_cumusums = torch.cat(
  37. [
  38. torch.zeros(size=(ending_cumusums.shape[0], 1), dtype=ending_cumusums.dtype, device=scores_mat.device),
  39. ending_cumusums[:, :-1],
  40. ],
  41. dim=1,
  42. )
  43. grouped_sums = ending_cumusums - shifted_ending_cumusums
  44. if group_idx is not None:
  45. grouped_cumsums = cumulative_scores - shifted_ending_cumusums[:, group_idx]
  46. return grouped_sums, grouped_cumsums
  47. return grouped_sums, None
  48. def confidence_based_inlier_selection(
  49. residuals: Tensor, ransidx: Tensor, rdims: Tensor, idxoffsets: Tensor, dv: torch.device, min_confidence: Tensor
  50. ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
  51. """Select inliers from confidence scores."""
  52. numransacs = rdims.shape[0]
  53. numiters = residuals.shape[0]
  54. sorted_res, sorting_idxes = stable_sort_residuals(residuals, ransidx)
  55. sorted_res_sqr = sorted_res**2
  56. too_perfect_fits = sorted_res_sqr <= 1e-8
  57. end_rans_indexing = torch.cumsum(rdims, dim=0) - 1
  58. _, inv_indices, res_dup_counts = torch.unique_consecutive(
  59. sorted_res_sqr.half().float(), dim=1, return_counts=True, return_inverse=True
  60. )
  61. duplicates_per_sample = res_dup_counts[inv_indices]
  62. inlier_weights = (1.0 / duplicates_per_sample).repeat(numiters, 1)
  63. inlier_weights[too_perfect_fits] = 0.0
  64. balanced_rdims, weights_cumsums = group_sum_and_cumsum(inlier_weights, end_rans_indexing, ransidx)
  65. if not isinstance(weights_cumsums, Tensor):
  66. raise TypeError("Expected the `weights_cumsums` to be a Tensor!")
  67. progressive_inl_rates = weights_cumsums.float() / (balanced_rdims.repeat_interleave(rdims, dim=1)).float()
  68. good_inl_mask = (sorted_res_sqr * min_confidence <= progressive_inl_rates) | too_perfect_fits
  69. inlier_weights[~good_inl_mask] = 0.0
  70. inlier_counts_matrix, _ = group_sum_and_cumsum(inlier_weights, end_rans_indexing)
  71. inl_counts, inl_iters = torch.max(inlier_counts_matrix.long(), dim=0)
  72. relative_inl_idxes = arange_sequence(inl_counts)
  73. inl_ransidx = torch.arange(numransacs, device=dv).repeat_interleave(inl_counts)
  74. inl_sampleidx = sorting_idxes[inl_iters.repeat_interleave(inl_counts), idxoffsets[inl_ransidx] + relative_inl_idxes]
  75. highest_accepted_sqr_residuals = sorted_res_sqr[inl_iters, idxoffsets + inl_counts - 1]
  76. expected_extra_inl = (
  77. balanced_rdims[inl_iters, torch.arange(numransacs, device=dv)].float() * highest_accepted_sqr_residuals
  78. )
  79. return inl_ransidx, inl_sampleidx, inl_counts, inl_iters, inl_counts.float() / expected_extra_inl
  80. def sample_padded_inliers(
  81. xsamples: Tensor,
  82. ysamples: Tensor,
  83. inlier_counts: Tensor,
  84. inl_ransidx: Tensor,
  85. inl_sampleidx: Tensor,
  86. numransacs: int,
  87. dv: torch.device,
  88. ) -> Tuple[Tensor, Tensor]:
  89. """Sample from padded inliers."""
  90. maxinliers = int(torch.max(inlier_counts).item())
  91. dtype = xsamples.dtype
  92. padded_inlier_x = torch.zeros(size=(numransacs, maxinliers, 2), device=dv, dtype=dtype)
  93. padded_inlier_y = torch.zeros(size=(numransacs, maxinliers, 2), device=dv, dtype=dtype)
  94. padded_inlier_x[inl_ransidx, piecewise_arange(inl_ransidx)] = xsamples[inl_sampleidx]
  95. padded_inlier_y[inl_ransidx, piecewise_arange(inl_ransidx)] = ysamples[inl_sampleidx]
  96. return padded_inlier_x, padded_inlier_y
  97. def ransac(
  98. xsamples: Tensor, ysamples: Tensor, rdims: Tensor, config: Dict[str, Any], iters: int = 128, refit: bool = True
  99. ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  100. """Run ransac."""
  101. DET_THR = config["detected_scale_rate_threshold"]
  102. MIN_CONFIDENCE = config["min_confidence"]
  103. dv: torch.device = config["device"]
  104. numransacs = rdims.shape[0]
  105. ransidx = torch.arange(numransacs, device=dv).repeat_interleave(rdims)
  106. idxoffsets = torch.cat([torch.tensor([0], device=dv), torch.cumsum(rdims[:-1], dim=0)], dim=0)
  107. rand_samples_rel = draw_first_k_couples(iters, rdims, dv)
  108. rand_samples_abs = rand_samples_rel + idxoffsets
  109. sampled_x = torch.transpose(
  110. xsamples[rand_samples_abs], dim0=1, dim1=2
  111. ) # (niters, 2, numransacs, 2) -> (niters, numransacs, 2, 2)
  112. sampled_y = torch.transpose(ysamples[rand_samples_abs], dim0=1, dim1=2)
  113. # minimal fit for sampled_x @ A^T = sampled_y
  114. affinities_fit = torch.transpose(batch_2x2_inv(sampled_x, check_dets=True) @ sampled_y, -1, -2)
  115. if not refit:
  116. eigenvals, _eigenvecs = batch_2x2_ellipse(affinities_fit)
  117. bad_ones = (eigenvals[..., 1] < 1 / DET_THR**2) | (eigenvals[..., 0] > DET_THR**2)
  118. affinities_fit[bad_ones] = torch.eye(2, device=dv)
  119. y_pred = (affinities_fit[:, ransidx] @ xsamples.unsqueeze(-1)).squeeze(-1)
  120. residuals = torch.norm(y_pred - ysamples, dim=-1) # (niters, numsamples)
  121. inl_ransidx, inl_sampleidx, inl_counts, inl_iters, inl_confidence = confidence_based_inlier_selection(
  122. residuals, ransidx, rdims, idxoffsets, dv=dv, min_confidence=MIN_CONFIDENCE
  123. )
  124. if len(inl_sampleidx) == 0:
  125. # If no inliers have been found, there is nothing to re-fit!
  126. refit = False
  127. if not refit:
  128. return (
  129. inl_sampleidx,
  130. affinities_fit[inl_iters, torch.arange(inl_iters.shape[0], device=dv)],
  131. inl_confidence,
  132. inl_counts,
  133. )
  134. # Organize inliers found into a matrix for efficient GPU re-fitting.
  135. # Cope with the irregular number of inliers per sample by padding with zeros
  136. padded_inlier_x, padded_inlier_y = sample_padded_inliers(
  137. xsamples, ysamples, inl_counts, inl_ransidx, inl_sampleidx, numransacs, dv
  138. )
  139. # A @ pad_x.T = pad_y.T
  140. # A = pad_y.T @ pad_x @ (pad_x.T @ pad_x)^-1
  141. refit_affinity = (
  142. padded_inlier_y.transpose(-2, -1)
  143. @ padded_inlier_x
  144. @ batch_2x2_inv(padded_inlier_x.transpose(-2, -1) @ padded_inlier_x, check_dets=True)
  145. )
  146. # Filter out degenerate affinities with large scale changes
  147. eigenvals, _eigenvecs = batch_2x2_ellipse(refit_affinity)
  148. bad_ones = (eigenvals[..., 1] < 1 / DET_THR**2) | (eigenvals[..., 0] > DET_THR**2)
  149. refit_affinity[bad_ones] = torch.eye(2, device=dv, dtype=refit_affinity.dtype)
  150. y_pred = (refit_affinity[ransidx] @ xsamples.unsqueeze(-1)).squeeze(-1)
  151. residuals = torch.norm(y_pred - ysamples, dim=-1)
  152. inl_ransidx, inl_sampleidx, inl_counts, inl_iters, inl_confidence = confidence_based_inlier_selection(
  153. residuals.unsqueeze(0), ransidx, rdims, idxoffsets, dv=dv, min_confidence=MIN_CONFIDENCE
  154. )
  155. return inl_sampleidx, refit_affinity, inl_confidence, inl_counts