essential.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """Module containing functionalities for the Essential matrix."""
  18. from typing import Optional, Tuple
  19. import torch
  20. from kornia.core import eye, ones_like, stack, where, zeros
  21. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SAME_SHAPE, KORNIA_CHECK_SHAPE
  22. from kornia.geometry import solvers
  23. from kornia.utils import eye_like, vec_like
  24. from kornia.utils.helpers import _torch_solve_cast, _torch_svd_cast
  25. from .numeric import cross_product_matrix, matrix_cofactor_tensor
  26. from .projection import depth_from_point, projection_from_KRt
  27. from .triangulation import triangulate_points
  28. __all__ = [
  29. "decompose_essential_matrix",
  30. "decompose_essential_matrix_no_svd",
  31. "essential_from_Rt",
  32. "essential_from_fundamental",
  33. "find_essential",
  34. "motion_from_essential",
  35. "motion_from_essential_choose_solution",
  36. "relative_camera_motion",
  37. ]
  38. def run_5point(points1: torch.Tensor, points2: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
  39. r"""Compute the essential matrix using the 5-point algorithm from Nister.
  40. The linear system is solved by Nister's 5-point algorithm [@nister2004efficient],
  41. and the solver implemented referred to [@barath2020magsac++][@wei2023generalized][@wang2023vggsfm].
  42. Args:
  43. points1: A set of carlibrated points in the first image with a tensor shape :math:`(B, N, 2), N>=8`.
  44. points2: A set of points in the second image with a tensor shape :math:`(B, N, 2), N>=8`.
  45. weights: Tensor containing the weights per point correspondence with a shape of :math:`(B, N)`.
  46. Returns:
  47. the computed essential matrix with shape :math:`(B, 3, 3)`.
  48. """
  49. KORNIA_CHECK_SHAPE(points1, ["B", "N", "2"])
  50. KORNIA_CHECK_SAME_SHAPE(points1, points2)
  51. KORNIA_CHECK(points1.shape[1] >= 5, "Number of points should be >=5")
  52. if weights is not None:
  53. KORNIA_CHECK_SAME_SHAPE(points1[:, :, 0], weights)
  54. batch_size, _, _ = points1.shape
  55. x1, y1 = torch.chunk(points1, dim=-1, chunks=2) # Bx1xN
  56. x2, y2 = torch.chunk(points2, dim=-1, chunks=2) # Bx1xN
  57. ones = ones_like(x1)
  58. # build the equation system and find the null space.
  59. # https://www.cc.gatech.edu/~afb/classes/CS4495-Fall2013/slides/CS4495-09-TwoViews-2.pdf
  60. # [x * x', x * y', x, y * x', y * y', y, x', y', 1]
  61. # BxNx9
  62. X = torch.cat([x1 * x2, x1 * y2, x1, y1 * x2, y1 * y2, y1, x2, y2, ones], dim=-1)
  63. # apply the weights to the linear system
  64. if weights is None:
  65. X = X.transpose(-2, -1) @ X
  66. else:
  67. w_diag = torch.diag_embed(weights)
  68. X = X.transpose(-2, -1) @ w_diag @ X
  69. # use Nister's 5PC to solve essential matrix
  70. E_Nister = null_to_Nister_solution(X, batch_size)
  71. return E_Nister
  72. def fun_select(null_mat: torch.Tensor, i: int, j: int, ratio: int = 3) -> torch.Tensor:
  73. return null_mat[:, ratio * j + i]
  74. def null_to_Nister_solution(X: torch.Tensor, batch_size: int) -> torch.Tensor:
  75. r"""Use Nister's 5PC to solve essential matrix.
  76. The linear system is solved by Nister's 5-point algorithm [@nister2004efficient],
  77. and the solver implemented referred to [@barath2020magsac++][@wei2023generalized][@wang2023vggsfm].
  78. Args:
  79. X: Coefficients for the null space :math:`(B, N, 2), N>=8`.
  80. batch_size: batcs size of the input, the number of image pairs :math:`B`.
  81. Returns:
  82. the computed essential matrix with shape :math:`(B, 3, 3)`.
  83. Note that the returned E matrices should be the same batch size with the input.
  84. """
  85. # compute eigenvectors and retrieve the one with the smallest eigenvalue, using SVD
  86. # turn off the grad check due to the unstable gradients from SVD.
  87. # several close to zero values of eigenvalues.
  88. _, _, V = _torch_svd_cast(X) # torch.svd
  89. null_ = V[:, :, -4:] # the last four rows
  90. nullSpace = V.transpose(-1, -2)[:, -4:, :]
  91. coeffs = zeros(batch_size, 10, 20, device=null_.device, dtype=null_.dtype)
  92. d = zeros(batch_size, 60, device=null_.device, dtype=null_.dtype)
  93. # Determinant constraint
  94. coeffs[:, 9] = (
  95. solvers.multiply_deg_two_one_poly(
  96. solvers.multiply_deg_one_poly(fun_select(null_, 0, 1), fun_select(null_, 1, 2))
  97. - solvers.multiply_deg_one_poly(fun_select(null_, 0, 2), fun_select(null_, 1, 1)),
  98. fun_select(null_, 2, 0),
  99. )
  100. + solvers.multiply_deg_two_one_poly(
  101. solvers.multiply_deg_one_poly(fun_select(null_, 0, 2), fun_select(null_, 1, 0))
  102. - solvers.multiply_deg_one_poly(fun_select(null_, 0, 0), fun_select(null_, 1, 2)),
  103. fun_select(null_, 2, 1),
  104. )
  105. + solvers.multiply_deg_two_one_poly(
  106. solvers.multiply_deg_one_poly(fun_select(null_, 0, 0), fun_select(null_, 1, 1))
  107. - solvers.multiply_deg_one_poly(fun_select(null_, 0, 1), fun_select(null_, 1, 0)),
  108. fun_select(null_, 2, 2),
  109. )
  110. )
  111. indices = torch.tensor([[0, 10, 20], [10, 40, 30], [20, 30, 50]])
  112. # Compute EE^T (Eqn. 20 in the paper)
  113. for i in range(3):
  114. for j in range(3):
  115. d[:, indices[i, j] : indices[i, j] + 10] = (
  116. solvers.multiply_deg_one_poly(fun_select(null_, i, 0), fun_select(null_, j, 0))
  117. + solvers.multiply_deg_one_poly(fun_select(null_, i, 1), fun_select(null_, j, 1))
  118. + solvers.multiply_deg_one_poly(fun_select(null_, i, 2), fun_select(null_, j, 2))
  119. )
  120. for i in range(10):
  121. t = 0.5 * (d[:, indices[0, 0] + i] + d[:, indices[1, 1] + i] + d[:, indices[2, 2] + i])
  122. d[:, indices[0, 0] + i] -= t
  123. d[:, indices[1, 1] + i] -= t
  124. d[:, indices[2, 2] + i] -= t
  125. cnt = 0
  126. for i in range(3):
  127. for j in range(3):
  128. row = (
  129. solvers.multiply_deg_two_one_poly(d[:, indices[i, 0] : indices[i, 0] + 10], fun_select(null_, 0, j))
  130. + solvers.multiply_deg_two_one_poly(d[:, indices[i, 1] : indices[i, 1] + 10], fun_select(null_, 1, j))
  131. + solvers.multiply_deg_two_one_poly(d[:, indices[i, 2] : indices[i, 2] + 10], fun_select(null_, 2, j))
  132. )
  133. coeffs[:, cnt] = row
  134. cnt += 1
  135. b = coeffs[:, :, 10:]
  136. singular_filter = torch.linalg.matrix_rank(coeffs[:, :, :10]) >= torch.max(
  137. torch.linalg.matrix_rank(coeffs), ones_like(torch.linalg.matrix_rank(coeffs[:, :, :10])) * 10
  138. )
  139. # check if there is no solution
  140. if singular_filter.sum() == 0:
  141. return torch.eye(3, dtype=coeffs.dtype, device=coeffs.device)[None].expand(batch_size, 10, -1, -1).clone()
  142. eliminated_mat = _torch_solve_cast(coeffs[singular_filter, :, :10], b[singular_filter])
  143. coeffs_ = torch.cat((coeffs[singular_filter, :, :10], eliminated_mat), dim=-1)
  144. # check the batch size after singular filter, for batch operation afterwards
  145. batch_size_filtered = coeffs_.shape[0]
  146. A = zeros(coeffs_.shape[0], 3, 13, device=coeffs_.device, dtype=coeffs_.dtype)
  147. for i in range(3):
  148. A[:, i, 0] = 0.0
  149. A[:, i : i + 1, 1:4] = coeffs_[:, 4 + 2 * i : 5 + 2 * i, 10:13]
  150. A[:, i : i + 1, 0:3] -= coeffs_[:, 5 + 2 * i : 6 + 2 * i, 10:13]
  151. A[:, i, 4] = 0.0
  152. A[:, i : i + 1, 5:8] = coeffs_[:, 4 + 2 * i : 5 + 2 * i, 13:16]
  153. A[:, i : i + 1, 4:7] -= coeffs_[:, 5 + 2 * i : 6 + 2 * i, 13:16]
  154. A[:, i, 8] = 0.0
  155. A[:, i : i + 1, 9:13] = coeffs_[:, 4 + 2 * i : 5 + 2 * i, 16:20]
  156. A[:, i : i + 1, 8:12] -= coeffs_[:, 5 + 2 * i : 6 + 2 * i, 16:20]
  157. # Bx11
  158. cs = solvers.determinant_to_polynomial(A)
  159. # A: Bx3x13
  160. # nullSpace: Bx4x9
  161. # companion matrices to solve the polynomial, in batch
  162. C = zeros((batch_size_filtered, 10, 10), device=cs.device, dtype=cs.dtype)
  163. eye_mat = eye(C[0, 0:-1, 0:-1].shape[0], device=cs.device, dtype=cs.dtype)
  164. C[:, 0:-1, 1:] = eye_mat
  165. cs_de = cs[:, -1].unsqueeze(-1)
  166. cs_de = torch.where(cs_de == 0, torch.tensor(1e-8, dtype=cs_de.dtype), cs_de)
  167. C[:, -1, :] = -cs[:, :-1] / cs_de
  168. roots = torch.real(torch.linalg.eigvals(C))
  169. roots_unsqu = roots.unsqueeze(1)
  170. Bs = stack(
  171. (
  172. A[:, :3, :1] * (roots_unsqu**3)
  173. + A[:, :3, 1:2] * roots_unsqu.square()
  174. + A[:, 0:3, 2:3] * roots_unsqu
  175. + A[:, 0:3, 3:4],
  176. A[:, 0:3, 4:5] * (roots_unsqu**3)
  177. + A[:, 0:3, 5:6] * roots_unsqu.square()
  178. + A[:, 0:3, 6:7] * roots_unsqu
  179. + A[:, 0:3, 7:8],
  180. ),
  181. dim=1,
  182. )
  183. Bs = Bs.transpose(1, -1)
  184. bs = (
  185. (
  186. A[:, 0:3, 8:9] * (roots_unsqu**4)
  187. + A[:, 0:3, 9:10] * (roots_unsqu**3)
  188. + A[:, 0:3, 10:11] * roots_unsqu.square()
  189. + A[:, 0:3, 11:12] * roots_unsqu
  190. + A[:, 0:3, 12:13]
  191. )
  192. .transpose(1, 2)
  193. .unsqueeze(-1)
  194. )
  195. xzs = torch.matmul(torch.linalg.inv(Bs[:, :, 0:2, 0:2]), bs[:, :, 0:2])
  196. mask = (abs(Bs[:, 2].unsqueeze(1) @ xzs - bs[:, 2].unsqueeze(1)) > 1e-3).flatten()
  197. # mask: bx10x1x1
  198. mask = (
  199. abs(torch.matmul(Bs[:, :, 2, :].unsqueeze(2), xzs) - bs[:, :, 2, :].unsqueeze(2)) > 1e-3
  200. ) # .flatten(start_dim=1)
  201. # bx10
  202. mask = mask.squeeze(3).squeeze(2)
  203. if torch.any(mask):
  204. q_batch, r_batch = torch.linalg.qr(Bs[mask])
  205. xyz_to_feed = torch.linalg.solve(r_batch, torch.matmul(q_batch.transpose(-1, -2), bs[mask]))
  206. xzs[mask] = xyz_to_feed
  207. nullSpace_filtered = nullSpace[singular_filter]
  208. Es = (
  209. nullSpace_filtered[:, 0:1] * (-xzs[:, :, 0])
  210. + nullSpace_filtered[:, 1:2] * (-xzs[:, :, 1])
  211. + nullSpace_filtered[:, 2:3] * roots.unsqueeze(-1)
  212. + nullSpace_filtered[:, 3:4]
  213. )
  214. inv = 1.0 / torch.sqrt((-xzs[:, :, 0]) ** 2 + (-xzs[:, :, 1]) ** 2 + roots.unsqueeze(-1) ** 2 + 1.0)
  215. Es *= inv
  216. Es = Es.view(batch_size_filtered, -1, 3, 3).transpose(-1, -2)
  217. # make sure the returned batch size equals to that of inputs
  218. E_return = torch.eye(3, dtype=Es.dtype, device=Es.device)[None].expand(batch_size, 10, -1, -1).clone()
  219. E_return[singular_filter] = Es
  220. return E_return
  221. def essential_from_fundamental(F_mat: torch.Tensor, K1: torch.Tensor, K2: torch.Tensor) -> torch.Tensor:
  222. r"""Get Essential matrix from Fundamental and Camera matrices.
  223. Uses the method from Hartley/Zisserman 9.6 pag 257 (formula 9.12).
  224. Args:
  225. F_mat: The fundamental matrix with shape of :math:`(*, 3, 3)`.
  226. K1: The camera matrix from first camera with shape :math:`(*, 3, 3)`.
  227. K2: The camera matrix from second camera with shape :math:`(*, 3, 3)`.
  228. Returns:
  229. The essential matrix with shape :math:`(*, 3, 3)`.
  230. """
  231. KORNIA_CHECK_SHAPE(F_mat, ["*", "3", "3"])
  232. KORNIA_CHECK_SHAPE(K1, ["*", "3", "3"])
  233. KORNIA_CHECK_SHAPE(K2, ["*", "3", "3"])
  234. return K2.transpose(-2, -1) @ F_mat @ K1
  235. def decompose_essential_matrix(E_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  236. r"""Decompose an essential matrix to possible rotations and translation.
  237. This function decomposes the essential matrix E using svd decomposition [96]
  238. and give the possible solutions: :math:`R1, R2, t`.
  239. Args:
  240. E_mat: The essential matrix in the form of :math:`(*, 3, 3)`.
  241. Returns:
  242. A tuple containing the first and second possible rotation matrices and the translation vector.
  243. The shape of the tensors with be same input :math:`[(*, 3, 3), (*, 3, 3), (*, 3, 1)]`.
  244. """
  245. KORNIA_CHECK_SHAPE(E_mat, ["*", "3", "3"])
  246. # decompose matrix by its singular values
  247. U, _, V = _torch_svd_cast(E_mat)
  248. Vt = V.transpose(-2, -1)
  249. mask = ones_like(E_mat)
  250. mask[..., -1:] *= -1.0 # fill last column with negative values
  251. maskt = mask.transpose(-2, -1)
  252. # avoid singularities
  253. U = where((torch.det(U) < 0.0)[..., None, None], U * mask, U)
  254. Vt = where((torch.det(Vt) < 0.0)[..., None, None], Vt * maskt, Vt)
  255. W = cross_product_matrix(torch.tensor([[0.0, 0.0, 1.0]]).type_as(E_mat))
  256. W[..., 2, 2] += 1.0
  257. # reconstruct rotations and retrieve translation vector
  258. U_W_Vt = U @ W @ Vt
  259. U_Wt_Vt = U @ W.transpose(-2, -1) @ Vt
  260. # return values
  261. R1 = U_W_Vt
  262. R2 = U_Wt_Vt
  263. T = U[..., -1:]
  264. return (R1, R2, T)
  265. def decompose_essential_matrix_no_svd(E_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  266. r"""Decompose the essential matrix to rotation and translation.
  267. Recover rotation and translation from essential matrices without SVD
  268. reference: Horn, Berthold KP. Recovering baseline and orientation from essential matrix[J].
  269. J. Opt. Soc. Am, 1990, 110.
  270. Args:
  271. E_mat: The essential matrix in the form of :math:`(*, 3, 3)`.
  272. Returns:
  273. A tuple containing the first and second possible rotation matrices and the translation vector.
  274. The shape of the tensors with be same input :math:`[(*, 3, 3), (*, 3, 3), (*, 3, 1)]`.
  275. """
  276. KORNIA_CHECK_SHAPE(E_mat, ["*", "3", "3"])
  277. if len(E_mat.shape) != 3:
  278. E_mat = E_mat.view(-1, 3, 3)
  279. B = E_mat.shape[0]
  280. # Eq.18, choose the largest of the three possible pairwise cross-products
  281. e1, e2, e3 = E_mat[..., 0], E_mat[..., 1], E_mat[..., 2]
  282. # sqrt(1/2 trace(EE^T)), B
  283. scale_factor = torch.sqrt(0.5 * torch.diagonal(E_mat @ E_mat.transpose(-1, -2), dim1=-1, dim2=-2).sum(-1))
  284. # B, 3, 3
  285. cross_products = torch.stack(
  286. [torch.linalg.cross(e1, e2, dim=-1), torch.linalg.cross(e2, e3, dim=-1), torch.linalg.cross(e3, e1, dim=-1)],
  287. dim=1,
  288. )
  289. # B, 3, 1
  290. norms = torch.norm(cross_products, dim=-1, keepdim=True)
  291. # B, to select which b1
  292. largest = torch.argmax(norms, dim=-2)
  293. # B, 3, 3
  294. e_cross_products = scale_factor[:, None, None] * cross_products / norms
  295. # broadcast the index
  296. index_expanded = largest.unsqueeze(-1).expand(-1, -1, e_cross_products.size(-1))
  297. # slice at dim=1, select for each batch one b (e1*e2 or e2*e3 or e3*e1), B, 1, 3
  298. b1 = torch.gather(e_cross_products, dim=1, index=index_expanded).squeeze(1)
  299. # normalization
  300. b1_ = b1 / torch.norm(b1, dim=-1, keepdim=True)
  301. # skew-symmetric matrix
  302. B1 = torch.zeros((B, 3, 3), device=E_mat.device, dtype=E_mat.dtype)
  303. t0, t1, t2 = b1[:, 0], b1[:, 1], b1[:, 2]
  304. B1[:, 0, 1], B1[:, 1, 0] = -t2, t2
  305. B1[:, 0, 2], B1[:, 2, 0] = t1, -t1
  306. B1[:, 1, 2], B1[:, 2, 1] = -t0, t0
  307. # the second translation and rotation
  308. B2 = -B1
  309. b2 = -b1
  310. # Eq.24, recover R
  311. # (bb)R = Cofactors(E)^T - BE
  312. R1 = (matrix_cofactor_tensor(E_mat) - B1 @ E_mat) / (b1 * b1).sum().unsqueeze(-1)
  313. R2 = (matrix_cofactor_tensor(E_mat) - B2 @ E_mat) / (b2 * b2).sum().unsqueeze(-1)
  314. return (R1, R2, b1_.unsqueeze(-1))
  315. def essential_from_Rt(R1: torch.Tensor, t1: torch.Tensor, R2: torch.Tensor, t2: torch.Tensor) -> torch.Tensor:
  316. r"""Get the Essential matrix from Camera motion (Rs and ts).
  317. Reference: Hartley/Zisserman 9.6 pag 257 (formula 9.12)
  318. Args:
  319. R1: The first camera rotation matrix with shape :math:`(*, 3, 3)`.
  320. t1: The first camera translation vector with shape :math:`(*, 3, 1)`.
  321. R2: The second camera rotation matrix with shape :math:`(*, 3, 3)`.
  322. t2: The second camera translation vector with shape :math:`(*, 3, 1)`.
  323. Returns:
  324. The Essential matrix with the shape :math:`(*, 3, 3)`.
  325. """
  326. KORNIA_CHECK_SHAPE(R1, ["*", "3", "3"])
  327. KORNIA_CHECK_SHAPE(R2, ["*", "3", "3"])
  328. KORNIA_CHECK_SHAPE(t1, ["*", "3", "1"])
  329. KORNIA_CHECK_SHAPE(t2, ["*", "3", "1"])
  330. # first compute the camera relative motion
  331. R, t = relative_camera_motion(R1, t1, R2, t2)
  332. # get the cross product from relative translation vector
  333. Tx = cross_product_matrix(t[..., 0])
  334. return Tx @ R
  335. def motion_from_essential(E_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  336. r"""Get Motion (R's and t's ) from Essential matrix.
  337. Computes and return four possible poses exist for the decomposition of the Essential
  338. matrix. The possible solutions are :math:`[R1,t], [R1,-t], [R2,t], [R2,-t]`.
  339. Args:
  340. E_mat: The essential matrix in the form of :math:`(*, 3, 3)`.
  341. Returns:
  342. The rotation and translation containing the four possible combination for the retrieved motion.
  343. The tuple is as following :math:`[(*, 4, 3, 3), (*, 4, 3, 1)]`.
  344. """
  345. KORNIA_CHECK_SHAPE(E_mat, ["*", "3", "3"])
  346. # decompose the essential matrix by its possible poses
  347. R1, R2, t = decompose_essential_matrix(E_mat)
  348. # compbine and returns the four possible solutions
  349. Rs = stack([R1, R1, R2, R2], dim=-3)
  350. Ts = stack([t, -t, t, -t], dim=-3)
  351. return Rs, Ts
  352. def motion_from_essential_choose_solution(
  353. E_mat: torch.Tensor,
  354. K1: torch.Tensor,
  355. K2: torch.Tensor,
  356. x1: torch.Tensor,
  357. x2: torch.Tensor,
  358. mask: Optional[torch.Tensor] = None,
  359. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  360. r"""Recover the relative camera rotation and the translation from an estimated essential matrix.
  361. The method checks the corresponding points in two images and also returns the triangulated
  362. 3d points. Internally uses :py:meth:`~kornia.geometry.epipolar.decompose_essential_matrix` and then chooses
  363. the best solution based on the combination that gives more 3d points in front of the camera plane from
  364. :py:meth:`~kornia.geometry.epipolar.triangulate_points`.
  365. Args:
  366. E_mat: The essential matrix in the form of :math:`(*, 3, 3)`.
  367. K1: The camera matrix from first camera with shape :math:`(*, 3, 3)`.
  368. K2: The camera matrix from second camera with shape :math:`(*, 3, 3)`.
  369. x1: The set of points seen from the first camera frame in the camera plane
  370. coordinates with shape :math:`(*, N, 2)`.
  371. x2: The set of points seen from the first camera frame in the camera plane
  372. coordinates with shape :math:`(*, N, 2)`.
  373. mask: A boolean mask which can be used to exclude some points from choosing
  374. the best solution. This is useful for using this function with sets of points of
  375. different cardinality (for instance after filtering with RANSAC) while keeping batch
  376. semantics. Mask is of shape :math:`(*, N)`.
  377. Returns:
  378. The rotation and translation plus the 3d triangulated points.
  379. The tuple is as following :math:`[(*, 3, 3), (*, 3, 1), (*, N, 3)]`.
  380. """
  381. KORNIA_CHECK_SHAPE(E_mat, ["*", "3", "3"])
  382. KORNIA_CHECK_SHAPE(K1, ["*", "3", "3"])
  383. KORNIA_CHECK_SHAPE(K2, ["*", "3", "3"])
  384. KORNIA_CHECK_SHAPE(x1, ["*", "N", "2"])
  385. KORNIA_CHECK_SHAPE(x2, ["*", "N", "2"])
  386. KORNIA_CHECK(len(E_mat.shape[:-2]) == len(K1.shape[:-2]) == len(K2.shape[:-2]))
  387. if mask is not None:
  388. KORNIA_CHECK_SHAPE(mask, ["*", "N"])
  389. KORNIA_CHECK(mask.shape == x1.shape[:-1])
  390. unbatched = len(E_mat.shape) == 2
  391. if unbatched:
  392. # add a leading batch dimension. We will remove it at the end, before
  393. # returning the results
  394. E_mat = E_mat[None]
  395. K1 = K1[None]
  396. K2 = K2[None]
  397. x1 = x1[None]
  398. x2 = x2[None]
  399. if mask is not None:
  400. mask = mask[None]
  401. # compute four possible pose solutions
  402. Rs, ts = motion_from_essential(E_mat)
  403. # set reference view pose and compute projection matrix
  404. R1 = eye_like(3, E_mat) # Bx3x3
  405. t1 = vec_like(3, E_mat) # Bx3x1
  406. # compute the projection matrices for first camera
  407. R1 = R1[:, None].expand(-1, 4, -1, -1)
  408. t1 = t1[:, None].expand(-1, 4, -1, -1)
  409. K1 = K1[:, None].expand(-1, 4, -1, -1)
  410. P1 = projection_from_KRt(K1, R1, t1) # 1x4x4x4
  411. # compute the projection matrices for second camera
  412. R2 = Rs
  413. t2 = ts
  414. K2 = K2[:, None].expand(-1, 4, -1, -1)
  415. P2 = projection_from_KRt(K2, R2, t2) # Bx4x4x4
  416. # triangulate the points
  417. x1 = x1[:, None].expand(-1, 4, -1, -1)
  418. x2 = x2[:, None].expand(-1, 4, -1, -1)
  419. X = triangulate_points(P1, P2, x1, x2) # Bx4xNx3
  420. # project points and compute their depth values
  421. d1 = depth_from_point(R1, t1, X)
  422. d2 = depth_from_point(R2, t2, X)
  423. # verify the point values that have a positive depth value
  424. depth_mask = (d1 > 0.0) & (d2 > 0.0)
  425. if mask is not None:
  426. depth_mask &= mask.unsqueeze(1)
  427. mask_indices = torch.max(depth_mask.sum(-1), dim=-1, keepdim=True)[1]
  428. # get pose and points 3d and return
  429. R_out = Rs[:, mask_indices][:, 0, 0]
  430. t_out = ts[:, mask_indices][:, 0, 0]
  431. points3d_out = X[:, mask_indices][:, 0, 0]
  432. if unbatched:
  433. R_out = R_out[0]
  434. t_out = t_out[0]
  435. points3d_out = points3d_out[0]
  436. return R_out, t_out, points3d_out
  437. def relative_camera_motion(
  438. R1: torch.Tensor, t1: torch.Tensor, R2: torch.Tensor, t2: torch.Tensor
  439. ) -> Tuple[torch.Tensor, torch.Tensor]:
  440. r"""Compute the relative camera motion between two cameras.
  441. Given the motion parameters of two cameras, computes the motion parameters of the second
  442. one assuming the first one to be at the origin. If :math:`T1` and :math:`T2` are the camera motions,
  443. the computed relative motion is :math:`T = T_{2}T^{-1}_{1}`.
  444. Args:
  445. R1: The first camera rotation matrix with shape :math:`(*, 3, 3)`.
  446. t1: The first camera translation vector with shape :math:`(*, 3, 1)`.
  447. R2: The second camera rotation matrix with shape :math:`(*, 3, 3)`.
  448. t2: The second camera translation vector with shape :math:`(*, 3, 1)`.
  449. Returns:
  450. A tuple with the relative rotation matrix and
  451. translation vector with the shape of :math:`[(*, 3, 3), (*, 3, 1)]`.
  452. """
  453. KORNIA_CHECK_SHAPE(R1, ["*", "3", "3"])
  454. KORNIA_CHECK_SHAPE(R2, ["*", "3", "3"])
  455. KORNIA_CHECK_SHAPE(t1, ["*", "3", "1"])
  456. KORNIA_CHECK_SHAPE(t2, ["*", "3", "1"])
  457. # compute first the relative rotation
  458. R = R2 @ R1.transpose(-2, -1)
  459. # compute the relative translation vector
  460. t = t2 - R @ t1
  461. return R, t
  462. def find_essential(
  463. points1: torch.Tensor, points2: torch.Tensor, weights: Optional[torch.Tensor] = None
  464. ) -> torch.Tensor:
  465. r"""Find essential matrices.
  466. Args:
  467. points1: A set of points in the first image with a tensor shape :math:`(B, N, 2), N>=5`.
  468. points2: A set of points in the second image with a tensor shape :math:`(B, N, 2), N>=5`.
  469. weights: Tensor containing the weights per point correspondence with a shape of :math:`(5, N)`.
  470. Returns:
  471. the computed essential matrices with shape :math:`(B, 10, 3, 3)`.
  472. Note that all possible solutions are returned, i.e., 10 essential matrices for each image pair.
  473. To choose the best one out of 10, try to check the one with the lowest Sampson distance.
  474. """
  475. E = run_5point(points1, points2, weights).to(points1.dtype)
  476. return E