stereo.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any
  18. import torch
  19. from kornia.core import Tensor, stack, zeros
  20. from kornia.geometry.linalg import transform_points
  21. from kornia.utils.grid import create_meshgrid
  22. class StereoException(Exception):
  23. def __init__(self, msg: str, *args: Any, **kwargs: Any) -> None:
  24. r"""Construct custom exception for the :module:`~kornia.geometry.camera.stereo` module.
  25. Adds a general helper module redirecting the user to the proper documentation site.
  26. Args:
  27. msg: Custom message to add to the general message.
  28. *args: Additional argument passthrough
  29. **kwargs: Additional argument passthrough
  30. """
  31. doc_help = (
  32. "\n Please check documents here: "
  33. "https://kornia.readthedocs.io/en/latest/geometry.camera.stereo.html for further information and examples."
  34. )
  35. final_msg = msg + doc_help
  36. # type ignore because of mypy error:
  37. # Too many arguments for "__init__" of "BaseException"
  38. super().__init__(final_msg, *args, **kwargs)
  39. class StereoCamera:
  40. def __init__(self, rectified_left_camera: Tensor, rectified_right_camera: Tensor) -> None:
  41. r"""Class representing a horizontal stereo camera setup.
  42. Args:
  43. rectified_left_camera: The rectified left camera projection matrix
  44. of shape :math:`(B, 3, 4)`
  45. rectified_right_camera: The rectified right camera projection matrix
  46. of shape :math:`(B, 3, 4)`
  47. """
  48. self._check_stereo_camera(rectified_left_camera, rectified_right_camera)
  49. self.rectified_left_camera: Tensor = rectified_left_camera
  50. self.rectified_right_camera: Tensor = rectified_right_camera
  51. self.device = self.rectified_left_camera.device
  52. self.dtype = self.rectified_left_camera.dtype
  53. self._Q_matrix = self._init_Q_matrix()
  54. @staticmethod
  55. def _check_stereo_camera(rectified_left_camera: Tensor, rectified_right_camera: Tensor) -> None:
  56. r"""Ensure user specified correct camera matrices.
  57. Args:
  58. rectified_left_camera: The rectified left camera projection matrix
  59. of shape :math:`(B, 3, 4)`
  60. rectified_right_camera: The rectified right camera projection matrix
  61. of shape :math:`(B, 3, 4)`
  62. """
  63. # Ensure correct shapes
  64. if len(rectified_left_camera.shape) != 3:
  65. raise StereoException(
  66. f"Expected 'rectified_left_camera' to have 3 dimensions. Got {rectified_left_camera.shape}."
  67. )
  68. if len(rectified_right_camera.shape) != 3:
  69. raise StereoException(
  70. f"Expected 'rectified_right_camera' to have 3 dimension. Got {rectified_right_camera.shape}."
  71. )
  72. if rectified_left_camera.shape[:1] == (3, 4):
  73. raise StereoException(
  74. f"Expected each 'rectified_left_camera' to be of shape (3, 4).Got {rectified_left_camera.shape[:1]}."
  75. )
  76. if rectified_right_camera.shape[:1] == (3, 4):
  77. raise StereoException(
  78. f"Expected each 'rectified_right_camera' to be of shape (3, 4).Got {rectified_right_camera.shape[:1]}."
  79. )
  80. # Ensure same devices for cameras.
  81. if rectified_left_camera.device != rectified_right_camera.device:
  82. raise StereoException(
  83. "Expected 'rectified_left_camera' and 'rectified_right_camera' "
  84. "to be on the same devices."
  85. f"Got {rectified_left_camera.device} and {rectified_right_camera.device}."
  86. )
  87. # Ensure same dtypes for cameras.
  88. if rectified_left_camera.dtype != rectified_right_camera.dtype:
  89. raise StereoException(
  90. "Expected 'rectified_left_camera' and 'rectified_right_camera' to"
  91. "have same dtype."
  92. f"Got {rectified_left_camera.dtype} and {rectified_right_camera.dtype}."
  93. )
  94. # Ensure all intrinsics parameters (fx, fy, cx, cy) are the same in both cameras.
  95. if not torch.all(torch.eq(rectified_left_camera[..., :, :3], rectified_right_camera[..., :, :3])):
  96. raise StereoException(
  97. "Expected 'left_rectified_camera' and 'rectified_right_camera' to have"
  98. "same parameters except for the last column."
  99. f"Got {rectified_left_camera[..., :, :3]} and {rectified_right_camera[..., :, :3]}."
  100. )
  101. # Ensure that tx * fx is negative and exists.
  102. tx_fx = rectified_right_camera[..., 0, 3]
  103. if torch.all(torch.gt(tx_fx, 0)):
  104. raise StereoException(f"Expected :math:`T_x * f_x` to be negative. Got {tx_fx}.")
  105. @property
  106. def batch_size(self) -> int:
  107. r"""Return the batch size of the storage.
  108. Returns:
  109. scalar with the batch size
  110. """
  111. return self.rectified_left_camera.shape[0]
  112. @property
  113. def fx(self) -> Tensor:
  114. r"""Return the focal length in the x-direction.
  115. Note that the focal lengths of the rectified left and right
  116. camera are assumed to be equal.
  117. Returns:
  118. tensor of shape :math:`(B)`
  119. """
  120. return self.rectified_left_camera[..., 0, 0]
  121. @property
  122. def fy(self) -> Tensor:
  123. r"""Returns the focal length in the y-direction.
  124. Note that the focal lengths of the rectified left and right
  125. camera are assumed to be equal.
  126. Returns:
  127. tensor of shape :math:`(B)`
  128. """
  129. return self.rectified_left_camera[..., 1, 1]
  130. @property
  131. def cx_left(self) -> Tensor:
  132. r"""Return the x-coordinate of the principal point for the left camera.
  133. Returns:
  134. tensor of shape :math:`(B)`
  135. """
  136. return self.rectified_left_camera[..., 0, 2]
  137. @property
  138. def cx_right(self) -> Tensor:
  139. r"""Return the x-coordinate of the principal point for the right camera.
  140. Returns:
  141. tensor of shape :math:`(B)`
  142. """
  143. return self.rectified_right_camera[..., 0, 2]
  144. @property
  145. def cy(self) -> Tensor:
  146. r"""Return the y-coordinate of the principal point.
  147. Note that the y-coordinate of the principal points
  148. is assumed to be equal for the left and right camera.
  149. Returns:
  150. tensor of shape :math:`(B)`
  151. """
  152. return self.rectified_left_camera[..., 1, 2]
  153. @property
  154. def tx(self) -> Tensor:
  155. r"""The horizontal baseline between the two cameras.
  156. Returns:
  157. Tensor of shape :math:`(B)`
  158. """
  159. return -self.rectified_right_camera[..., 0, 3] / self.fx
  160. @property
  161. def Q(self) -> Tensor:
  162. r"""The Q matrix of the horizontal stereo setup.
  163. This matrix is used for reprojecting a disparity tensor to
  164. the corresponding point cloud. Note that this is in a general form that allows different focal
  165. lengths in the x and y direction.
  166. Return:
  167. The Q matrix of shape :math:`(B, 4, 4)`.
  168. """
  169. return self._Q_matrix
  170. def _init_Q_matrix(self) -> Tensor:
  171. r"""Initialize the Q matrix of the horizontal stereo setup. See the Q property.
  172. Returns:
  173. The Q matrix of shape :math:`(B, 4, 4)`.
  174. """
  175. Q = zeros((self.batch_size, 4, 4), device=self.device, dtype=self.dtype)
  176. baseline: Tensor = -self.tx
  177. Q[:, 0, 0] = self.fy * baseline
  178. Q[:, 0, 3] = -self.fy * self.cx_left * baseline
  179. Q[:, 1, 1] = self.fx * baseline
  180. Q[:, 1, 3] = -self.fx * self.cy * baseline
  181. Q[:, 2, 3] = self.fx * self.fy * baseline
  182. Q[:, 3, 2] = -self.fy
  183. Q[:, 3, 3] = self.fy * (self.cx_left - self.cx_right) # NOTE: This is usually zero.
  184. return Q
  185. def reproject_disparity_to_3D(self, disparity_tensor: Tensor) -> Tensor:
  186. r"""Reproject the disparity tensor to a 3D point cloud.
  187. Args:
  188. disparity_tensor: Disparity tensor of shape :math:`(B, 1, H, W)`.
  189. Returns:
  190. The 3D point cloud of shape :math:`(B, H, W, 3)`
  191. """
  192. return reproject_disparity_to_3D(disparity_tensor, self.Q)
  193. def _check_disparity_tensor(disparity_tensor: Tensor) -> None:
  194. r"""Ensure correct user provided correct disparity tensor.
  195. Args:
  196. disparity_tensor: The disparity tensor of shape :math:`(B, 1, H, W)`.
  197. """
  198. if not isinstance(disparity_tensor, Tensor):
  199. raise StereoException(
  200. f"Expected 'disparity_tensor' to be an instance of Tensor but got {type(disparity_tensor)}."
  201. )
  202. if len(disparity_tensor.shape) != 4:
  203. raise StereoException(f"Expected 'disparity_tensor' to have 4 dimensions. Got {disparity_tensor.shape}.")
  204. if disparity_tensor.shape[-1] != 1:
  205. raise StereoException(
  206. "Expected dimension 1 of 'disparity_tensor' to be 1 for as single channeled disparity map."
  207. f"Got {disparity_tensor.shape}."
  208. )
  209. if disparity_tensor.dtype not in (torch.float16, torch.float32, torch.float64):
  210. raise StereoException(
  211. "Expected 'disparity_tensor' to have dtype torch.float16, torch.float32 or torch.float64."
  212. f"Got {disparity_tensor.dtype}"
  213. )
  214. def _check_Q_matrix(Q_matrix: Tensor) -> None:
  215. r"""Ensure Q matrix is of correct form.
  216. Args:
  217. Q_matrix: The Q matrix for reprojecting disparity to a point cloud of shape :math:`(B, 4, 4)`
  218. """
  219. if not isinstance(Q_matrix, Tensor):
  220. raise StereoException(f"Expected 'Q_matrix' to be an instance of Tensor but got {type(Q_matrix)}.")
  221. if not len(Q_matrix.shape) == 3:
  222. raise StereoException(f"Expected 'Q_matrix' to have 3 dimensions. Got {Q_matrix.shape}")
  223. if not Q_matrix.shape[1:] == (4, 4):
  224. raise StereoException(f"Expected last two dimensions of 'Q_matrix' to be of shape (4, 4). Got {Q_matrix.shape}")
  225. if Q_matrix.dtype not in (torch.float16, torch.float32, torch.float64):
  226. raise StereoException(
  227. f"Expected 'Q_matrix' to be of type torch.float16, torch.float32 or torch.float64. Got {Q_matrix.dtype}"
  228. )
  229. def reproject_disparity_to_3D(disparity_tensor: Tensor, Q_matrix: Tensor) -> Tensor:
  230. r"""Reproject the disparity tensor to a 3D point cloud.
  231. Args:
  232. disparity_tensor: Disparity tensor of shape :math:`(B, H, W, 1)`.
  233. Q_matrix: Tensor of Q matrices of shapes :math:`(B, 4, 4)`.
  234. Returns:
  235. The 3D point cloud of shape :math:`(B, H, W, 3)`
  236. """
  237. _check_Q_matrix(Q_matrix)
  238. _check_disparity_tensor(disparity_tensor)
  239. batch_size, rows, cols, _ = disparity_tensor.shape
  240. dtype = disparity_tensor.dtype
  241. device = disparity_tensor.device
  242. uv = create_meshgrid(rows, cols, normalized_coordinates=False, device=device, dtype=dtype)
  243. uv = uv.expand(batch_size, -1, -1, -1)
  244. v, u = torch.unbind(uv, dim=-1)
  245. v, u = torch.unsqueeze(v, -1), torch.unsqueeze(u, -1)
  246. uvd = stack((u, v, disparity_tensor), 1).reshape(batch_size, 3, -1).permute(0, 2, 1)
  247. points = transform_points(Q_matrix, uvd).reshape(batch_size, rows, cols, 3)
  248. # Final check that everything went well.
  249. if not points.shape == (batch_size, rows, cols, 3):
  250. raise StereoException(
  251. "Something went wrong in `reproject_disparity_to_3D`. Expected the final output"
  252. f"to be of shape {(batch_size, rows, cols, 3)}."
  253. f"But the computed point cloud had shape {points.shape}. "
  254. "Please ensure input are correct. If this is an error, please submit an issue."
  255. )
  256. return points