homography.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import warnings
  18. from typing import Optional, Tuple
  19. import torch
  20. from kornia.core import Tensor
  21. from kornia.core.check import KORNIA_CHECK_SHAPE
  22. from kornia.utils import _extract_device_dtype, safe_inverse_with_mask, safe_solve_with_mask
  23. from kornia.utils.helpers import _torch_svd_cast
  24. from .conversions import convert_points_from_homogeneous, convert_points_to_homogeneous
  25. from .epipolar import normalize_points
  26. from .linalg import transform_points
  27. TupleTensor = Tuple[Tensor, Tensor]
  28. def oneway_transfer_error(pts1: Tensor, pts2: Tensor, H: Tensor, squared: bool = True, eps: float = 1e-8) -> Tensor:
  29. r"""Return transfer error in image 2 for correspondences given the homography matrix.
  30. Args:
  31. pts1: correspondences from the left images with shape
  32. (B, N, 2 or 3). If they are homogeneous, converted automatically.
  33. pts2: correspondences from the right images with shape
  34. (B, N, 2 or 3). If they are homogeneous, converted automatically.
  35. H: Homographies with shape :math:`(B, 3, 3)`.
  36. squared: if True (default), the squared distance is returned.
  37. eps: Small constant for safe sqrt.
  38. Returns:
  39. the computed distance with shape :math:`(B, N)`.
  40. """
  41. KORNIA_CHECK_SHAPE(H, ["B", "3", "3"])
  42. if pts1.size(-1) == 3:
  43. pts1 = convert_points_from_homogeneous(pts1)
  44. if pts2.size(-1) == 3:
  45. pts2 = convert_points_from_homogeneous(pts2)
  46. # From Hartley and Zisserman, Error in one image (4.6)
  47. # dist = \sum_{i} ( d(x', Hx)**2)
  48. pts1_in_2: Tensor = transform_points(H, pts1)
  49. error_squared: Tensor = (pts1_in_2 - pts2).pow(2).sum(dim=-1)
  50. if squared:
  51. return error_squared
  52. return (error_squared + eps).sqrt()
  53. def symmetric_transfer_error(pts1: Tensor, pts2: Tensor, H: Tensor, squared: bool = True, eps: float = 1e-8) -> Tensor:
  54. r"""Return Symmetric transfer error for correspondences given the homography matrix.
  55. Args:
  56. pts1: correspondences from the left images with shape
  57. (B, N, 2 or 3). If they are homogeneous, converted automatically.
  58. pts2: correspondences from the right images with shape
  59. (B, N, 2 or 3). If they are homogeneous, converted automatically.
  60. H: Homographies with shape :math:`(B, 3, 3)`.
  61. squared: if True (default), the squared distance is returned.
  62. eps: Small constant for safe sqrt.
  63. Returns:
  64. the computed distance with shape :math:`(B, N)`.
  65. """
  66. KORNIA_CHECK_SHAPE(H, ["B", "3", "3"])
  67. if pts1.size(-1) == 3:
  68. pts1 = convert_points_from_homogeneous(pts1)
  69. if pts2.size(-1) == 3:
  70. pts2 = convert_points_from_homogeneous(pts2)
  71. max_num = torch.finfo(pts1.dtype).max
  72. # From Hartley and Zisserman, Symmetric transfer error (4.7)
  73. # dist = \sum_{i} (d(x, H^-1 x')**2 + d(x', Hx)**2)
  74. H_inv, good_H = safe_inverse_with_mask(H)
  75. there: Tensor = oneway_transfer_error(pts1, pts2, H, True, eps)
  76. back: Tensor = oneway_transfer_error(pts2, pts1, H_inv, True, eps)
  77. good_H_reshape: Tensor = good_H.view(-1, 1).expand_as(there)
  78. out = (there + back) * good_H_reshape.to(there.dtype) + max_num * (~good_H_reshape).to(there.dtype)
  79. if squared:
  80. return out
  81. return (out + eps).sqrt()
  82. def line_segment_transfer_error_one_way(ls1: Tensor, ls2: Tensor, H: Tensor, squared: bool = False) -> Tensor:
  83. r"""Return transfer error in image 2 for line segment correspondences given the homography matrix.
  84. Line segment end points are reprojected into image 2, and point-to-line error is calculated w.r.t. line,
  85. induced by line segment in image 2. See :cite:`homolines2001` for details.
  86. Args:
  87. ls1: line segment correspondences from the left images with shape
  88. (B, N, 2, 2).
  89. ls2: line segment correspondences from the right images with shape
  90. (B, N, 2, 2).
  91. H: Homographies with shape :math:`(B, 3, 3)`.
  92. squared: if True (default is False), the squared distance is returned.
  93. Returns:
  94. the computed distance with shape :math:`(B, N)`.
  95. """
  96. KORNIA_CHECK_SHAPE(H, ["B", "3", "3"])
  97. KORNIA_CHECK_SHAPE(ls1, ["B", "N", "2", "2"])
  98. KORNIA_CHECK_SHAPE(ls2, ["B", "N", "2", "2"])
  99. B, N = ls1.shape[:2]
  100. ps1, pe1 = torch.chunk(ls1, dim=2, chunks=2)
  101. ps2, pe2 = torch.chunk(ls2, dim=2, chunks=2)
  102. ps2_h = convert_points_to_homogeneous(ps2)
  103. pe2_h = convert_points_to_homogeneous(pe2)
  104. ln2 = torch.linalg.cross(ps2_h, pe2_h, dim=3)
  105. ps1_in2 = convert_points_to_homogeneous(transform_points(H, ps1))
  106. pe1_in2 = convert_points_to_homogeneous(transform_points(H, pe1))
  107. er_st1 = (ln2 @ ps1_in2.transpose(-2, -1)).view(B, N).abs()
  108. er_end1 = (ln2 @ pe1_in2.transpose(-2, -1)).view(B, N).abs()
  109. error = 0.5 * (er_st1 + er_end1)
  110. if squared:
  111. error = error**2
  112. return error
  113. def find_homography_dlt(
  114. points1: torch.Tensor, points2: torch.Tensor, weights: Optional[torch.Tensor] = None, solver: str = "lu"
  115. ) -> torch.Tensor:
  116. r"""Compute the homography matrix using the DLT formulation.
  117. The linear system is solved by using the Weighted Least Squares Solution for the 4 Points algorithm.
  118. Args:
  119. points1: A set of points in the first image with a tensor shape :math:`(B, N, 2)`.
  120. points2: A set of points in the second image with a tensor shape :math:`(B, N, 2)`.
  121. weights: Tensor containing the weights per point correspondence with a shape of :math:`(B, N)`.
  122. solver: variants: svd, lu.
  123. Returns:
  124. the computed homography matrix with shape :math:`(B, 3, 3)`.
  125. """
  126. if points1.shape != points2.shape:
  127. raise AssertionError(points1.shape)
  128. if points1.shape[1] < 4:
  129. raise AssertionError(points1.shape)
  130. KORNIA_CHECK_SHAPE(points1, ["B", "N", "2"])
  131. KORNIA_CHECK_SHAPE(points2, ["B", "N", "2"])
  132. device, dtype = _extract_device_dtype([points1, points2])
  133. eps: float = 1e-8
  134. points1_norm, transform1 = normalize_points(points1)
  135. points2_norm, transform2 = normalize_points(points2)
  136. x1, y1 = torch.chunk(points1_norm, dim=-1, chunks=2) # BxNx1
  137. x2, y2 = torch.chunk(points2_norm, dim=-1, chunks=2) # BxNx1
  138. ones, zeros = torch.ones_like(x1), torch.zeros_like(x1)
  139. # DIAPO 11: https://www.uio.no/studier/emner/matnat/its/nedlagte-emner/UNIK4690/v16/forelesninger/lecture_4_3-estimating-homographies-from-feature-correspondences.pdf # noqa: E501
  140. ax = torch.cat([zeros, zeros, zeros, -x1, -y1, -ones, y2 * x1, y2 * y1, y2], dim=-1)
  141. ay = torch.cat([x1, y1, ones, zeros, zeros, zeros, -x2 * x1, -x2 * y1, -x2], dim=-1)
  142. A = torch.cat((ax, ay), dim=-1).reshape(ax.shape[0], -1, ax.shape[-1])
  143. if weights is None:
  144. # All points are equally important
  145. A = A.transpose(-2, -1) @ A
  146. else:
  147. # We should use provided weights
  148. if not (len(weights.shape) == 2 and weights.shape == points1.shape[:2]):
  149. raise AssertionError(weights.shape)
  150. w_full = weights.repeat_interleave(2, dim=1).unsqueeze(1)
  151. A = (A.transpose(-2, -1) * w_full) @ A
  152. if solver == "svd":
  153. try:
  154. _, _, V = _torch_svd_cast(A)
  155. except RuntimeError:
  156. warnings.warn("SVD did not converge", RuntimeWarning, stacklevel=1)
  157. return torch.empty((points1_norm.size(0), 3, 3), device=device, dtype=dtype)
  158. H = V[..., -1].view(-1, 3, 3)
  159. elif solver == "lu":
  160. B = torch.ones(A.shape[0], A.shape[1], device=device, dtype=dtype)
  161. sol, _, _ = safe_solve_with_mask(B, A)
  162. H = sol.reshape(-1, 3, 3)
  163. else:
  164. raise NotImplementedError
  165. H = safe_inverse_with_mask(transform2)[0] @ (H @ transform1)
  166. H_norm = H / (H[..., -1:, -1:] + eps)
  167. return H_norm
  168. def find_homography_dlt_iterated(
  169. points1: Tensor, points2: Tensor, weights: Tensor, soft_inl_th: float = 3.0, n_iter: int = 5
  170. ) -> Tensor:
  171. r"""Compute the homography matrix using the iteratively-reweighted least squares (IRWLS).
  172. The linear system is solved by using the Reweighted Least Squares Solution for the 4 Points algorithm.
  173. Args:
  174. points1: A set of points in the first image with a tensor shape :math:`(B, N, 2)`.
  175. points2: A set of points in the second image with a tensor shape :math:`(B, N, 2)`.
  176. weights: Tensor containing the weights per point correspondence with a shape of :math:`(B, N)`.
  177. Used for the first iteration of the IRWLS.
  178. soft_inl_th: Soft inlier threshold used for weight calculation.
  179. n_iter: number of iterations.
  180. Returns:
  181. the computed homography matrix with shape :math:`(B, 3, 3)`.
  182. """
  183. H: Tensor = find_homography_dlt(points1, points2, weights)
  184. for _ in range(n_iter - 1):
  185. errors: Tensor = symmetric_transfer_error(points1, points2, H, False)
  186. weights_new: Tensor = torch.exp(-errors / (2.0 * (soft_inl_th**2)))
  187. H = find_homography_dlt(points1, points2, weights_new)
  188. return H
  189. def sample_is_valid_for_homography(points1: Tensor, points2: Tensor) -> Tensor:
  190. """Implement oriented constraint check from :cite:`Marquez-Neila2015`.
  191. Analogous to https://github.com/opencv/opencv/blob/4.x/modules/calib3d/src/usac/degeneracy.cpp#L88
  192. Args:
  193. points1: A set of points in the first image with a tensor shape :math:`(B, 4, 2)`.
  194. points2: A set of points in the second image with a tensor shape :math:`(B, 4, 2)`.
  195. Returns:
  196. Mask with the minimal sample is good for homography estimation:math:`(B, 3, 3)`.
  197. """
  198. if points1.shape != points2.shape:
  199. raise AssertionError(points1.shape)
  200. KORNIA_CHECK_SHAPE(points1, ["B", "4", "2"])
  201. KORNIA_CHECK_SHAPE(points2, ["B", "4", "2"])
  202. device = points1.device
  203. idx_perm = torch.tensor([[0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3]], dtype=torch.long, device=device)
  204. points_src_h = convert_points_to_homogeneous(points1)
  205. points_dst_h = convert_points_to_homogeneous(points2)
  206. src_perm = points_src_h[:, idx_perm]
  207. dst_perm = points_dst_h[:, idx_perm]
  208. left_sign = (
  209. torch.linalg.cross(src_perm[..., 1:2, :], src_perm[..., 2:3, :], dim=-1)
  210. @ src_perm[..., 0:1, :].permute(0, 1, 3, 2)
  211. ).sign()
  212. right_sign = (
  213. torch.linalg.cross(dst_perm[..., 1:2, :], dst_perm[..., 2:3, :], dim=-1)
  214. @ dst_perm[..., 0:1, :].permute(0, 1, 3, 2)
  215. ).sign()
  216. sample_is_valid = (left_sign == right_sign).view(-1, 4).min(dim=1)[0]
  217. return sample_is_valid
  218. def find_homography_lines_dlt(ls1: Tensor, ls2: Tensor, weights: Optional[Tensor] = None) -> Tensor:
  219. """Compute the homography matrix using the DLT formulation for line correspondences.
  220. See :cite:`homolines2001` for details.
  221. The linear system is solved by using the Weighted Least Squares Solution for the 4 Line correspondences algorithm.
  222. Args:
  223. ls1: A set of line segments in the first image with a tensor shape :math:`(B, N, 2, 2)`.
  224. ls2: A set of line segments in the second image with a tensor shape :math:`(B, N, 2, 2)`.
  225. weights: Tensor containing the weights per point correspondence with a shape of :math:`(B, N)`.
  226. Returns:
  227. the computed homography matrix with shape :math:`(B, 3, 3)`.
  228. """
  229. if len(ls1.shape) == 3:
  230. ls1 = ls1[None]
  231. if len(ls2.shape) == 3:
  232. ls2 = ls2[None]
  233. KORNIA_CHECK_SHAPE(ls1, ["B", "N", "2", "2"])
  234. KORNIA_CHECK_SHAPE(ls2, ["B", "N", "2", "2"])
  235. BS, N = ls1.shape[:2]
  236. device, dtype = _extract_device_dtype([ls1, ls2])
  237. points1 = ls1.reshape(BS, 2 * N, 2)
  238. points2 = ls2.reshape(BS, 2 * N, 2)
  239. points1_norm, transform1 = normalize_points(points1)
  240. points2_norm, transform2 = normalize_points(points2)
  241. lst1, le1 = torch.chunk(points1_norm, dim=1, chunks=2)
  242. lst2, le2 = torch.chunk(points2_norm, dim=1, chunks=2)
  243. xs1, ys1 = torch.chunk(lst1, dim=-1, chunks=2) # BxNx1
  244. xs2, ys2 = torch.chunk(lst2, dim=-1, chunks=2) # BxNx1
  245. xe1, ye1 = torch.chunk(le1, dim=-1, chunks=2) # BxNx1
  246. xe2, ye2 = torch.chunk(le2, dim=-1, chunks=2) # BxNx1
  247. A = ys2 - ye2
  248. B = xe2 - xs2
  249. C = xs2 * ye2 - xe2 * ys2
  250. eps: float = 1e-8
  251. # http://diis.unizar.es/biblioteca/00/09/000902.pdf
  252. ax = torch.cat([A * xs1, A * ys1, A, B * xs1, B * ys1, B, C * xs1, C * ys1, C], dim=-1)
  253. ay = torch.cat([A * xe1, A * ye1, A, B * xe1, B * ye1, B, C * xe1, C * ye1, C], dim=-1)
  254. A = torch.cat((ax, ay), dim=-1).reshape(ax.shape[0], -1, ax.shape[-1])
  255. if weights is None:
  256. # All points are equally important
  257. A = A.transpose(-2, -1) @ A
  258. else:
  259. # We should use provided weights
  260. if not ((len(weights.shape) == 2) and (weights.shape == ls1.shape[:2])):
  261. raise AssertionError(weights.shape)
  262. w_diag = torch.diag_embed(weights.unsqueeze(dim=-1).repeat(1, 1, 2).reshape(weights.shape[0], -1))
  263. A = A.transpose(-2, -1) @ w_diag @ A
  264. try:
  265. _, _, V = _torch_svd_cast(A)
  266. except RuntimeError:
  267. warnings.warn("SVD did not converge", RuntimeWarning, stacklevel=1)
  268. return torch.empty((points1_norm.size(0), 3, 3), device=device, dtype=dtype)
  269. H = V[..., -1].view(-1, 3, 3)
  270. H = safe_inverse_with_mask(transform2)[0] @ (H @ transform1)
  271. H_norm = H / (H[..., -1:, -1:] + eps)
  272. return H_norm
  273. def find_homography_lines_dlt_iterated(
  274. ls1: Tensor, ls2: Tensor, weights: Tensor, soft_inl_th: float = 4.0, n_iter: int = 5
  275. ) -> Tensor:
  276. r"""Compute the homography matrix using the iteratively-reweighted least squares (IRWLS) from line segments.
  277. The linear system is solved by using the Reweighted Least Squares Solution for the 4 line segments algorithm.
  278. Args:
  279. ls1: A set of line segments in the first image with a tensor shape :math:`(B, N, 2, 2)`.
  280. ls2: A set of line segments in the second image with a tensor shape :math:`(B, N, 2, 2)`.
  281. weights: Tensor containing the weights per point correspondence with a shape of :math:`(B, N)`.
  282. Used for the first iteration of the IRWLS.
  283. soft_inl_th: Soft inlier threshold used for weight calculation.
  284. n_iter: number of iterations.
  285. Returns:
  286. the computed homography matrix with shape :math:`(B, 3, 3)`.
  287. """
  288. H: Tensor = find_homography_lines_dlt(ls1, ls2, weights)
  289. for _ in range(n_iter - 1):
  290. errors: Tensor = line_segment_transfer_error_one_way(ls1, ls2, H, False)
  291. weights_new: Tensor = torch.exp(-errors / (2.0 * (soft_inl_th**2)))
  292. H = find_homography_lines_dlt(ls1, ls2, weights_new)
  293. return H