| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from __future__ import annotations
- import torch
- from torch import nn
- from torch.nn import functional as F
- from kornia.core import Module, Tensor
- from kornia.geometry.camera import PinholeCamera
- from kornia.geometry.ray import Ray
- from kornia.nerf.positional_encoder import PositionalEncoder
- from kornia.nerf.samplers import sample_lengths, sample_ray_points
- from kornia.nerf.volume_renderer import IrregularRenderer, RegularRenderer
- from kornia.utils._compat import torch_inference_mode
- from kornia.utils.grid import create_meshgrid
- class MLP(Module):
- r"""Class to represent a multi-layer perceptron.
- The MLP represents a deep NN of fully connected layers.
- The network is build of user defined sub-units, each with a user defined number of layers.
- Skip connections span between the sub-units.
- The model follows: Ben Mildenhall et el. (2020) at https://arxiv.org/abs/2003.08934.
- """
- def __init__(self, num_dims: int, num_units: int = 2, num_unit_layers: int = 4, num_hidden: int = 128) -> None:
- """Construct MLP.
- Args:
- num_dims: Number of input dimensions (channels).
- num_units: Number of sub-units.
- num_unit_layers: Number of fully connected layers in each sub-unit.
- num_hidden: Layer hidden dimensions.
- """
- super().__init__()
- self._num_unit_layers = num_unit_layers
- layers = []
- for i in range(num_units):
- num_unit_inp_dims = num_dims if i == 0 else num_hidden + num_dims
- for j in range(num_unit_layers):
- num_layer_inp_dims = num_unit_inp_dims if j == 0 else num_hidden
- layer = nn.Linear(num_layer_inp_dims, num_hidden)
- layers.append(nn.Sequential(layer, nn.ReLU()))
- self._mlp = nn.ModuleList(layers)
- def forward(self, x: Tensor) -> Tensor:
- out = x
- x_skip = x
- for i, layer in enumerate(self._mlp):
- if i > 0 and i % self._num_unit_layers == 0:
- out = torch.cat((out, x_skip), dim=-1)
- out = layer(out)
- return out
- class NerfModel(Module):
- r"""Class to represent NeRF model.
- Args:
- num_ray_points: Number of points to sample along rays.
- irregular_ray_sampling: Whether to sample ray points irregularly.
- num_pos_freqs: Number of frequencies for positional encoding.
- num_dir_freqs: Number of frequencies for directional encoding.
- num_units: Number of sub-units.
- num_unit_layers: Number of fully connected layers in each sub-unit.
- num_hidden: Layer hidden dimensions.
- log_space_encoding: Whether to apply log spacing for encoding.
- """
- def __init__(
- self,
- num_ray_points: int,
- irregular_ray_sampling: bool = True,
- num_pos_freqs: int = 10,
- num_dir_freqs: int = 4,
- num_units: int = 2,
- num_unit_layers: int = 4,
- num_hidden: int = 128, # FIXME: add as call argument
- log_space_encoding: bool = True,
- ) -> None:
- super().__init__()
- self._num_ray_points = num_ray_points
- self._irregular_ray_sampling = irregular_ray_sampling
- self._renderer = IrregularRenderer() if self._irregular_ray_sampling else RegularRenderer()
- self._pos_encoder = PositionalEncoder(3, num_pos_freqs, log_space=log_space_encoding)
- self._dir_encoder = PositionalEncoder(3, num_dir_freqs, log_space=log_space_encoding)
- self._mlp = MLP(self._pos_encoder.num_encoded_dims, num_units, num_unit_layers, num_hidden)
- self._fc1 = nn.Linear(num_hidden, num_hidden)
- self._fc2 = nn.Sequential(
- nn.Linear(num_hidden + self._dir_encoder.num_encoded_dims, num_hidden // 2), nn.ReLU()
- )
- self._sigma = nn.Linear(num_hidden, 1, bias=True)
- torch.nn.init.xavier_uniform_(self._sigma.weight.data)
- self._sigma.bias.data = torch.tensor([0.1]).float()
- self._rgb = nn.Sequential(nn.Linear(num_hidden // 2, 3), nn.Sigmoid())
- self._rgb[0].bias.data = torch.tensor([0.02, 0.02, 0.02]).float()
- def forward(self, origins: Tensor, directions: Tensor) -> Tensor:
- """Forward method.
- Args:
- origins: Ray origins with shape :math:`(B, 3)`.
- directions: Ray directions with shape :math:`(B, 3)`.
- Returns:
- Rendered image pixels :math:`(B, 3)`.
- """
- # Sample xyz for ray parameters
- batch_size = origins.shape[0]
- lengths = sample_lengths(
- batch_size,
- self._num_ray_points,
- device=origins.device,
- dtype=origins.dtype,
- irregular=self._irregular_ray_sampling,
- ) # FIXME: handle the case of hierarchical sampling
- points_3d = sample_ray_points(origins, directions, lengths)
- # Encode positions & directions
- points_3d_encoded = self._pos_encoder(points_3d)
- directions_encoded = self._dir_encoder(F.normalize(directions, dim=-1))
- # Map positional encodings to latent features (MLP with skip connections)
- y = self._mlp(points_3d_encoded)
- y = self._fc1(y)
- # Calculate ray point density values
- densities_ray_points = self._sigma(y)
- densities_ray_points = densities_ray_points + torch.randn_like(densities_ray_points) * 0.1
- densities_ray_points = torch.relu(densities_ray_points) # FIXME: Revise this
- # Calculate ray point rgb values
- y = torch.cat((y, directions_encoded[..., None, :].expand(-1, self._num_ray_points, -1)), dim=-1)
- y = self._fc2(y)
- rgbs_ray_points = self._rgb(y)
- # Rendering rgbs and densities along rays
- rgbs = self._renderer(rgbs_ray_points, densities_ray_points, points_3d)
- # Return pixel point rendered rgb
- return rgbs
- class NerfModelRenderer:
- """Renders a novel synthesis view of a trained NeRF model for given camera."""
- def __init__(
- self, nerf_model: NerfModel, image_size: tuple[int, int], device: torch.device | None, dtype: torch.dtype | None
- ) -> None:
- """Construct NerfModelRenderer.
- Args:
- nerf_model: NeRF model.
- image_size: image size.
- device: device to run the model on.
- dtype: dtype to run the model on.
- """
- self._nerf_model = nerf_model
- self._image_size = image_size
- self._device = device
- self._dtype = dtype
- self._pixels_grid, self._ones = self._create_pixels_grid() # 1xHxWx2 and (H*W)x1
- def _create_pixels_grid(self) -> tuple[Tensor, Tensor]:
- """Create the pixels grid to unproject to plane z=1.
- Args:
- image_size: image size: tuple[int, int]
- Returns:
- - Pixels grid: Tensor (1, H, W, 2)
- - Ones: Tensor (H*W, 1)
- """
- height, width = self._image_size
- pixels_grid: Tensor = create_meshgrid(
- height, width, normalized_coordinates=False, device=self._device, dtype=self._dtype
- ) # 1xHxWx2
- pixels_grid = pixels_grid.reshape(-1, 2) # (H*W)x2
- ones = torch.ones(pixels_grid.shape[0], 1, device=pixels_grid.device, dtype=pixels_grid.dtype) # (H*W)x1
- return pixels_grid, ones
- def render_view(self, camera: PinholeCamera) -> Tensor:
- """Render a novel synthesis view of a trained NeRF model for given camera.
- Args:
- camera: camera for image rendering: PinholeCamera.
- Returns:
- Rendered image with shape :math:`(H, W, 3)`.
- """
- # create ray for this camera
- rays: Ray = self._create_rays(camera)
- # render the image
- with torch_inference_mode():
- rgb_model = self._nerf_model(rays.origin, rays.direction)
- rgb_image = rgb_model.view(self._image_size[0], self._image_size[1], 3)
- return rgb_image
- def _create_rays(self, camera: PinholeCamera) -> Ray:
- """Create rays for a given camera.
- Args:
- camera: camera for image rendering: PinholeCamera.
- """
- height, width = self._image_size
- # convert to rays
- origin = camera.extrinsics[..., :3, -1] # 1x3
- origin = origin.repeat(height * width, 1) # (H*W)x3
- destination = camera.unproject(self._pixels_grid, self._ones) # (H*W)x3
- return Ray.through(origin, destination)
|