| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import torch
- from kornia.core import Tensor
- from kornia.core.check import KORNIA_CHECK_SHAPE
- from kornia.nerf.samplers import calc_ray_t_vals
- class VolumeRenderer(torch.nn.Module):
- r"""Base class for volume rendering.
- Implementation follows Ben Mildenhall et el. (2020) at https://arxiv.org/abs/2003.08934.
- """
- _huge = 1.0e10
- _eps = 1.0e-10
- def __init__(self, shift: int = 1) -> None:
- """Initialize the renderer.
- Args:
- shift: Size of far-field layer: int
- """
- super().__init__()
- self._shift = shift
- def _render(self, alpha: Tensor, rgbs: Tensor) -> Tensor:
- trans = torch.cumprod(1 - alpha + self._eps, dim=-2) # (*, N, 1)
- trans = torch.roll(trans, shifts=self._shift, dims=-2) # (*, N, 1)
- trans[..., : self._shift, :] = 1 # (*, N, 1)
- weights = trans * alpha # (*, N, 1)
- rgbs_rendered = torch.sum(weights * rgbs, dim=-2) # (*, 3)
- return rgbs_rendered
- def forward(self, rgbs: Tensor, densities: Tensor, points_3d: Tensor) -> Tensor:
- raise NotImplementedError
- class IrregularRenderer(VolumeRenderer):
- """Renders 3D irregularly sampled points along rays."""
- def forward(self, rgbs: Tensor, densities: Tensor, points_3d: Tensor) -> Tensor:
- r"""Render 3D irregularly sampled points along rays.
- Args:
- rgbs: RGB values of points along rays :math:`(*, N, 3)`
- densities: Volume densities of points along rays :math:`(*, N)`
- points_3d: 3D points along rays :math:`(*, N, 3)`
- Returns:
- Rendered RGB values for each ray :math:`(*, 3)`
- """
- t_vals = calc_ray_t_vals(points_3d)
- deltas = t_vals[..., 1:] - t_vals[..., :-1] # (*, N - 1)
- far = torch.empty(size=t_vals.shape[:-1], dtype=t_vals.dtype, device=t_vals.device).fill_(self._huge)
- deltas = torch.cat([deltas, far[..., None]], dim=-1) # (*, N)
- alpha = 1 - torch.exp(-1.0 * densities * deltas[..., None]) # (*, N)
- return self._render(alpha, rgbs)
- class RegularRenderer(VolumeRenderer):
- """Renders 3D regularly sampled points along rays."""
- def forward(self, rgbs: Tensor, densities: Tensor, points_3d: Tensor) -> Tensor:
- r"""Render 3D regularly sampled points along rays.
- Args:
- rgbs: RGB values of points along rays :math:`(*, N, 3)`
- densities: Volume densities of points along rays :math:`(*, N)`
- points_3d: 3D points along rays :math:`(*, N, 3)`
- Returns:
- Rendered RGB values for each ray :math:`(*, 3)`
- """
- KORNIA_CHECK_SHAPE(rgbs, ["*", "N", "3"])
- KORNIA_CHECK_SHAPE(densities, ["*", "N"])
- KORNIA_CHECK_SHAPE(points_3d, ["*", "N", "3"])
- num_ray_points: int = points_3d.shape[-2]
- points_3d = points_3d.reshape(-1, num_ray_points, 3) # (*, N, 3)
- delta_3d = points_3d[0, 1, :] - points_3d[0, 0, :] # (*, 3)
- delta = torch.linalg.norm(delta_3d, dim=-1)
- alpha = 1 - torch.exp(-1.0 * densities * delta) # (*, N)
- return self._render(alpha, rgbs)
|