# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import torch from kornia.core import Tensor from kornia.core.check import KORNIA_CHECK_SHAPE from kornia.nerf.samplers import calc_ray_t_vals class VolumeRenderer(torch.nn.Module): r"""Base class for volume rendering. Implementation follows Ben Mildenhall et el. (2020) at https://arxiv.org/abs/2003.08934. """ _huge = 1.0e10 _eps = 1.0e-10 def __init__(self, shift: int = 1) -> None: """Initialize the renderer. Args: shift: Size of far-field layer: int """ super().__init__() self._shift = shift def _render(self, alpha: Tensor, rgbs: Tensor) -> Tensor: trans = torch.cumprod(1 - alpha + self._eps, dim=-2) # (*, N, 1) trans = torch.roll(trans, shifts=self._shift, dims=-2) # (*, N, 1) trans[..., : self._shift, :] = 1 # (*, N, 1) weights = trans * alpha # (*, N, 1) rgbs_rendered = torch.sum(weights * rgbs, dim=-2) # (*, 3) return rgbs_rendered def forward(self, rgbs: Tensor, densities: Tensor, points_3d: Tensor) -> Tensor: raise NotImplementedError class IrregularRenderer(VolumeRenderer): """Renders 3D irregularly sampled points along rays.""" def forward(self, rgbs: Tensor, densities: Tensor, points_3d: Tensor) -> Tensor: r"""Render 3D irregularly sampled points along rays. Args: rgbs: RGB values of points along rays :math:`(*, N, 3)` densities: Volume densities of points along rays :math:`(*, N)` points_3d: 3D points along rays :math:`(*, N, 3)` Returns: Rendered RGB values for each ray :math:`(*, 3)` """ t_vals = calc_ray_t_vals(points_3d) deltas = t_vals[..., 1:] - t_vals[..., :-1] # (*, N - 1) far = torch.empty(size=t_vals.shape[:-1], dtype=t_vals.dtype, device=t_vals.device).fill_(self._huge) deltas = torch.cat([deltas, far[..., None]], dim=-1) # (*, N) alpha = 1 - torch.exp(-1.0 * densities * deltas[..., None]) # (*, N) return self._render(alpha, rgbs) class RegularRenderer(VolumeRenderer): """Renders 3D regularly sampled points along rays.""" def forward(self, rgbs: Tensor, densities: Tensor, points_3d: Tensor) -> Tensor: r"""Render 3D regularly sampled points along rays. Args: rgbs: RGB values of points along rays :math:`(*, N, 3)` densities: Volume densities of points along rays :math:`(*, N)` points_3d: 3D points along rays :math:`(*, N, 3)` Returns: Rendered RGB values for each ray :math:`(*, 3)` """ KORNIA_CHECK_SHAPE(rgbs, ["*", "N", "3"]) KORNIA_CHECK_SHAPE(densities, ["*", "N"]) KORNIA_CHECK_SHAPE(points_3d, ["*", "N", "3"]) num_ray_points: int = points_3d.shape[-2] points_3d = points_3d.reshape(-1, num_ray_points, 3) # (*, N, 3) delta_3d = points_3d[0, 1, :] - points_3d[0, 0, :] # (*, 3) delta = torch.linalg.norm(delta_3d, dim=-1) alpha = 1 - torch.exp(-1.0 * densities * delta) # (*, N) return self._render(alpha, rgbs)