pnp.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Optional, Tuple
  18. import torch
  19. from torch.linalg import qr as linalg_qr
  20. from kornia.core import arange, ones_like, where, zeros
  21. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SAME_SHAPE, KORNIA_CHECK_SHAPE
  22. from kornia.geometry.conversions import convert_points_to_homogeneous, normalize_points_with_intrinsics
  23. from kornia.geometry.linalg import transform_points
  24. from kornia.utils import eye_like
  25. from kornia.utils.helpers import _torch_linalg_svdvals, _torch_svd_cast
  26. def _mean_isotropic_scale_normalize(points: torch.Tensor, eps: float = 1e-8) -> Tuple[torch.Tensor, torch.Tensor]:
  27. r"""Normalize points.
  28. Args:
  29. points : Tensor containing the points to be normalized with shape :math:`(B, N, D)`.
  30. eps : Small value to avoid division by zero error.
  31. Returns:
  32. Tuple containing the normalized points in the shape :math:`(B, N, D)` and the transformation matrix
  33. in the shape :math:`(B, D+1, D+1)`.
  34. """
  35. KORNIA_CHECK_SHAPE(points, ["B", "N", "D"])
  36. x_mean = torch.mean(points, dim=1, keepdim=True) # Bx1xD
  37. scale = (points - x_mean).norm(dim=-1, p=2).mean(dim=-1) # B
  38. D_int = points.shape[-1]
  39. D_float = torch.tensor(points.shape[-1], dtype=torch.float64, device=points.device)
  40. scale = torch.sqrt(D_float) / (scale + eps) # B
  41. transform = eye_like(D_int + 1, points) # (B, D+1, D+1)
  42. idxs = arange(D_int, dtype=torch.int64, device=points.device)
  43. transform[:, idxs, idxs] = transform[:, idxs, idxs] * scale[:, None]
  44. transform[:, idxs, D_int] = transform[:, idxs, D_int] + (-scale[:, None] * x_mean[:, 0, idxs])
  45. points_norm = transform_points(transform, points) # BxNxD
  46. return (points_norm, transform)
  47. def solve_pnp_dlt(
  48. world_points: torch.Tensor,
  49. img_points: torch.Tensor,
  50. intrinsics: torch.Tensor,
  51. weights: Optional[torch.Tensor] = None,
  52. svd_eps: float = 1e-4,
  53. ) -> torch.Tensor:
  54. r"""Attempt to solve the Perspective-n-Point (PnP) problem using Direct Linear Transform (DLT).
  55. Given a batch (where batch size is :math:`B`) of :math:`N` 3D points
  56. (where :math:`N \geq 6`) in the world space, a batch of :math:`N`
  57. corresponding 2D points in the image space and a batch of
  58. intrinsic matrices, this function tries to estimate a batch of
  59. world to camera transformation matrices.
  60. This implementation needs at least 6 points (i.e. :math:`N \geq 6`) to
  61. provide solutions.
  62. This function cannot be used if all the 3D world points (of any element
  63. of the batch) lie on a line or if all the 3D world points (of any element
  64. of the batch) lie on a plane. This function attempts to check for these
  65. conditions and throws an AssertionError if found. Do note that this check
  66. is sensitive to the value of the svd_eps parameter.
  67. Another bad condition occurs when the camera and the points lie on a
  68. twisted cubic. However, this function does not check for this condition.
  69. Args:
  70. world_points : A tensor with shape :math:`(B, N, 3)` representing
  71. the points in the world space.
  72. img_points : A tensor with shape :math:`(B, N, 2)` representing
  73. the points in the image space.
  74. intrinsics : A tensor with shape :math:`(B, 3, 3)` representing
  75. the intrinsic matrices.
  76. weights : A tensor with shape :math:`(B, N)` representing the
  77. weights for each point. If None, all points are considered to be equally important.
  78. svd_eps : A small float value to avoid numerical precision issues.
  79. Returns:
  80. A tensor with shape :math:`(B, 3, 4)` representing the estimated world to
  81. camera transformation matrices (also known as the extrinsic matrices).
  82. Example:
  83. >>> world_points = torch.tensor([[
  84. ... [ 5. , -5. , 0. ], [ 0. , 0. , 1.5],
  85. ... [ 2.5, 3. , 6. ], [ 9. , -2. , 3. ],
  86. ... [-4. , 5. , 2. ], [-5. , 5. , 1. ],
  87. ... ]], dtype=torch.float64)
  88. >>>
  89. >>> img_points = torch.tensor([[
  90. ... [1409.1504, -800.936 ], [ 407.0207, -182.1229],
  91. ... [ 392.7021, 177.9428], [1016.838 , -2.9416],
  92. ... [ -63.1116, 142.9204], [-219.3874, 99.666 ],
  93. ... ]], dtype=torch.float64)
  94. >>>
  95. >>> intrinsics = torch.tensor([[
  96. ... [ 500., 0., 250.],
  97. ... [ 0., 500., 250.],
  98. ... [ 0., 0., 1.],
  99. ... ]], dtype=torch.float64)
  100. >>>
  101. >>> print(world_points.shape, img_points.shape, intrinsics.shape)
  102. torch.Size([1, 6, 3]) torch.Size([1, 6, 2]) torch.Size([1, 3, 3])
  103. >>>
  104. >>> pred_world_to_cam = kornia.geometry.solve_pnp_dlt(world_points, img_points, intrinsics)
  105. >>>
  106. >>> print(pred_world_to_cam.shape)
  107. torch.Size([1, 3, 4])
  108. >>>
  109. >>> pred_world_to_cam
  110. tensor([[[ 0.9392, -0.3432, -0.0130, 1.6734],
  111. [ 0.3390, 0.9324, -0.1254, -4.3634],
  112. [ 0.0552, 0.1134, 0.9920, 3.7785]]], dtype=torch.float64)
  113. """
  114. # This function was implemented based on ideas inspired from multiple references.
  115. # ============
  116. # References:
  117. # ============
  118. # 1. https://team.inria.fr/lagadic/camera_localization/tutorial-pose-dlt-opencv.html
  119. # 2. https://github.com/opencv/opencv/blob/68d15fc62edad980f1ffa15ee478438335f39cc3/modules/calib3d/src/calibration.cpp # noqa: E501
  120. # 3. http://rpg.ifi.uzh.ch/docs/teaching/2020/03_camera_calibration.pdf
  121. # 4. http://www.cs.cmu.edu/~16385/s17/Slides/11.3_Pose_Estimation.pdf
  122. # 5. https://www.ece.mcmaster.ca/~shirani/vision/hartley_ch7.pdf
  123. accepted_dtypes = (torch.float32, torch.float64)
  124. KORNIA_CHECK_IS_TENSOR(world_points)
  125. KORNIA_CHECK_IS_TENSOR(img_points)
  126. KORNIA_CHECK_IS_TENSOR(intrinsics)
  127. KORNIA_CHECK(isinstance(svd_eps, float))
  128. KORNIA_CHECK(world_points.dtype in accepted_dtypes)
  129. KORNIA_CHECK(img_points.dtype in accepted_dtypes)
  130. KORNIA_CHECK(intrinsics.dtype in accepted_dtypes)
  131. KORNIA_CHECK_SHAPE(world_points, ["B", "N", "3"])
  132. KORNIA_CHECK_SHAPE(img_points, ["B", "N", "2"])
  133. KORNIA_CHECK_SHAPE(intrinsics, ["B", "3", "3"])
  134. KORNIA_CHECK_SAME_SHAPE(world_points[:, :, 0], img_points[:, :, 0])
  135. KORNIA_CHECK(world_points.shape[1] >= 6)
  136. if weights is not None:
  137. KORNIA_CHECK_IS_TENSOR(weights)
  138. B, N = world_points.shape[:2]
  139. # Getting normalized world points.
  140. world_points_norm, world_transform_norm = _mean_isotropic_scale_normalize(world_points)
  141. # Checking if world_points_norm (of any element of the batch) has rank = 3. This
  142. # function cannot be used if all world points (of any element of the batch) lie
  143. # on a line or if all world points (of any element of the batch) lie on a plane.
  144. s = _torch_linalg_svdvals(world_points_norm)
  145. if torch.any(s[:, -1] < svd_eps):
  146. raise AssertionError(
  147. "The last singular value of one/more of the elements of the batch is smaller "
  148. f"than {svd_eps}. This function cannot be used if all world_points (of any "
  149. "element of the batch) lie on a line or if all world_points (of any "
  150. "element of the batch) lie on a plane."
  151. )
  152. world_points_norm_h = convert_points_to_homogeneous(world_points_norm)
  153. # Normalizing img_points_inv
  154. img_points_inv = normalize_points_with_intrinsics(img_points, intrinsics)
  155. img_points_norm, img_transform_norm = _mean_isotropic_scale_normalize(img_points_inv)
  156. inv_img_transform_norm = torch.inverse(img_transform_norm)
  157. # Setting up the system (the matrix A in Ax=0)
  158. system = zeros((B, 2 * N, 12), dtype=world_points.dtype, device=world_points.device)
  159. system[:, 0::2, 0:4] = world_points_norm_h
  160. system[:, 1::2, 4:8] = world_points_norm_h
  161. system[:, 0::2, 8:12] = world_points_norm_h * (-1) * img_points_norm[..., 0:1]
  162. system[:, 1::2, 8:12] = world_points_norm_h * (-1) * img_points_norm[..., 1:2]
  163. # Apply weights to the system if provided
  164. if weights is not None:
  165. if weights.shape != (B, N):
  166. raise AssertionError(f"Weights should have shape (B, N). Got {weights.shape}.")
  167. weights_expanded = weights.unsqueeze(-1).repeat(1, 1, 2).view(B, 2 * N, 1)
  168. # Multiply the system matrix by the expanded weights
  169. system = weights_expanded * system
  170. # Getting the solution vectors.
  171. _, _, v = _torch_svd_cast(system)
  172. solution = v[..., -1]
  173. # Reshaping the solution vectors to the correct shape.
  174. solution = solution.reshape(B, 3, 4)
  175. # Creating solution_4x4
  176. solution_4x4 = eye_like(4, solution)
  177. solution_4x4[:, :3, :] = solution
  178. # De-normalizing the solution
  179. intermediate = torch.bmm(solution_4x4, world_transform_norm)
  180. solution = torch.bmm(inv_img_transform_norm, intermediate[:, :3, :])
  181. # We obtained one solution for each element of the batch. We may
  182. # need to multiply each solution with a scalar. This is because
  183. # if x is a solution to Ax=0, then cx is also a solution. We can
  184. # find the required scalars by using the properties of
  185. # rotation matrices. We do this in two parts:
  186. # First, we fix the sign by making sure that the determinant of
  187. # the all the rotation matrices are non-negative (since determinant
  188. # of a rotation matrix should be 1).
  189. det = torch.det(solution[:, :3, :3])
  190. ones = ones_like(det)
  191. sign_fix = where(det < 0, ones * -1, ones)
  192. solution = solution * sign_fix[:, None, None]
  193. # Then, we make sure that norm of the 0th columns of the rotation
  194. # matrices are 1. Do note that the norm of any column of a rotation
  195. # matrix should be 1. Here we use the 0th column to calculate norm_col.
  196. # We then multiply solution with mul_factor.
  197. norm_col = torch.norm(input=solution[:, :3, 0], p=2, dim=1)
  198. mul_factor = (1 / norm_col)[:, None, None]
  199. temp = solution * mul_factor
  200. # To make sure that the rotation matrix would be orthogonal, we apply
  201. # QR decomposition.
  202. ortho, right = linalg_qr(temp[:, :3, :3])
  203. # We may need to fix the signs of the columns of the ortho matrix.
  204. # If right[i, j, j] is negative, then we need to flip the signs of
  205. # the column ortho[i, :, j]. The below code performs the necessary
  206. # operations in a better way.
  207. mask = eye_like(3, ortho)
  208. col_sign_fix = torch.sign(mask * right)
  209. rot_mat = torch.bmm(ortho, col_sign_fix)
  210. # Preparing the final output.
  211. pred_world_to_cam = torch.cat([rot_mat, temp[:, :3, 3:4]], dim=-1)
  212. # TODO: Implement algorithm to refine the solution.
  213. return pred_world_to_cam