# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import torch from torch import nn from torch.nn import functional as F from kornia.core import Module, Tensor from kornia.geometry.camera import PinholeCamera from kornia.geometry.ray import Ray from kornia.nerf.positional_encoder import PositionalEncoder from kornia.nerf.samplers import sample_lengths, sample_ray_points from kornia.nerf.volume_renderer import IrregularRenderer, RegularRenderer from kornia.utils._compat import torch_inference_mode from kornia.utils.grid import create_meshgrid class MLP(Module): r"""Class to represent a multi-layer perceptron. The MLP represents a deep NN of fully connected layers. The network is build of user defined sub-units, each with a user defined number of layers. Skip connections span between the sub-units. The model follows: Ben Mildenhall et el. (2020) at https://arxiv.org/abs/2003.08934. """ def __init__(self, num_dims: int, num_units: int = 2, num_unit_layers: int = 4, num_hidden: int = 128) -> None: """Construct MLP. Args: num_dims: Number of input dimensions (channels). num_units: Number of sub-units. num_unit_layers: Number of fully connected layers in each sub-unit. num_hidden: Layer hidden dimensions. """ super().__init__() self._num_unit_layers = num_unit_layers layers = [] for i in range(num_units): num_unit_inp_dims = num_dims if i == 0 else num_hidden + num_dims for j in range(num_unit_layers): num_layer_inp_dims = num_unit_inp_dims if j == 0 else num_hidden layer = nn.Linear(num_layer_inp_dims, num_hidden) layers.append(nn.Sequential(layer, nn.ReLU())) self._mlp = nn.ModuleList(layers) def forward(self, x: Tensor) -> Tensor: out = x x_skip = x for i, layer in enumerate(self._mlp): if i > 0 and i % self._num_unit_layers == 0: out = torch.cat((out, x_skip), dim=-1) out = layer(out) return out class NerfModel(Module): r"""Class to represent NeRF model. Args: num_ray_points: Number of points to sample along rays. irregular_ray_sampling: Whether to sample ray points irregularly. num_pos_freqs: Number of frequencies for positional encoding. num_dir_freqs: Number of frequencies for directional encoding. num_units: Number of sub-units. num_unit_layers: Number of fully connected layers in each sub-unit. num_hidden: Layer hidden dimensions. log_space_encoding: Whether to apply log spacing for encoding. """ def __init__( self, num_ray_points: int, irregular_ray_sampling: bool = True, num_pos_freqs: int = 10, num_dir_freqs: int = 4, num_units: int = 2, num_unit_layers: int = 4, num_hidden: int = 128, # FIXME: add as call argument log_space_encoding: bool = True, ) -> None: super().__init__() self._num_ray_points = num_ray_points self._irregular_ray_sampling = irregular_ray_sampling self._renderer = IrregularRenderer() if self._irregular_ray_sampling else RegularRenderer() self._pos_encoder = PositionalEncoder(3, num_pos_freqs, log_space=log_space_encoding) self._dir_encoder = PositionalEncoder(3, num_dir_freqs, log_space=log_space_encoding) self._mlp = MLP(self._pos_encoder.num_encoded_dims, num_units, num_unit_layers, num_hidden) self._fc1 = nn.Linear(num_hidden, num_hidden) self._fc2 = nn.Sequential( nn.Linear(num_hidden + self._dir_encoder.num_encoded_dims, num_hidden // 2), nn.ReLU() ) self._sigma = nn.Linear(num_hidden, 1, bias=True) torch.nn.init.xavier_uniform_(self._sigma.weight.data) self._sigma.bias.data = torch.tensor([0.1]).float() self._rgb = nn.Sequential(nn.Linear(num_hidden // 2, 3), nn.Sigmoid()) self._rgb[0].bias.data = torch.tensor([0.02, 0.02, 0.02]).float() def forward(self, origins: Tensor, directions: Tensor) -> Tensor: """Forward method. Args: origins: Ray origins with shape :math:`(B, 3)`. directions: Ray directions with shape :math:`(B, 3)`. Returns: Rendered image pixels :math:`(B, 3)`. """ # Sample xyz for ray parameters batch_size = origins.shape[0] lengths = sample_lengths( batch_size, self._num_ray_points, device=origins.device, dtype=origins.dtype, irregular=self._irregular_ray_sampling, ) # FIXME: handle the case of hierarchical sampling points_3d = sample_ray_points(origins, directions, lengths) # Encode positions & directions points_3d_encoded = self._pos_encoder(points_3d) directions_encoded = self._dir_encoder(F.normalize(directions, dim=-1)) # Map positional encodings to latent features (MLP with skip connections) y = self._mlp(points_3d_encoded) y = self._fc1(y) # Calculate ray point density values densities_ray_points = self._sigma(y) densities_ray_points = densities_ray_points + torch.randn_like(densities_ray_points) * 0.1 densities_ray_points = torch.relu(densities_ray_points) # FIXME: Revise this # Calculate ray point rgb values y = torch.cat((y, directions_encoded[..., None, :].expand(-1, self._num_ray_points, -1)), dim=-1) y = self._fc2(y) rgbs_ray_points = self._rgb(y) # Rendering rgbs and densities along rays rgbs = self._renderer(rgbs_ray_points, densities_ray_points, points_3d) # Return pixel point rendered rgb return rgbs class NerfModelRenderer: """Renders a novel synthesis view of a trained NeRF model for given camera.""" def __init__( self, nerf_model: NerfModel, image_size: tuple[int, int], device: torch.device | None, dtype: torch.dtype | None ) -> None: """Construct NerfModelRenderer. Args: nerf_model: NeRF model. image_size: image size. device: device to run the model on. dtype: dtype to run the model on. """ self._nerf_model = nerf_model self._image_size = image_size self._device = device self._dtype = dtype self._pixels_grid, self._ones = self._create_pixels_grid() # 1xHxWx2 and (H*W)x1 def _create_pixels_grid(self) -> tuple[Tensor, Tensor]: """Create the pixels grid to unproject to plane z=1. Args: image_size: image size: tuple[int, int] Returns: - Pixels grid: Tensor (1, H, W, 2) - Ones: Tensor (H*W, 1) """ height, width = self._image_size pixels_grid: Tensor = create_meshgrid( height, width, normalized_coordinates=False, device=self._device, dtype=self._dtype ) # 1xHxWx2 pixels_grid = pixels_grid.reshape(-1, 2) # (H*W)x2 ones = torch.ones(pixels_grid.shape[0], 1, device=pixels_grid.device, dtype=pixels_grid.dtype) # (H*W)x1 return pixels_grid, ones def render_view(self, camera: PinholeCamera) -> Tensor: """Render a novel synthesis view of a trained NeRF model for given camera. Args: camera: camera for image rendering: PinholeCamera. Returns: Rendered image with shape :math:`(H, W, 3)`. """ # create ray for this camera rays: Ray = self._create_rays(camera) # render the image with torch_inference_mode(): rgb_model = self._nerf_model(rays.origin, rays.direction) rgb_image = rgb_model.view(self._image_size[0], self._image_size[1], 3) return rgb_image def _create_rays(self, camera: PinholeCamera) -> Ray: """Create rays for a given camera. Args: camera: camera for image rendering: PinholeCamera. """ height, width = self._image_size # convert to rays origin = camera.extrinsics[..., :3, -1] # 1x3 origin = origin.repeat(height * width, 1) # (H*W)x3 destination = camera.unproject(self._pixels_grid, self._ones) # (H*W)x3 return Ray.through(origin, destination)