nerf_model.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. import torch
  19. from torch import nn
  20. from torch.nn import functional as F
  21. from kornia.core import Module, Tensor
  22. from kornia.geometry.camera import PinholeCamera
  23. from kornia.geometry.ray import Ray
  24. from kornia.nerf.positional_encoder import PositionalEncoder
  25. from kornia.nerf.samplers import sample_lengths, sample_ray_points
  26. from kornia.nerf.volume_renderer import IrregularRenderer, RegularRenderer
  27. from kornia.utils._compat import torch_inference_mode
  28. from kornia.utils.grid import create_meshgrid
  29. class MLP(Module):
  30. r"""Class to represent a multi-layer perceptron.
  31. The MLP represents a deep NN of fully connected layers.
  32. The network is build of user defined sub-units, each with a user defined number of layers.
  33. Skip connections span between the sub-units.
  34. The model follows: Ben Mildenhall et el. (2020) at https://arxiv.org/abs/2003.08934.
  35. """
  36. def __init__(self, num_dims: int, num_units: int = 2, num_unit_layers: int = 4, num_hidden: int = 128) -> None:
  37. """Construct MLP.
  38. Args:
  39. num_dims: Number of input dimensions (channels).
  40. num_units: Number of sub-units.
  41. num_unit_layers: Number of fully connected layers in each sub-unit.
  42. num_hidden: Layer hidden dimensions.
  43. """
  44. super().__init__()
  45. self._num_unit_layers = num_unit_layers
  46. layers = []
  47. for i in range(num_units):
  48. num_unit_inp_dims = num_dims if i == 0 else num_hidden + num_dims
  49. for j in range(num_unit_layers):
  50. num_layer_inp_dims = num_unit_inp_dims if j == 0 else num_hidden
  51. layer = nn.Linear(num_layer_inp_dims, num_hidden)
  52. layers.append(nn.Sequential(layer, nn.ReLU()))
  53. self._mlp = nn.ModuleList(layers)
  54. def forward(self, x: Tensor) -> Tensor:
  55. out = x
  56. x_skip = x
  57. for i, layer in enumerate(self._mlp):
  58. if i > 0 and i % self._num_unit_layers == 0:
  59. out = torch.cat((out, x_skip), dim=-1)
  60. out = layer(out)
  61. return out
  62. class NerfModel(Module):
  63. r"""Class to represent NeRF model.
  64. Args:
  65. num_ray_points: Number of points to sample along rays.
  66. irregular_ray_sampling: Whether to sample ray points irregularly.
  67. num_pos_freqs: Number of frequencies for positional encoding.
  68. num_dir_freqs: Number of frequencies for directional encoding.
  69. num_units: Number of sub-units.
  70. num_unit_layers: Number of fully connected layers in each sub-unit.
  71. num_hidden: Layer hidden dimensions.
  72. log_space_encoding: Whether to apply log spacing for encoding.
  73. """
  74. def __init__(
  75. self,
  76. num_ray_points: int,
  77. irregular_ray_sampling: bool = True,
  78. num_pos_freqs: int = 10,
  79. num_dir_freqs: int = 4,
  80. num_units: int = 2,
  81. num_unit_layers: int = 4,
  82. num_hidden: int = 128, # FIXME: add as call argument
  83. log_space_encoding: bool = True,
  84. ) -> None:
  85. super().__init__()
  86. self._num_ray_points = num_ray_points
  87. self._irregular_ray_sampling = irregular_ray_sampling
  88. self._renderer = IrregularRenderer() if self._irregular_ray_sampling else RegularRenderer()
  89. self._pos_encoder = PositionalEncoder(3, num_pos_freqs, log_space=log_space_encoding)
  90. self._dir_encoder = PositionalEncoder(3, num_dir_freqs, log_space=log_space_encoding)
  91. self._mlp = MLP(self._pos_encoder.num_encoded_dims, num_units, num_unit_layers, num_hidden)
  92. self._fc1 = nn.Linear(num_hidden, num_hidden)
  93. self._fc2 = nn.Sequential(
  94. nn.Linear(num_hidden + self._dir_encoder.num_encoded_dims, num_hidden // 2), nn.ReLU()
  95. )
  96. self._sigma = nn.Linear(num_hidden, 1, bias=True)
  97. torch.nn.init.xavier_uniform_(self._sigma.weight.data)
  98. self._sigma.bias.data = torch.tensor([0.1]).float()
  99. self._rgb = nn.Sequential(nn.Linear(num_hidden // 2, 3), nn.Sigmoid())
  100. self._rgb[0].bias.data = torch.tensor([0.02, 0.02, 0.02]).float()
  101. def forward(self, origins: Tensor, directions: Tensor) -> Tensor:
  102. """Forward method.
  103. Args:
  104. origins: Ray origins with shape :math:`(B, 3)`.
  105. directions: Ray directions with shape :math:`(B, 3)`.
  106. Returns:
  107. Rendered image pixels :math:`(B, 3)`.
  108. """
  109. # Sample xyz for ray parameters
  110. batch_size = origins.shape[0]
  111. lengths = sample_lengths(
  112. batch_size,
  113. self._num_ray_points,
  114. device=origins.device,
  115. dtype=origins.dtype,
  116. irregular=self._irregular_ray_sampling,
  117. ) # FIXME: handle the case of hierarchical sampling
  118. points_3d = sample_ray_points(origins, directions, lengths)
  119. # Encode positions & directions
  120. points_3d_encoded = self._pos_encoder(points_3d)
  121. directions_encoded = self._dir_encoder(F.normalize(directions, dim=-1))
  122. # Map positional encodings to latent features (MLP with skip connections)
  123. y = self._mlp(points_3d_encoded)
  124. y = self._fc1(y)
  125. # Calculate ray point density values
  126. densities_ray_points = self._sigma(y)
  127. densities_ray_points = densities_ray_points + torch.randn_like(densities_ray_points) * 0.1
  128. densities_ray_points = torch.relu(densities_ray_points) # FIXME: Revise this
  129. # Calculate ray point rgb values
  130. y = torch.cat((y, directions_encoded[..., None, :].expand(-1, self._num_ray_points, -1)), dim=-1)
  131. y = self._fc2(y)
  132. rgbs_ray_points = self._rgb(y)
  133. # Rendering rgbs and densities along rays
  134. rgbs = self._renderer(rgbs_ray_points, densities_ray_points, points_3d)
  135. # Return pixel point rendered rgb
  136. return rgbs
  137. class NerfModelRenderer:
  138. """Renders a novel synthesis view of a trained NeRF model for given camera."""
  139. def __init__(
  140. self, nerf_model: NerfModel, image_size: tuple[int, int], device: torch.device | None, dtype: torch.dtype | None
  141. ) -> None:
  142. """Construct NerfModelRenderer.
  143. Args:
  144. nerf_model: NeRF model.
  145. image_size: image size.
  146. device: device to run the model on.
  147. dtype: dtype to run the model on.
  148. """
  149. self._nerf_model = nerf_model
  150. self._image_size = image_size
  151. self._device = device
  152. self._dtype = dtype
  153. self._pixels_grid, self._ones = self._create_pixels_grid() # 1xHxWx2 and (H*W)x1
  154. def _create_pixels_grid(self) -> tuple[Tensor, Tensor]:
  155. """Create the pixels grid to unproject to plane z=1.
  156. Args:
  157. image_size: image size: tuple[int, int]
  158. Returns:
  159. - Pixels grid: Tensor (1, H, W, 2)
  160. - Ones: Tensor (H*W, 1)
  161. """
  162. height, width = self._image_size
  163. pixels_grid: Tensor = create_meshgrid(
  164. height, width, normalized_coordinates=False, device=self._device, dtype=self._dtype
  165. ) # 1xHxWx2
  166. pixels_grid = pixels_grid.reshape(-1, 2) # (H*W)x2
  167. ones = torch.ones(pixels_grid.shape[0], 1, device=pixels_grid.device, dtype=pixels_grid.dtype) # (H*W)x1
  168. return pixels_grid, ones
  169. def render_view(self, camera: PinholeCamera) -> Tensor:
  170. """Render a novel synthesis view of a trained NeRF model for given camera.
  171. Args:
  172. camera: camera for image rendering: PinholeCamera.
  173. Returns:
  174. Rendered image with shape :math:`(H, W, 3)`.
  175. """
  176. # create ray for this camera
  177. rays: Ray = self._create_rays(camera)
  178. # render the image
  179. with torch_inference_mode():
  180. rgb_model = self._nerf_model(rays.origin, rays.direction)
  181. rgb_image = rgb_model.view(self._image_size[0], self._image_size[1], 3)
  182. return rgb_image
  183. def _create_rays(self, camera: PinholeCamera) -> Ray:
  184. """Create rays for a given camera.
  185. Args:
  186. camera: camera for image rendering: PinholeCamera.
  187. """
  188. height, width = self._image_size
  189. # convert to rays
  190. origin = camera.extrinsics[..., :3, -1] # 1x3
  191. origin = origin.repeat(height * width, 1) # (H*W)x3
  192. destination = camera.unproject(self._pixels_grid, self._ones) # (H*W)x3
  193. return Ray.through(origin, destination)