fundamental.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """Module containing the functionalities for computing the Fundamental Matrix."""
  18. from typing import Literal, Optional, Tuple
  19. import torch
  20. from kornia.core import Tensor, concatenate, ones_like, stack, where, zeros
  21. from kornia.core.check import KORNIA_CHECK_SAME_SHAPE, KORNIA_CHECK_SHAPE
  22. from kornia.geometry.conversions import convert_points_from_homogeneous, convert_points_to_homogeneous
  23. from kornia.geometry.linalg import transform_points
  24. from kornia.geometry.solvers import solve_cubic
  25. from kornia.utils.helpers import _torch_svd_cast, safe_inverse_with_mask
  26. def normalize_points(points: Tensor, eps: float = 1e-8) -> Tuple[Tensor, Tensor]:
  27. r"""Normalize points (isotropic).
  28. Computes the transformation matrix such that the two principal moments of the set of points
  29. are equal to unity, forming an approximately symmetric circular cloud of points of radius 1
  30. about the origin. Reference: Hartley/Zisserman 4.4.4 pag.107
  31. This operation is an essential step before applying the DLT algorithm in order to consider
  32. the result as optimal.
  33. Args:
  34. points: Tensor containing the points to be normalized with shape :math:`(B, N, 2)`.
  35. eps: epsilon value to avoid numerical instabilities.
  36. Returns:
  37. tuple containing the normalized points in the shape :math:`(B, N, 2)` and the transformation matrix
  38. in the shape :math:`(B, 3, 3)`.
  39. """
  40. if len(points.shape) != 3:
  41. raise AssertionError(points.shape)
  42. if points.shape[-1] != 2:
  43. raise AssertionError(points.shape)
  44. x_mean = torch.mean(points, dim=1, keepdim=True) # Bx1x2
  45. scale = (points - x_mean).norm(dim=-1, p=2).mean(dim=-1) # B
  46. scale = torch.sqrt(torch.tensor(2.0)) / (scale + eps) # B
  47. ones, zeros = ones_like(scale), torch.zeros_like(scale)
  48. transform = stack(
  49. [scale, zeros, -scale * x_mean[..., 0, 0], zeros, scale, -scale * x_mean[..., 0, 1], zeros, zeros, ones], dim=-1
  50. ) # Bx9
  51. transform = transform.view(-1, 3, 3) # Bx3x3
  52. points_norm = transform_points(transform, points) # BxNx2
  53. return points_norm, transform
  54. def normalize_transformation(M: Tensor, eps: float = 1e-8) -> Tensor:
  55. r"""Normalize a given transformation matrix.
  56. The function trakes the transformation matrix and normalize so that the value in
  57. the last row and column is one.
  58. Args:
  59. M: The transformation to be normalized of any shape with a minimum size of 2x2.
  60. eps: small value to avoid unstabilities during the backpropagation.
  61. Returns:
  62. the normalized transformation matrix with same shape as the input.
  63. """
  64. if len(M.shape) < 2:
  65. raise AssertionError(M.shape)
  66. norm_val: Tensor = M[..., -1:, -1:]
  67. return where(norm_val.abs() > eps, M / (norm_val + eps), M)
  68. # Reference: Adapted from the 'run_7point' function in opencv
  69. # https://github.com/opencv/opencv/blob/4.x/modules/calib3d/src/fundam.cpp
  70. def run_7point(points1: Tensor, points2: Tensor) -> Tensor:
  71. r"""Compute the fundamental matrix using the 7-point algorithm.
  72. Args:
  73. points1: A set of points in the first image with a tensor shape :math:`(B, N, 2)`.
  74. points2: A set of points in the second image with a tensor shape :math:`(B, N, 2)`.
  75. Returns:
  76. the computed fundamental matrix with shape :math:`(B, 3*m, 3), Valid values of m are 1, 2 or 3`
  77. """
  78. KORNIA_CHECK_SHAPE(points1, ["B", "7", "2"])
  79. KORNIA_CHECK_SHAPE(points2, ["B", "7", "2"])
  80. batch_size = points1.shape[0]
  81. points1_norm, transform1 = normalize_points(points1)
  82. points2_norm, transform2 = normalize_points(points2)
  83. x1, y1 = torch.chunk(points1_norm, dim=-1, chunks=2) # Bx1xN
  84. x2, y2 = torch.chunk(points2_norm, dim=-1, chunks=2) # Bx1xN
  85. ones = ones_like(x1)
  86. # form a linear system: which represents
  87. # the equation (x2[i], 1)*F*(x1[i], 1) = 0
  88. X = concatenate([x2 * x1, x2 * y1, x2, y2 * x1, y2 * y1, y2, x1, y1, ones], -1) # BxNx9
  89. # X * Fmat = 0 is singular (7 equations for 9 variables)
  90. # solving for nullspace of X to get two F
  91. ####### unstable failing gradcheck
  92. # _, _, v = torch.linalg.svd(X)
  93. _, _, v = _torch_svd_cast(X)
  94. # last two singular vector as a basic of the space
  95. f1 = v[..., 7].view(-1, 3, 3)
  96. f2 = v[..., 8].view(-1, 3, 3)
  97. # lambda*f1 + mu*f2 is an arbitrary fundamental matrix
  98. # f ~ lambda*f1 + (1 - lambda)*f2
  99. # det(f) = det(lambda*f1 + (1-lambda)*f2), find lambda
  100. # form a cubic equation
  101. # finding the coefficients of cubic polynomial (coeffs)
  102. coeffs = zeros((batch_size, 4), device=v.device, dtype=v.dtype)
  103. f1_det = torch.linalg.det(f1)
  104. f2_det = torch.linalg.det(f2)
  105. coeffs[:, 0] = f1_det
  106. coeffs[:, 1] = torch.einsum("bii->b", f2 @ safe_inverse_with_mask(f1)[0]) * f1_det
  107. coeffs[:, 2] = torch.einsum("bii->b", f1 @ safe_inverse_with_mask(f2)[0]) * f2_det
  108. coeffs[:, 3] = f2_det
  109. # solve the cubic equation, there can be 1 to 3 roots
  110. # roots = torch.tensor(np.roots(coeffs.numpy()))
  111. roots = solve_cubic(coeffs)
  112. fmatrix = zeros((batch_size, 3, 3, 3), device=v.device, dtype=v.dtype)
  113. valid_root_mask = (torch.count_nonzero(roots, dim=1) < 3) | (torch.count_nonzero(roots, dim=1) > 1)
  114. _lambda = roots
  115. _mu = torch.ones_like(_lambda)
  116. _s = f1[valid_root_mask, 2, 2].unsqueeze(dim=1) * roots[valid_root_mask] + f2[valid_root_mask, 2, 2].unsqueeze(
  117. dim=1
  118. )
  119. # _s_non_zero_mask = torch.abs(_s ) > 1e-16
  120. _s_non_zero_mask = ~torch.isclose(_s, torch.tensor(0.0, device=v.device, dtype=v.dtype))
  121. _mu[_s_non_zero_mask] = 1.0 / _s[_s_non_zero_mask]
  122. _lambda[_s_non_zero_mask] = _lambda[_s_non_zero_mask] * _mu[_s_non_zero_mask]
  123. f1_expanded = f1.unsqueeze(1).expand(batch_size, 3, 3, 3)
  124. f2_expanded = f2.unsqueeze(1).expand(batch_size, 3, 3, 3)
  125. fmatrix[valid_root_mask] = (
  126. f1_expanded[valid_root_mask] * _lambda[valid_root_mask, :, None, None]
  127. + f2_expanded[valid_root_mask] * _mu[valid_root_mask, :, None, None]
  128. )
  129. mat_ind = zeros(3, 3, dtype=torch.bool)
  130. mat_ind[2, 2] = True
  131. fmatrix[_s_non_zero_mask, mat_ind] = 1.0
  132. fmatrix[~_s_non_zero_mask, mat_ind] = 0.0
  133. trans1_exp = transform1[valid_root_mask].unsqueeze(1).expand(-1, fmatrix.shape[2], -1, -1)
  134. trans2_exp = transform2[valid_root_mask].unsqueeze(1).expand(-1, fmatrix.shape[2], -1, -1)
  135. fmatrix[valid_root_mask] = torch.matmul(
  136. trans2_exp.transpose(-2, -1), torch.matmul(fmatrix[valid_root_mask], trans1_exp)
  137. )
  138. return normalize_transformation(fmatrix)
  139. def run_8point(points1: Tensor, points2: Tensor, weights: Optional[Tensor] = None) -> Tensor:
  140. r"""Compute the fundamental matrix using the DLT formulation.
  141. The linear system is solved by using the Weighted Least Squares Solution for the 8 Points algorithm.
  142. Args:
  143. points1: A set of points in the first image with a tensor shape :math:`(B, N, 2), N>=8`.
  144. points2: A set of points in the second image with a tensor shape :math:`(B, N, 2), N>=8`.
  145. weights: Tensor containing the weights per point correspondence with a shape of :math:`(B, N)`.
  146. Returns:
  147. the computed fundamental matrix with shape :math:`(B, 3, 3)`.
  148. """
  149. KORNIA_CHECK_SHAPE(points1, ["B", "N", "2"])
  150. KORNIA_CHECK_SHAPE(points2, ["B", "N", "2"])
  151. KORNIA_CHECK_SAME_SHAPE(points1, points2)
  152. if points1.shape[1] < 8:
  153. raise AssertionError(points1.shape)
  154. if weights is not None:
  155. KORNIA_CHECK_SHAPE(weights, ["B", "N"])
  156. if not (weights.shape[1] == points1.shape[1]):
  157. raise AssertionError(weights.shape)
  158. points1_norm, transform1 = normalize_points(points1)
  159. points2_norm, transform2 = normalize_points(points2)
  160. x1, y1 = torch.chunk(points1_norm, dim=-1, chunks=2) # Bx1xN
  161. x2, y2 = torch.chunk(points2_norm, dim=-1, chunks=2) # Bx1xN
  162. ones = ones_like(x1)
  163. # build equations system and solve DLT
  164. # https://www.cc.gatech.edu/~afb/classes/CS4495-Fall2013/slides/CS4495-09-TwoViews-2.pdf
  165. # [x * x', x * y', x, y * x', y * y', y, x', y', 1]
  166. X = torch.cat([x2 * x1, x2 * y1, x2, y2 * x1, y2 * y1, y2, x1, y1, ones], dim=-1) # BxNx9
  167. # apply the weights to the linear system
  168. if weights is None:
  169. X = X.transpose(-2, -1) @ X
  170. else:
  171. w_diag = torch.diag_embed(weights)
  172. X = X.transpose(-2, -1) @ w_diag @ X
  173. # compute eigevectors and retrieve the one with the smallest eigenvalue
  174. _, _, V = _torch_svd_cast(X)
  175. F_mat = V[..., -1].view(-1, 3, 3)
  176. # reconstruct and force the matrix to have rank2
  177. U, S, V = _torch_svd_cast(F_mat)
  178. rank_mask = torch.tensor([1.0, 1.0, 0.0], device=F_mat.device, dtype=F_mat.dtype)
  179. F_projected = U @ (torch.diag_embed(S * rank_mask) @ V.transpose(-2, -1))
  180. F_est = transform2.transpose(-2, -1) @ (F_projected @ transform1)
  181. return normalize_transformation(F_est)
  182. def find_fundamental(
  183. points1: Tensor, points2: Tensor, weights: Optional[Tensor] = None, method: Literal["8POINT", "7POINT"] = "8POINT"
  184. ) -> Tensor:
  185. r"""Find the fundamental matrix.
  186. Args:
  187. points1: A set of points in the first image with a tensor shape :math:`(B, N, 2), N>=8`.
  188. points2: A set of points in the second image with a tensor shape :math:`(B, N, 2), N>=8`.
  189. weights: Tensor containing the weights per point correspondence with a shape of :math:`(B, N)`.
  190. method: The method to use for computing the fundamental matrix. Supported methods are "7POINT" and "8POINT".
  191. Returns:
  192. the computed fundamental matrix with shape :math:`(B, 3*m, 3)`, where `m` number of fundamental matrix.
  193. Raises:
  194. ValueError: If an invalid method is provided.
  195. """
  196. if method.upper() == "7POINT":
  197. result = run_7point(points1, points2)
  198. elif method.upper() == "8POINT":
  199. result = run_8point(points1, points2, weights)
  200. else:
  201. raise ValueError(f"Invalid method: {method}. Supported methods are '7POINT' and '8POINT'.")
  202. return result
  203. def compute_correspond_epilines(points: Tensor, F_mat: Tensor) -> Tensor:
  204. r"""Compute the corresponding epipolar line for a given set of points.
  205. Args:
  206. points: tensor containing the set of points to project in the shape of :math:`(*, N, 2)` or :math:`(*, N, 3)`.
  207. F_mat: the fundamental to use for projection the points in the shape of :math:`(*, 3, 3)`.
  208. Returns:
  209. a tensor with shape :math:`(*, N, 3)` containing a vector of the epipolar
  210. lines corresponding to the points to the other image. Each line is described as
  211. :math:`ax + by + c = 0` and encoding the vectors as :math:`(a, b, c)`.
  212. """
  213. KORNIA_CHECK_SHAPE(points, ["*", "N", "DIM"])
  214. if points.shape[-1] == 2:
  215. points_h: Tensor = convert_points_to_homogeneous(points)
  216. elif points.shape[-1] == 3:
  217. points_h = points
  218. else:
  219. raise AssertionError(points.shape)
  220. KORNIA_CHECK_SHAPE(F_mat, ["*", "3", "3"])
  221. # project points and retrieve lines components
  222. points_h = torch.transpose(points_h, dim0=-2, dim1=-1)
  223. a, b, c = torch.chunk(F_mat @ points_h, dim=-2, chunks=3)
  224. # compute normal and compose equation line
  225. nu: Tensor = a * a + b * b
  226. nu = where(nu > 0.0, 1.0 / torch.sqrt(nu), torch.ones_like(nu))
  227. line = torch.cat([a * nu, b * nu, c * nu], dim=-2) # *x3xN
  228. return torch.transpose(line, dim0=-2, dim1=-1) # *xNx3
  229. def get_perpendicular(lines: Tensor, points: Tensor) -> Tensor:
  230. r"""Compute the perpendicular to a line, through the point.
  231. Args:
  232. lines: tensor containing the set of lines :math:`(*, N, 3)`.
  233. points: tensor containing the set of points :math:`(*, N, 2)`.
  234. Returns:
  235. a tensor with shape :math:`(*, N, 3)` containing a vector of the epipolar
  236. perpendicular lines. Each line is described as
  237. :math:`ax + by + c = 0` and encoding the vectors as :math:`(a, b, c)`.
  238. """
  239. KORNIA_CHECK_SHAPE(lines, ["*", "N", "3"])
  240. KORNIA_CHECK_SHAPE(points, ["*", "N", "two"])
  241. if points.shape[2] == 2:
  242. points_h: Tensor = convert_points_to_homogeneous(points)
  243. elif points.shape[2] == 3:
  244. points_h = points
  245. else:
  246. raise AssertionError(points.shape)
  247. infinity_point = lines * torch.tensor([1, 1, 0], dtype=lines.dtype, device=lines.device).view(1, 1, 3)
  248. perp: Tensor = torch.linalg.cross(points_h, infinity_point, dim=2)
  249. return perp
  250. def get_closest_point_on_epipolar_line(pts1: Tensor, pts2: Tensor, Fm: Tensor) -> Tensor:
  251. """Return closest point on the epipolar line to the correspondence, given the fundamental matrix.
  252. Args:
  253. pts1: correspondences from the left images with shape :math:`(*, N, (2|3))`. If they are not homogeneous,
  254. converted automatically.
  255. pts2: correspondences from the right images with shape :math:`(*, N, (2|3))`. If they are not homogeneous,
  256. converted automatically.
  257. Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to avoid ambiguity with torch.nn.functional.
  258. Returns:
  259. point on epipolar line :math:`(*, N, 2)`.
  260. """
  261. if not isinstance(Fm, Tensor):
  262. raise TypeError(f"Fm type is not a torch.Tensor. Got {type(Fm)}")
  263. if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3):
  264. raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}")
  265. if pts1.shape[-1] == 2:
  266. pts1 = convert_points_to_homogeneous(pts1)
  267. if pts2.shape[-1] == 2:
  268. pts2 = convert_points_to_homogeneous(pts2)
  269. line1in2 = compute_correspond_epilines(pts1, Fm)
  270. perp = get_perpendicular(line1in2, pts2)
  271. points1_in_2 = convert_points_from_homogeneous(torch.linalg.cross(line1in2, perp, dim=2))
  272. return points1_in_2
  273. def fundamental_from_essential(E_mat: Tensor, K1: Tensor, K2: Tensor) -> Tensor:
  274. r"""Get the Fundamental matrix from Essential and camera matrices.
  275. Uses the method from Hartley/Zisserman 9.6 pag 257 (formula 9.12).
  276. Args:
  277. E_mat: The essential matrix with shape of :math:`(*, 3, 3)`.
  278. K1: The camera matrix from first camera with shape :math:`(*, 3, 3)`.
  279. K2: The camera matrix from second camera with shape :math:`(*, 3, 3)`.
  280. Returns:
  281. The fundamental matrix with shape :math:`(*, 3, 3)`.
  282. """
  283. KORNIA_CHECK_SHAPE(E_mat, ["*", "3", "3"])
  284. KORNIA_CHECK_SHAPE(K1, ["*", "3", "3"])
  285. KORNIA_CHECK_SHAPE(K2, ["*", "3", "3"])
  286. if not len(E_mat.shape[:-2]) == len(K1.shape[:-2]) == len(K2.shape[:-2]):
  287. raise AssertionError
  288. return (safe_inverse_with_mask(K2)[0]).transpose(-2, -1) @ E_mat @ (safe_inverse_with_mask(K1)[0])
  289. # adapted from:
  290. # https://github.com/opencv/opencv_contrib/blob/master/modules/sfm/src/fundamental.cpp#L109
  291. # https://github.com/openMVG/openMVG/blob/160643be515007580086650f2ae7f1a42d32e9fb/src/openMVG/multiview/projection.cpp#L134
  292. def fundamental_from_projections(P1: Tensor, P2: Tensor) -> Tensor:
  293. r"""Get the Fundamental matrix from Projection matrices.
  294. Args:
  295. P1: The projection matrix from first camera with shape :math:`(*, 3, 4)`.
  296. P2: The projection matrix from second camera with shape :math:`(*, 3, 4)`.
  297. Returns:
  298. The fundamental matrix with shape :math:`(*, 3, 3)`.
  299. """
  300. KORNIA_CHECK_SHAPE(P1, ["*", "3", "4"])
  301. KORNIA_CHECK_SHAPE(P2, ["*", "3", "4"])
  302. if P1.shape[:-2] != P2.shape[:-2]:
  303. raise AssertionError
  304. def vstack(x: Tensor, y: Tensor) -> Tensor:
  305. return concatenate([x, y], dim=-2)
  306. input_dtype = P1.dtype
  307. if input_dtype not in (torch.float32, torch.float64):
  308. P1 = P1.to(torch.float32)
  309. P2 = P2.to(torch.float32)
  310. X1 = P1[..., 1:, :]
  311. X2 = vstack(P1[..., 2:3, :], P1[..., 0:1, :])
  312. X3 = P1[..., :2, :]
  313. Y1 = P2[..., 1:, :]
  314. Y2 = vstack(P2[..., 2:3, :], P2[..., 0:1, :])
  315. Y3 = P2[..., :2, :]
  316. X1Y1, X2Y1, X3Y1 = vstack(X1, Y1), vstack(X2, Y1), vstack(X3, Y1)
  317. X1Y2, X2Y2, X3Y2 = vstack(X1, Y2), vstack(X2, Y2), vstack(X3, Y2)
  318. X1Y3, X2Y3, X3Y3 = vstack(X1, Y3), vstack(X2, Y3), vstack(X3, Y3)
  319. F_vec = torch.cat(
  320. [
  321. X1Y1.det().reshape(-1, 1),
  322. X2Y1.det().reshape(-1, 1),
  323. X3Y1.det().reshape(-1, 1),
  324. X1Y2.det().reshape(-1, 1),
  325. X2Y2.det().reshape(-1, 1),
  326. X3Y2.det().reshape(-1, 1),
  327. X1Y3.det().reshape(-1, 1),
  328. X2Y3.det().reshape(-1, 1),
  329. X3Y3.det().reshape(-1, 1),
  330. ],
  331. dim=1,
  332. )
  333. return F_vec.view(*P1.shape[:-2], 3, 3).to(input_dtype)