| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import math
- from typing import Dict, List, Tuple
- import torch
- from kornia.core import Device, Tensor, tensor
- from kornia.geometry.camera import PinholeCamera
- from kornia.nerf.camera_utils import cameras_for_ids
- from kornia.utils._compat import torch_meshgrid
- from kornia.utils.helpers import _torch_inverse_cast
- class RaySampler:
- r"""Class to manage spatial ray sampling.
- Args:
- min_depth: sampled rays minimal depth from cameras: float
- max_depth: sampled rays maximal depth from cameras: float
- ndc: convert ray parameters to normalized device coordinates: bool
- device: device for ray tensors: Union[str, torch.device]
- """
- _origins: Tensor # Ray origins in world coordinates (*, 3)
- _directions: Tensor # Ray directions in world coordinates (*, 3)
- _directions_cam: Tensor # Ray directions in camera coordinates (*, 3)
- _origins_cam: Tensor # Ray origins in camera coordinates (*, 3)
- _camera_ids: Tensor # Ray camera ID
- _points_2d: Tensor # Ray intersection with image plane in camera coordinates
- def __init__(self, min_depth: float, max_depth: float, ndc: bool, device: Device, dtype: torch.dtype) -> None:
- self._min_depth = min_depth
- self._max_depth = max_depth
- self._ndc = ndc
- self._device = device
- self._dtype = dtype
- @property
- def origins(self) -> Tensor:
- return self._origins
- @property
- def directions(self) -> Tensor:
- return self._directions
- @property
- def camera_ids(self) -> Tensor:
- return self._camera_ids
- @property
- def points_2d(self) -> Tensor:
- return self._points_2d
- def __len__(self) -> int:
- if self.origins is None:
- return 0
- return self.origins.shape[0]
- def _calc_ray_directions_cam(self, cameras: PinholeCamera, points_2d: Tensor) -> Tensor:
- # FIXME: This function should call perspective.unproject_points or, implement in PinholeCamera unproject to
- # camera coordinates that will call perspective.unproject_points
- fx = cameras.fx
- fy = cameras.fy
- cx = cameras.cx
- cy = cameras.cy
- directions_x = (points_2d[..., 0] - cx[..., None]) / fx[..., None]
- directions_y = (points_2d[..., 1] - cy[..., None]) / fy[..., None]
- directions_z = torch.ones_like(directions_x)
- directions_cam = torch.stack([directions_x, directions_y, directions_z], dim=-1)
- return directions_cam.reshape(-1, 3)
- class Points2D:
- r"""A class to hold ray 2d pixel coordinates and a camera id for each.
- Args:
- points_2d: tensor with ray pixel coordinates (the coordinates in the image plane that correspond to the
- ray):math:`(B, 2)`
- camera_ids: list of camera ids for each pixel coordinates: List[int]
- """
- def __init__(self, points_2d: Tensor, camera_ids: List[int]) -> None:
- self._points_2d = points_2d # (*, N, 2)
- self._camera_ids = camera_ids
- @property
- def points_2d(self) -> Tensor:
- return self._points_2d
- @property
- def camera_ids(self) -> List[int]:
- return self._camera_ids
- def _calc_ray_params(self, cameras: PinholeCamera, points_2d_camera: Dict[int, Points2D]) -> None:
- r"""Calculate ray parameters: origins, directions.
- Also stored are camera ids for each ray, and its pixel coordinates.
- Args:
- cameras: scene cameras: PinholeCamera
- points_2d_camera: a dictionary that groups Point2D objects by total number of casted rays
- """
- # Unproject 2d points in image plane to 3d world for two depths
- origins = []
- directions = []
- directions_cam = []
- origins_cam = []
- camera_ids = []
- points_2d = []
- for obj in points_2d_camera.values():
- # FIXME: Below both world and camera ray directions are calculated. It could be that world ray directions
- # will not be necessary and can be removed here
- num_cams_group, num_points_per_cam_group = obj._points_2d.shape[:2]
- depths = (
- torch.ones(num_cams_group, 2 * num_points_per_cam_group, 3, device=self._device, dtype=self._dtype)
- * self._min_depth
- )
- depths[:, num_points_per_cam_group:] = self._max_depth
- cams = cameras_for_ids(cameras, obj.camera_ids)
- points_3d = cams.unproject(obj._points_2d.repeat(1, 2, 1), depths)
- origins.append(points_3d[..., :num_points_per_cam_group, :].reshape(-1, 3))
- directions.append(
- (points_3d[..., num_points_per_cam_group:, :] - points_3d[..., :num_points_per_cam_group, :]).reshape(
- -1, 3
- )
- )
- directions_cam.append(self._calc_ray_directions_cam(cams, obj._points_2d))
- origins_cam.append(directions_cam[-1] * self._min_depth)
- camera_ids.append(
- tensor(obj.camera_ids).repeat(num_points_per_cam_group, 1).permute(1, 0).reshape(1, -1).squeeze(0)
- )
- points_2d.append(obj._points_2d.reshape(-1, 2).int())
- self._origins = torch.cat(origins)
- self._directions = torch.cat(directions)
- self._directions_cam = torch.cat(directions_cam)
- self._origins_cam = torch.cat(origins_cam)
- self._camera_ids = torch.cat(camera_ids)
- if self._ndc: # Transform ray parameters to NDC, if defined
- self._origins, self._directions = self.transform_ray_params_world_to_ndc(cameras)
- self._points_2d = torch.cat(points_2d)
- def transform_ray_params_world_to_ndc(self, cameras: PinholeCamera) -> Tuple[Tensor, Tensor]:
- r"""Transform ray parameters to normalized coordinate device (camera) system (NDC).
- Args:
- cameras: scene cameras: PinholeCamera
- """
- cams = cameras_for_ids(cameras, self._camera_ids)
- fx = cams.fx
- fy = cams.fy
- widths = cams.width
- heights = cams.height
- fx_widths = 2.0 * fx / (widths - 1.0)
- fy_heights = 2.0 * fy / (heights - 1.0)
- # oxoz = self._origins_cam[..., 0] / self._origins_cam[..., 2]
- # oyoz = self._origins_cam[..., 1] / self._origins_cam[..., 2]
- oxoz = self._origins[..., 0] / self._origins[..., 2]
- oyoz = self._origins[..., 1] / self._origins[..., 2]
- origins_ndc_x = fx_widths * oxoz
- origins_ndc_y = fy_heights * oyoz
- # origins_ndc_z = 1 - 2 * self._min_depth / self._origins_cam[..., 2]
- origins_ndc_z = 1 - 2 * self._min_depth / self._origins[..., 2]
- origins_ndc = torch.stack([origins_ndc_x, origins_ndc_y, origins_ndc_z], dim=-1)
- # dxdz = self._directions_cam[..., 0] / self._directions_cam[..., 2]
- # dydz = self._directions_cam[..., 1] / self._directions_cam[..., 2]
- Rt_inv = _torch_inverse_cast(cams.rotation_matrix)
- directions_rotated_world = (Rt_inv @ self._directions_cam[..., None]).squeeze(dim=-1)
- dxdz = directions_rotated_world[..., 0] / directions_rotated_world[..., 2]
- dydz = directions_rotated_world[..., 1] / directions_rotated_world[..., 2]
- directions_ndc_x = fx_widths * dxdz - origins_ndc_x
- directions_ndc_y = fy_heights * dydz - origins_ndc_y
- directions_ndc_z = 1 - origins_ndc_z
- directions_ndc = torch.stack([directions_ndc_x, directions_ndc_y, directions_ndc_z], dim=-1)
- # Rt_inv = _torch_inverse_cast(cams.rotation_matrix)
- # origins_ndc_world = (Rt_inv @ origins_ndc[..., None]).squeeze(dim=-1)
- # directions_ndc_world = (Rt_inv @ directions_ndc[..., None]).squeeze(dim=-1)
- origins_ndc_world = origins_ndc
- directions_ndc_world = directions_ndc
- return origins_ndc_world, directions_ndc_world
- class Points2D_FlatTensors:
- r"""Class to hold x/y pixel coordinates for each ray, and its scene camera id."""
- def __init__(self) -> None:
- self._x: Tensor
- self._y: Tensor
- self._camera_ids: List[int] = []
- @staticmethod
- def _add_points2d_as_flat_tensors_to_num_ray_dict(
- n: int, x: Tensor, y: Tensor, camera_id: int, points2d_as_flat_tensors: Dict[int, Points2D_FlatTensors]
- ) -> None:
- r"""Add x/y pixel coordinates for all rays casted by a scene camera to dictionary of pixel coordinates
- grouped by total number of rays.
- """ # noqa: D205
- if n not in points2d_as_flat_tensors:
- points2d_as_flat_tensors[n] = RaySampler.Points2D_FlatTensors()
- points2d_as_flat_tensors[n]._x = x.flatten()
- points2d_as_flat_tensors[n]._y = y.flatten()
- else:
- points2d_as_flat_tensors[n]._x = torch.cat((points2d_as_flat_tensors[n]._x, x.flatten()))
- points2d_as_flat_tensors[n]._y = torch.cat((points2d_as_flat_tensors[n]._y, y.flatten()))
- points2d_as_flat_tensors[n]._camera_ids.append(camera_id)
- @staticmethod
- def _build_num_ray_dict_of_points2d(
- points2d_as_flat_tensors: Dict[int, Points2D_FlatTensors],
- ) -> Dict[int, Points2D]:
- r"""Build a dictionary of ray pixel points, by total number of rays as key.
- The dictionary groups rays by the total amount of rays, which allows the case of casting different number
- of rays from each scene camera.
- Args:
- points2d_as_flat_tensors: dictionary of pixel coordinates grouped by total number of rays:
- Dict[int, Points2D_FlatTensors]
- Returns:
- dictionary of Points2D objects that holds information on pixel 2d coordinates of each ray and the camera
- id it was casted by: Dict[int, Points2D]
- """
- num_ray_dict_of_points2d: Dict[int, RaySampler.Points2D] = {}
- for n, points2d_as_flat_tensor in points2d_as_flat_tensors.items():
- num_cams = len(points2d_as_flat_tensor._camera_ids)
- points_2d = (
- torch.stack((points2d_as_flat_tensor._x, points2d_as_flat_tensor._y))
- .permute(1, 0)
- .reshape(num_cams, -1, 2)
- )
- num_ray_dict_of_points2d[n] = RaySampler.Points2D(points_2d, points2d_as_flat_tensor._camera_ids)
- return num_ray_dict_of_points2d
- class RandomRaySampler(RaySampler):
- r"""Class to manage random ray spatial sampling.
- Args:
- min_depth: sampled rays minimal depth from cameras: float
- max_depth: sampled rays maximal depth from cameras: float
- ndc: convert to normalized device coordinates: bool
- device: device for ray tensors: Union[str, torch.device]
- """
- def __init__(self, min_depth: float, max_depth: float, ndc: bool, device: Device, dtype: torch.dtype) -> None:
- super().__init__(min_depth, max_depth, ndc, device, dtype)
- def sample_points_2d(self, heights: Tensor, widths: Tensor, num_img_rays: Tensor) -> Dict[int, RaySampler.Points2D]:
- r"""Randomly sample pixel points in 2d.
- Args:
- heights: tensor that holds scene camera image heights (can vary between cameras): math: `(B)`.
- widths: tensor that holds scene camera image widths (can vary between cameras): math: `(B)`.
- num_img_rays: tensor that holds the number of rays to randomly cast from each scene camera: math: `(B)`.
- Returns:
- dictionary of Points2D objects that holds information on pixel 2d coordinates of each ray and the camera
- id it was casted by: Dict[int, Points2D]
- """
- num_img_rays = num_img_rays.int()
- points2d_as_flat_tensors: Dict[int, RaySampler.Points2D_FlatTensors] = {}
- for camera_id, (height, width, n) in enumerate(zip(heights.tolist(), widths.tolist(), num_img_rays.tolist())):
- y_rand = torch.trunc(torch.rand(n, device=self._device, dtype=self._dtype) * height)
- x_rand = torch.trunc(torch.rand(n, device=self._device, dtype=self._dtype) * width)
- RaySampler._add_points2d_as_flat_tensors_to_num_ray_dict(
- n, x_rand, y_rand, camera_id, points2d_as_flat_tensors
- )
- return RaySampler._build_num_ray_dict_of_points2d(points2d_as_flat_tensors)
- def calc_ray_params(self, cameras: PinholeCamera, num_img_rays: Tensor) -> None:
- r"""Calculate ray parameters: origins, directions.
- Also stored are camera ids for each ray, and its pixel coordinates.
- Args:
- cameras: scene cameras: PinholeCamera
- num_img_rays: tensor that holds the number of rays to randomly cast from each scene camera: int math: `(B)`.
- """
- num_cams = cameras.batch_size
- if num_cams != num_img_rays.shape[0]:
- raise ValueError(
- f"Number of cameras {num_cams} does not match size of tensor to define number of rays to march from "
- f"each camera {num_img_rays.shape[0]}"
- )
- points_2d_camera = self.sample_points_2d(cameras.height, cameras.width, num_img_rays)
- self._calc_ray_params(cameras, points_2d_camera)
- class RandomGridRaySampler(RandomRaySampler):
- r"""Class to manage random ray spatial sampling.
- Sampling is done on a regular grid of pixels by randomizing
- column and row values, and casting rays for all pixels along the selected ones.
- Args:
- min_depth: sampled rays minimal depth from cameras: float
- max_depth: sampled rays maximal depth from cameras: float
- ndc: convert to normalized device coordinates: bool
- device: device for ray tensors: Union[str, torch.device]
- """
- def __init__(self, min_depth: float, max_depth: float, ndc: bool, device: Device, dtype: torch.dtype) -> None:
- super().__init__(min_depth, max_depth, ndc, device, dtype)
- def sample_points_2d(self, heights: Tensor, widths: Tensor, num_img_rays: Tensor) -> Dict[int, RaySampler.Points2D]:
- r"""Randomly sample pixel points in 2d over a regular row-column grid.
- Args:
- heights: tensor that holds scene camera image heights (can vary between cameras): math: `(B)`.
- widths: tensor that holds scene camera image widths (can vary between cameras): math: `(B)`.
- num_img_rays: tensor that holds the number of rays to randomly cast from each scene camera. Number of rows
- and columns is the square root of this value: int math: `(B)`.
- Returns:
- dictionary of Points2D objects that holds information on pixel 2d coordinates of each ray and the camera
- id it was casted by: Dict[int, Points2D]
- """
- num_img_rays = num_img_rays.int()
- points2d_as_flat_tensors: Dict[int, RaySampler.Points2D_FlatTensors] = {}
- for camera_id, (height, width, n) in enumerate(zip(heights.tolist(), widths.tolist(), num_img_rays.tolist())):
- n_sqrt = int(math.sqrt(n))
- y_rand = torch.randperm(int(height), device=self._device, dtype=self._dtype)[: min(int(height), n_sqrt)]
- x_rand = torch.randperm(int(width), device=self._device, dtype=self._dtype)[: min(int(width), n_sqrt)]
- y_grid, x_grid = torch_meshgrid([y_rand, x_rand], indexing="ij")
- RaySampler._add_points2d_as_flat_tensors_to_num_ray_dict(
- n_sqrt * n_sqrt, x_grid, y_grid, camera_id, points2d_as_flat_tensors
- )
- return RaySampler._build_num_ray_dict_of_points2d(points2d_as_flat_tensors)
- class UniformRaySampler(RaySampler):
- r"""Class to manage uniform ray spatial sampling for all camera scene pixels.
- Args:
- min_depth: sampled rays minimal depth from cameras: float
- max_depth: sampled rays maximal depth from cameras: float
- ndc: convert to normalized device coordinates: bool
- device: device for ray tensors: Union[str, torch.device]
- """
- def __init__(self, min_depth: float, max_depth: float, ndc: bool, device: Device, dtype: torch.dtype) -> None:
- super().__init__(min_depth, max_depth, ndc, device, dtype)
- def sample_points_2d(
- self, heights: Tensor, widths: Tensor, sampling_step: int = 1
- ) -> Dict[int, RaySampler.Points2D]:
- r"""Uniformly sample pixel points in 2d for all scene camera pixels.
- Args:
- heights: tensor that holds scene camera image heights (can vary between cameras): math: `(B)`.
- widths: tensor that holds scene camera image widths (can vary between cameras): math: `(B)`.
- sampling_step: defines uniform strides between rows and columns: int.
- Returns:
- dictionary of Points2D objects that holds information on pixel 2d coordinates of each ray and the camera
- id it was casted by: Dict[int, Points2D]
- """
- heights = heights.int()
- widths = widths.int()
- points2d_as_flat_tensors: Dict[int, RaySampler.Points2D_FlatTensors] = {}
- for camera_id, (height, width) in enumerate(zip(heights.tolist(), widths.tolist())):
- n = height * width
- y_grid, x_grid = torch_meshgrid(
- [
- torch.arange(0, height, sampling_step, device=self._device, dtype=self._dtype),
- torch.arange(0, width, sampling_step, device=self._device, dtype=self._dtype),
- ],
- indexing="ij",
- )
- RaySampler._add_points2d_as_flat_tensors_to_num_ray_dict(
- n, x_grid, y_grid, camera_id, points2d_as_flat_tensors
- )
- return RaySampler._build_num_ray_dict_of_points2d(points2d_as_flat_tensors)
- def calc_ray_params(self, cameras: PinholeCamera) -> None:
- points_2d_camera = self.sample_points_2d(cameras.height, cameras.width)
- self._calc_ray_params(cameras, points_2d_camera)
- def sample_lengths(
- num_rays: int, num_ray_points: int, device: Device, dtype: torch.dtype, irregular: bool = False
- ) -> Tensor:
- """Sample points along the length of rays."""
- if num_ray_points <= 1:
- raise ValueError("Number of ray points must be greater than 1")
- if not irregular:
- zero_to_one = torch.linspace(0.0, 1.0, num_ray_points, device=device, dtype=dtype)
- lengths = zero_to_one.repeat(num_rays, 1) # FIXME: Expand instead of repeat maybe?
- else:
- zero_to_one = torch.linspace(0.0, 1.0, num_ray_points + 1, device=device, dtype=dtype)
- lengths = torch.rand(num_rays, num_ray_points, device=device) / num_ray_points + zero_to_one[:-1]
- return lengths
- # TODO: Implement hierarchical ray sampling as described in Mildenhall (2020) Sec. 5.2
- def sample_ray_points(
- origins: Tensor, directions: Tensor, lengths: Tensor
- ) -> Tensor: # FIXME: Test by projecting to points_2d and compare with sampler 2d points
- r"""Sample points along ray.
- Args:
- origins: tensor containing ray origins in 3d world coordinates. Tensor shape :math:`(*, 3)`.
- directions: tensor containing ray directions in 3d world coordinates. Tensor shape :math:`(*, 3)`.
- lengths: tensor containing sampled distances along each ray. Tensor shape :math:`(*, num_ray_points)`.
- Returns:
- points_3d: Points along rays :math:`(*, num_ray_points, 3)`
- """
- points_3d = origins[..., None, :] + lengths[..., None] * directions[..., None, :]
- return points_3d
- def calc_ray_t_vals(points_3d: Tensor) -> Tensor:
- r"""Calculate t values along rays.
- Args:
- points_3d: Points along rays :math:`(*, num_ray_points, 3)`
- Returns:
- t values along rays :math:`(*, num_ray_points)`
- """
- t_vals = torch.linalg.norm(points_3d - points_3d[..., 0, :].unsqueeze(-2), dim=-1)
- return t_vals
|