projection.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """Module for image projections."""
  18. from typing import Tuple, Union
  19. import torch
  20. from torch.linalg import qr as linalg_qr
  21. from kornia.core import Tensor, concatenate, pad, stack
  22. from kornia.core.check import KORNIA_CHECK_SHAPE
  23. from kornia.utils import eye_like, vec_like
  24. from kornia.utils.helpers import _torch_svd_cast
  25. from .numeric import cross_product_matrix
  26. def intrinsics_like(focal: float, input: Tensor) -> Tensor:
  27. r"""Return a 3x3 intrinsics matrix, with same size as the input.
  28. The center of projection will be based in the input image size.
  29. Args:
  30. focal: the focal length for the camera matrix.
  31. input: image tensor that will determine the batch size and image height
  32. and width. It is assumed to be a tensor in the shape of :math:`(B, C, H, W)`.
  33. Returns:
  34. The camera matrix with the shape of :math:`(B, 3, 3)`.
  35. """
  36. if len(input.shape) != 4:
  37. raise AssertionError(input.shape)
  38. if focal <= 0:
  39. raise AssertionError(focal)
  40. _, _, H, W = input.shape
  41. intrinsics = eye_like(3, input)
  42. intrinsics[..., 0, 0] *= focal
  43. intrinsics[..., 1, 1] *= focal
  44. intrinsics[..., 0, 2] += 1.0 * W / 2
  45. intrinsics[..., 1, 2] += 1.0 * H / 2
  46. return intrinsics
  47. def random_intrinsics(low: Union[float, Tensor], high: Union[float, Tensor]) -> Tensor:
  48. r"""Generate a random camera matrix based on a given uniform distribution.
  49. Args:
  50. low: lower range (inclusive).
  51. high: upper range (exclusive).
  52. Returns:
  53. the random camera matrix with the shape of :math:`(1, 3, 3)`.
  54. """
  55. sampler = torch.distributions.Uniform(low, high)
  56. params = sampler.sample((4,))
  57. fx, fy, cx, cy = params[0], params[1], params[2], params[3]
  58. camera_matrix = torch.tensor([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=fx.dtype, device=fx.device)
  59. return camera_matrix.unsqueeze(0)
  60. def scale_intrinsics(camera_matrix: Tensor, scale_factor: Union[float, Tensor]) -> Tensor:
  61. r"""Scale a camera matrix containing the intrinsics.
  62. Applies the scaling factor to the focal length and center of projection.
  63. Args:
  64. camera_matrix: the camera calibration matrix containing the intrinsic
  65. parameters. The expected shape for the tensor is :math:`(B, 3, 3)`.
  66. scale_factor: the scaling factor to be applied.
  67. Returns:
  68. The scaled camera matrix with shame shape as input :math:`(B, 3, 3)`.
  69. """
  70. K_scale = camera_matrix.clone()
  71. K_scale[..., 0, 0] *= scale_factor
  72. K_scale[..., 1, 1] *= scale_factor
  73. K_scale[..., 0, 2] *= scale_factor
  74. K_scale[..., 1, 2] *= scale_factor
  75. return K_scale
  76. def projection_from_KRt(K: Tensor, R: Tensor, t: Tensor) -> Tensor:
  77. r"""Get the projection matrix P from K, R and t.
  78. This function estimate the projection matrix by solving the following equation: :math:`P = K * [R|t]`.
  79. Args:
  80. K: the camera matrix with the intrinsics with shape :math:`(B, 3, 3)`.
  81. R: The rotation matrix with shape :math:`(B, 3, 3)`.
  82. t: The translation vector with shape :math:`(B, 3, 1)`.
  83. Returns:
  84. The projection matrix P with shape :math:`(B, 4, 4)`.
  85. """
  86. KORNIA_CHECK_SHAPE(K, ["*", "3", "3"])
  87. KORNIA_CHECK_SHAPE(R, ["*", "3", "3"])
  88. KORNIA_CHECK_SHAPE(t, ["*", "3", "1"])
  89. if not len(K.shape) == len(R.shape) == len(t.shape):
  90. raise AssertionError
  91. Rt = concatenate([R, t], dim=-1) # 3x4
  92. Rt_h = pad(Rt, [0, 0, 0, 1], "constant", 0.0) # 4x4
  93. Rt_h[..., -1, -1] += 1.0
  94. K_h = pad(K, [0, 1, 0, 1], "constant", 0.0) # 4x4
  95. K_h[..., -1, -1] += 1.0
  96. return K @ Rt
  97. def KRt_from_projection(P: Tensor, eps: float = 1e-6) -> Tuple[Tensor, Tensor, Tensor]:
  98. r"""Decompose the Projection matrix into Camera-Matrix, Rotation Matrix and Translation vector.
  99. Args:
  100. P: the projection matrix with shape :math:`(B, 3, 4)`.
  101. eps: epsilon for numerical stability.
  102. Returns:
  103. - The Camera matrix with shape :math:`(B, 3, 3)`.
  104. - The Rotation matrix with shape :math:`(B, 3, 3)`.
  105. - The Translation vector with shape :math:`(B, 3)`.
  106. """
  107. KORNIA_CHECK_SHAPE(P, ["*", "3", "4"])
  108. submat_3x3 = P[:, 0:3, 0:3]
  109. last_column = P[:, 0:3, 3].unsqueeze(-1)
  110. # Trick to turn QR-decomposition into RQ-decomposition
  111. reverse = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=P.device, dtype=P.dtype).unsqueeze(0)
  112. submat_3x3 = torch.matmul(reverse, submat_3x3).permute(0, 2, 1)
  113. ortho_mat, upper_mat = linalg_qr(submat_3x3)
  114. ortho_mat = torch.matmul(reverse, ortho_mat.permute(0, 2, 1))
  115. upper_mat = torch.matmul(reverse, torch.matmul(upper_mat.permute(0, 2, 1), reverse))
  116. # Turning the `upper_mat's` diagonal elements to positive.
  117. diagonals = torch.diagonal(upper_mat, dim1=-2, dim2=-1) + eps
  118. signs = torch.sign(diagonals)
  119. signs_mat = torch.diag_embed(signs)
  120. K = torch.matmul(upper_mat, signs_mat)
  121. R = torch.matmul(signs_mat, ortho_mat)
  122. t = torch.matmul(torch.inverse(K), last_column)
  123. return K, R, t
  124. def depth_from_point(R: Tensor, t: Tensor, X: Tensor) -> Tensor:
  125. r"""Return the depth of a point transformed by a rigid transform.
  126. Args:
  127. R: The rotation matrix with shape :math:`(*, 3, 3)`.
  128. t: The translation vector with shape :math:`(*, 3, 1)`.
  129. X: The 3d points with shape :math:`(*, 3)`.
  130. Returns:
  131. The depth value per point with shape :math:`(*, 1)`.
  132. """
  133. X_tmp = R @ X.transpose(-2, -1)
  134. X_out = X_tmp[..., 2, :] + t[..., 2, :]
  135. return X_out
  136. # adapted from:
  137. # https://github.com/opencv/opencv_contrib/blob/master/modules/sfm/src/fundamental.cpp#L61
  138. # https://github.com/mapillary/OpenSfM/blob/master/opensfm/multiview.py#L14
  139. def _nullspace(A: Tensor) -> Tuple[Tensor, Tensor]:
  140. """Compute the null space of A.
  141. Return the smallest singular value and the corresponding vector.
  142. """
  143. _, s, v = _torch_svd_cast(A)
  144. return s[..., -1], v[..., -1]
  145. def projections_from_fundamental(F_mat: Tensor) -> Tensor:
  146. r"""Get the projection matrices from the Fundamental Matrix.
  147. Args:
  148. F_mat: the fundamental matrix with the shape :math:`(B, 3, 3)`.
  149. Returns:
  150. The projection matrices with shape :math:`(B, 3, 4, 2)`.
  151. """
  152. KORNIA_CHECK_SHAPE(F_mat, ["*", "3", "3"])
  153. R1 = eye_like(3, F_mat) # Bx3x3
  154. t1 = vec_like(3, F_mat) # Bx3
  155. Ft_mat = F_mat.transpose(-2, -1)
  156. _, e2 = _nullspace(Ft_mat)
  157. R2 = cross_product_matrix(e2) @ F_mat # Bx3x3
  158. t2 = e2[..., :, None] # Bx3x1
  159. P1 = torch.cat([R1, t1], dim=-1) # Bx3x4
  160. P2 = torch.cat([R2, t2], dim=-1) # Bx3x4
  161. return stack([P1, P2], dim=-1)