samplers.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import math
  18. from typing import Dict, List, Tuple
  19. import torch
  20. from kornia.core import Device, Tensor, tensor
  21. from kornia.geometry.camera import PinholeCamera
  22. from kornia.nerf.camera_utils import cameras_for_ids
  23. from kornia.utils._compat import torch_meshgrid
  24. from kornia.utils.helpers import _torch_inverse_cast
  25. class RaySampler:
  26. r"""Class to manage spatial ray sampling.
  27. Args:
  28. min_depth: sampled rays minimal depth from cameras: float
  29. max_depth: sampled rays maximal depth from cameras: float
  30. ndc: convert ray parameters to normalized device coordinates: bool
  31. device: device for ray tensors: Union[str, torch.device]
  32. """
  33. _origins: Tensor # Ray origins in world coordinates (*, 3)
  34. _directions: Tensor # Ray directions in world coordinates (*, 3)
  35. _directions_cam: Tensor # Ray directions in camera coordinates (*, 3)
  36. _origins_cam: Tensor # Ray origins in camera coordinates (*, 3)
  37. _camera_ids: Tensor # Ray camera ID
  38. _points_2d: Tensor # Ray intersection with image plane in camera coordinates
  39. def __init__(self, min_depth: float, max_depth: float, ndc: bool, device: Device, dtype: torch.dtype) -> None:
  40. self._min_depth = min_depth
  41. self._max_depth = max_depth
  42. self._ndc = ndc
  43. self._device = device
  44. self._dtype = dtype
  45. @property
  46. def origins(self) -> Tensor:
  47. return self._origins
  48. @property
  49. def directions(self) -> Tensor:
  50. return self._directions
  51. @property
  52. def camera_ids(self) -> Tensor:
  53. return self._camera_ids
  54. @property
  55. def points_2d(self) -> Tensor:
  56. return self._points_2d
  57. def __len__(self) -> int:
  58. if self.origins is None:
  59. return 0
  60. return self.origins.shape[0]
  61. def _calc_ray_directions_cam(self, cameras: PinholeCamera, points_2d: Tensor) -> Tensor:
  62. # FIXME: This function should call perspective.unproject_points or, implement in PinholeCamera unproject to
  63. # camera coordinates that will call perspective.unproject_points
  64. fx = cameras.fx
  65. fy = cameras.fy
  66. cx = cameras.cx
  67. cy = cameras.cy
  68. directions_x = (points_2d[..., 0] - cx[..., None]) / fx[..., None]
  69. directions_y = (points_2d[..., 1] - cy[..., None]) / fy[..., None]
  70. directions_z = torch.ones_like(directions_x)
  71. directions_cam = torch.stack([directions_x, directions_y, directions_z], dim=-1)
  72. return directions_cam.reshape(-1, 3)
  73. class Points2D:
  74. r"""A class to hold ray 2d pixel coordinates and a camera id for each.
  75. Args:
  76. points_2d: tensor with ray pixel coordinates (the coordinates in the image plane that correspond to the
  77. ray):math:`(B, 2)`
  78. camera_ids: list of camera ids for each pixel coordinates: List[int]
  79. """
  80. def __init__(self, points_2d: Tensor, camera_ids: List[int]) -> None:
  81. self._points_2d = points_2d # (*, N, 2)
  82. self._camera_ids = camera_ids
  83. @property
  84. def points_2d(self) -> Tensor:
  85. return self._points_2d
  86. @property
  87. def camera_ids(self) -> List[int]:
  88. return self._camera_ids
  89. def _calc_ray_params(self, cameras: PinholeCamera, points_2d_camera: Dict[int, Points2D]) -> None:
  90. r"""Calculate ray parameters: origins, directions.
  91. Also stored are camera ids for each ray, and its pixel coordinates.
  92. Args:
  93. cameras: scene cameras: PinholeCamera
  94. points_2d_camera: a dictionary that groups Point2D objects by total number of casted rays
  95. """
  96. # Unproject 2d points in image plane to 3d world for two depths
  97. origins = []
  98. directions = []
  99. directions_cam = []
  100. origins_cam = []
  101. camera_ids = []
  102. points_2d = []
  103. for obj in points_2d_camera.values():
  104. # FIXME: Below both world and camera ray directions are calculated. It could be that world ray directions
  105. # will not be necessary and can be removed here
  106. num_cams_group, num_points_per_cam_group = obj._points_2d.shape[:2]
  107. depths = (
  108. torch.ones(num_cams_group, 2 * num_points_per_cam_group, 3, device=self._device, dtype=self._dtype)
  109. * self._min_depth
  110. )
  111. depths[:, num_points_per_cam_group:] = self._max_depth
  112. cams = cameras_for_ids(cameras, obj.camera_ids)
  113. points_3d = cams.unproject(obj._points_2d.repeat(1, 2, 1), depths)
  114. origins.append(points_3d[..., :num_points_per_cam_group, :].reshape(-1, 3))
  115. directions.append(
  116. (points_3d[..., num_points_per_cam_group:, :] - points_3d[..., :num_points_per_cam_group, :]).reshape(
  117. -1, 3
  118. )
  119. )
  120. directions_cam.append(self._calc_ray_directions_cam(cams, obj._points_2d))
  121. origins_cam.append(directions_cam[-1] * self._min_depth)
  122. camera_ids.append(
  123. tensor(obj.camera_ids).repeat(num_points_per_cam_group, 1).permute(1, 0).reshape(1, -1).squeeze(0)
  124. )
  125. points_2d.append(obj._points_2d.reshape(-1, 2).int())
  126. self._origins = torch.cat(origins)
  127. self._directions = torch.cat(directions)
  128. self._directions_cam = torch.cat(directions_cam)
  129. self._origins_cam = torch.cat(origins_cam)
  130. self._camera_ids = torch.cat(camera_ids)
  131. if self._ndc: # Transform ray parameters to NDC, if defined
  132. self._origins, self._directions = self.transform_ray_params_world_to_ndc(cameras)
  133. self._points_2d = torch.cat(points_2d)
  134. def transform_ray_params_world_to_ndc(self, cameras: PinholeCamera) -> Tuple[Tensor, Tensor]:
  135. r"""Transform ray parameters to normalized coordinate device (camera) system (NDC).
  136. Args:
  137. cameras: scene cameras: PinholeCamera
  138. """
  139. cams = cameras_for_ids(cameras, self._camera_ids)
  140. fx = cams.fx
  141. fy = cams.fy
  142. widths = cams.width
  143. heights = cams.height
  144. fx_widths = 2.0 * fx / (widths - 1.0)
  145. fy_heights = 2.0 * fy / (heights - 1.0)
  146. # oxoz = self._origins_cam[..., 0] / self._origins_cam[..., 2]
  147. # oyoz = self._origins_cam[..., 1] / self._origins_cam[..., 2]
  148. oxoz = self._origins[..., 0] / self._origins[..., 2]
  149. oyoz = self._origins[..., 1] / self._origins[..., 2]
  150. origins_ndc_x = fx_widths * oxoz
  151. origins_ndc_y = fy_heights * oyoz
  152. # origins_ndc_z = 1 - 2 * self._min_depth / self._origins_cam[..., 2]
  153. origins_ndc_z = 1 - 2 * self._min_depth / self._origins[..., 2]
  154. origins_ndc = torch.stack([origins_ndc_x, origins_ndc_y, origins_ndc_z], dim=-1)
  155. # dxdz = self._directions_cam[..., 0] / self._directions_cam[..., 2]
  156. # dydz = self._directions_cam[..., 1] / self._directions_cam[..., 2]
  157. Rt_inv = _torch_inverse_cast(cams.rotation_matrix)
  158. directions_rotated_world = (Rt_inv @ self._directions_cam[..., None]).squeeze(dim=-1)
  159. dxdz = directions_rotated_world[..., 0] / directions_rotated_world[..., 2]
  160. dydz = directions_rotated_world[..., 1] / directions_rotated_world[..., 2]
  161. directions_ndc_x = fx_widths * dxdz - origins_ndc_x
  162. directions_ndc_y = fy_heights * dydz - origins_ndc_y
  163. directions_ndc_z = 1 - origins_ndc_z
  164. directions_ndc = torch.stack([directions_ndc_x, directions_ndc_y, directions_ndc_z], dim=-1)
  165. # Rt_inv = _torch_inverse_cast(cams.rotation_matrix)
  166. # origins_ndc_world = (Rt_inv @ origins_ndc[..., None]).squeeze(dim=-1)
  167. # directions_ndc_world = (Rt_inv @ directions_ndc[..., None]).squeeze(dim=-1)
  168. origins_ndc_world = origins_ndc
  169. directions_ndc_world = directions_ndc
  170. return origins_ndc_world, directions_ndc_world
  171. class Points2D_FlatTensors:
  172. r"""Class to hold x/y pixel coordinates for each ray, and its scene camera id."""
  173. def __init__(self) -> None:
  174. self._x: Tensor
  175. self._y: Tensor
  176. self._camera_ids: List[int] = []
  177. @staticmethod
  178. def _add_points2d_as_flat_tensors_to_num_ray_dict(
  179. n: int, x: Tensor, y: Tensor, camera_id: int, points2d_as_flat_tensors: Dict[int, Points2D_FlatTensors]
  180. ) -> None:
  181. r"""Add x/y pixel coordinates for all rays casted by a scene camera to dictionary of pixel coordinates
  182. grouped by total number of rays.
  183. """ # noqa: D205
  184. if n not in points2d_as_flat_tensors:
  185. points2d_as_flat_tensors[n] = RaySampler.Points2D_FlatTensors()
  186. points2d_as_flat_tensors[n]._x = x.flatten()
  187. points2d_as_flat_tensors[n]._y = y.flatten()
  188. else:
  189. points2d_as_flat_tensors[n]._x = torch.cat((points2d_as_flat_tensors[n]._x, x.flatten()))
  190. points2d_as_flat_tensors[n]._y = torch.cat((points2d_as_flat_tensors[n]._y, y.flatten()))
  191. points2d_as_flat_tensors[n]._camera_ids.append(camera_id)
  192. @staticmethod
  193. def _build_num_ray_dict_of_points2d(
  194. points2d_as_flat_tensors: Dict[int, Points2D_FlatTensors],
  195. ) -> Dict[int, Points2D]:
  196. r"""Build a dictionary of ray pixel points, by total number of rays as key.
  197. The dictionary groups rays by the total amount of rays, which allows the case of casting different number
  198. of rays from each scene camera.
  199. Args:
  200. points2d_as_flat_tensors: dictionary of pixel coordinates grouped by total number of rays:
  201. Dict[int, Points2D_FlatTensors]
  202. Returns:
  203. dictionary of Points2D objects that holds information on pixel 2d coordinates of each ray and the camera
  204. id it was casted by: Dict[int, Points2D]
  205. """
  206. num_ray_dict_of_points2d: Dict[int, RaySampler.Points2D] = {}
  207. for n, points2d_as_flat_tensor in points2d_as_flat_tensors.items():
  208. num_cams = len(points2d_as_flat_tensor._camera_ids)
  209. points_2d = (
  210. torch.stack((points2d_as_flat_tensor._x, points2d_as_flat_tensor._y))
  211. .permute(1, 0)
  212. .reshape(num_cams, -1, 2)
  213. )
  214. num_ray_dict_of_points2d[n] = RaySampler.Points2D(points_2d, points2d_as_flat_tensor._camera_ids)
  215. return num_ray_dict_of_points2d
  216. class RandomRaySampler(RaySampler):
  217. r"""Class to manage random ray spatial sampling.
  218. Args:
  219. min_depth: sampled rays minimal depth from cameras: float
  220. max_depth: sampled rays maximal depth from cameras: float
  221. ndc: convert to normalized device coordinates: bool
  222. device: device for ray tensors: Union[str, torch.device]
  223. """
  224. def __init__(self, min_depth: float, max_depth: float, ndc: bool, device: Device, dtype: torch.dtype) -> None:
  225. super().__init__(min_depth, max_depth, ndc, device, dtype)
  226. def sample_points_2d(self, heights: Tensor, widths: Tensor, num_img_rays: Tensor) -> Dict[int, RaySampler.Points2D]:
  227. r"""Randomly sample pixel points in 2d.
  228. Args:
  229. heights: tensor that holds scene camera image heights (can vary between cameras): math: `(B)`.
  230. widths: tensor that holds scene camera image widths (can vary between cameras): math: `(B)`.
  231. num_img_rays: tensor that holds the number of rays to randomly cast from each scene camera: math: `(B)`.
  232. Returns:
  233. dictionary of Points2D objects that holds information on pixel 2d coordinates of each ray and the camera
  234. id it was casted by: Dict[int, Points2D]
  235. """
  236. num_img_rays = num_img_rays.int()
  237. points2d_as_flat_tensors: Dict[int, RaySampler.Points2D_FlatTensors] = {}
  238. for camera_id, (height, width, n) in enumerate(zip(heights.tolist(), widths.tolist(), num_img_rays.tolist())):
  239. y_rand = torch.trunc(torch.rand(n, device=self._device, dtype=self._dtype) * height)
  240. x_rand = torch.trunc(torch.rand(n, device=self._device, dtype=self._dtype) * width)
  241. RaySampler._add_points2d_as_flat_tensors_to_num_ray_dict(
  242. n, x_rand, y_rand, camera_id, points2d_as_flat_tensors
  243. )
  244. return RaySampler._build_num_ray_dict_of_points2d(points2d_as_flat_tensors)
  245. def calc_ray_params(self, cameras: PinholeCamera, num_img_rays: Tensor) -> None:
  246. r"""Calculate ray parameters: origins, directions.
  247. Also stored are camera ids for each ray, and its pixel coordinates.
  248. Args:
  249. cameras: scene cameras: PinholeCamera
  250. num_img_rays: tensor that holds the number of rays to randomly cast from each scene camera: int math: `(B)`.
  251. """
  252. num_cams = cameras.batch_size
  253. if num_cams != num_img_rays.shape[0]:
  254. raise ValueError(
  255. f"Number of cameras {num_cams} does not match size of tensor to define number of rays to march from "
  256. f"each camera {num_img_rays.shape[0]}"
  257. )
  258. points_2d_camera = self.sample_points_2d(cameras.height, cameras.width, num_img_rays)
  259. self._calc_ray_params(cameras, points_2d_camera)
  260. class RandomGridRaySampler(RandomRaySampler):
  261. r"""Class to manage random ray spatial sampling.
  262. Sampling is done on a regular grid of pixels by randomizing
  263. column and row values, and casting rays for all pixels along the selected ones.
  264. Args:
  265. min_depth: sampled rays minimal depth from cameras: float
  266. max_depth: sampled rays maximal depth from cameras: float
  267. ndc: convert to normalized device coordinates: bool
  268. device: device for ray tensors: Union[str, torch.device]
  269. """
  270. def __init__(self, min_depth: float, max_depth: float, ndc: bool, device: Device, dtype: torch.dtype) -> None:
  271. super().__init__(min_depth, max_depth, ndc, device, dtype)
  272. def sample_points_2d(self, heights: Tensor, widths: Tensor, num_img_rays: Tensor) -> Dict[int, RaySampler.Points2D]:
  273. r"""Randomly sample pixel points in 2d over a regular row-column grid.
  274. Args:
  275. heights: tensor that holds scene camera image heights (can vary between cameras): math: `(B)`.
  276. widths: tensor that holds scene camera image widths (can vary between cameras): math: `(B)`.
  277. num_img_rays: tensor that holds the number of rays to randomly cast from each scene camera. Number of rows
  278. and columns is the square root of this value: int math: `(B)`.
  279. Returns:
  280. dictionary of Points2D objects that holds information on pixel 2d coordinates of each ray and the camera
  281. id it was casted by: Dict[int, Points2D]
  282. """
  283. num_img_rays = num_img_rays.int()
  284. points2d_as_flat_tensors: Dict[int, RaySampler.Points2D_FlatTensors] = {}
  285. for camera_id, (height, width, n) in enumerate(zip(heights.tolist(), widths.tolist(), num_img_rays.tolist())):
  286. n_sqrt = int(math.sqrt(n))
  287. y_rand = torch.randperm(int(height), device=self._device, dtype=self._dtype)[: min(int(height), n_sqrt)]
  288. x_rand = torch.randperm(int(width), device=self._device, dtype=self._dtype)[: min(int(width), n_sqrt)]
  289. y_grid, x_grid = torch_meshgrid([y_rand, x_rand], indexing="ij")
  290. RaySampler._add_points2d_as_flat_tensors_to_num_ray_dict(
  291. n_sqrt * n_sqrt, x_grid, y_grid, camera_id, points2d_as_flat_tensors
  292. )
  293. return RaySampler._build_num_ray_dict_of_points2d(points2d_as_flat_tensors)
  294. class UniformRaySampler(RaySampler):
  295. r"""Class to manage uniform ray spatial sampling for all camera scene pixels.
  296. Args:
  297. min_depth: sampled rays minimal depth from cameras: float
  298. max_depth: sampled rays maximal depth from cameras: float
  299. ndc: convert to normalized device coordinates: bool
  300. device: device for ray tensors: Union[str, torch.device]
  301. """
  302. def __init__(self, min_depth: float, max_depth: float, ndc: bool, device: Device, dtype: torch.dtype) -> None:
  303. super().__init__(min_depth, max_depth, ndc, device, dtype)
  304. def sample_points_2d(
  305. self, heights: Tensor, widths: Tensor, sampling_step: int = 1
  306. ) -> Dict[int, RaySampler.Points2D]:
  307. r"""Uniformly sample pixel points in 2d for all scene camera pixels.
  308. Args:
  309. heights: tensor that holds scene camera image heights (can vary between cameras): math: `(B)`.
  310. widths: tensor that holds scene camera image widths (can vary between cameras): math: `(B)`.
  311. sampling_step: defines uniform strides between rows and columns: int.
  312. Returns:
  313. dictionary of Points2D objects that holds information on pixel 2d coordinates of each ray and the camera
  314. id it was casted by: Dict[int, Points2D]
  315. """
  316. heights = heights.int()
  317. widths = widths.int()
  318. points2d_as_flat_tensors: Dict[int, RaySampler.Points2D_FlatTensors] = {}
  319. for camera_id, (height, width) in enumerate(zip(heights.tolist(), widths.tolist())):
  320. n = height * width
  321. y_grid, x_grid = torch_meshgrid(
  322. [
  323. torch.arange(0, height, sampling_step, device=self._device, dtype=self._dtype),
  324. torch.arange(0, width, sampling_step, device=self._device, dtype=self._dtype),
  325. ],
  326. indexing="ij",
  327. )
  328. RaySampler._add_points2d_as_flat_tensors_to_num_ray_dict(
  329. n, x_grid, y_grid, camera_id, points2d_as_flat_tensors
  330. )
  331. return RaySampler._build_num_ray_dict_of_points2d(points2d_as_flat_tensors)
  332. def calc_ray_params(self, cameras: PinholeCamera) -> None:
  333. points_2d_camera = self.sample_points_2d(cameras.height, cameras.width)
  334. self._calc_ray_params(cameras, points_2d_camera)
  335. def sample_lengths(
  336. num_rays: int, num_ray_points: int, device: Device, dtype: torch.dtype, irregular: bool = False
  337. ) -> Tensor:
  338. """Sample points along the length of rays."""
  339. if num_ray_points <= 1:
  340. raise ValueError("Number of ray points must be greater than 1")
  341. if not irregular:
  342. zero_to_one = torch.linspace(0.0, 1.0, num_ray_points, device=device, dtype=dtype)
  343. lengths = zero_to_one.repeat(num_rays, 1) # FIXME: Expand instead of repeat maybe?
  344. else:
  345. zero_to_one = torch.linspace(0.0, 1.0, num_ray_points + 1, device=device, dtype=dtype)
  346. lengths = torch.rand(num_rays, num_ray_points, device=device) / num_ray_points + zero_to_one[:-1]
  347. return lengths
  348. # TODO: Implement hierarchical ray sampling as described in Mildenhall (2020) Sec. 5.2
  349. def sample_ray_points(
  350. origins: Tensor, directions: Tensor, lengths: Tensor
  351. ) -> Tensor: # FIXME: Test by projecting to points_2d and compare with sampler 2d points
  352. r"""Sample points along ray.
  353. Args:
  354. origins: tensor containing ray origins in 3d world coordinates. Tensor shape :math:`(*, 3)`.
  355. directions: tensor containing ray directions in 3d world coordinates. Tensor shape :math:`(*, 3)`.
  356. lengths: tensor containing sampled distances along each ray. Tensor shape :math:`(*, num_ray_points)`.
  357. Returns:
  358. points_3d: Points along rays :math:`(*, num_ray_points, 3)`
  359. """
  360. points_3d = origins[..., None, :] + lengths[..., None] * directions[..., None, :]
  361. return points_3d
  362. def calc_ray_t_vals(points_3d: Tensor) -> Tensor:
  363. r"""Calculate t values along rays.
  364. Args:
  365. points_3d: Points along rays :math:`(*, num_ray_points, 3)`
  366. Returns:
  367. t values along rays :math:`(*, num_ray_points)`
  368. """
  369. t_vals = torch.linalg.norm(points_3d - points_3d[..., 0, :].unsqueeze(-2), dim=-1)
  370. return t_vals