nerf_solver.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. import logging
  19. from datetime import datetime
  20. from typing import cast
  21. import torch
  22. import torch.nn.functional as F
  23. from torch import optim
  24. from kornia.core import Module, Tensor, tensor
  25. from kornia.core.check import KORNIA_CHECK
  26. from kornia.geometry.camera import PinholeCamera
  27. from kornia.metrics import psnr
  28. from kornia.nerf.core import Images
  29. from kornia.nerf.data_utils import RayDataset, instantiate_ray_dataloader
  30. from kornia.nerf.nerf_model import NerfModel
  31. from kornia.utils import deprecated
  32. logger = logging.getLogger(__name__)
  33. class NerfSolver:
  34. r"""NeRF solver class.
  35. Args:
  36. device: device for class tensors: Union[str, Device]
  37. dtype: type for all floating point calculations: torch.dtype
  38. """
  39. def __init__(self, device: torch.device, dtype: torch.dtype) -> None:
  40. # TODO: add support for the new CameraModel class
  41. # cameras used for training
  42. self._cameras: PinholeCamera | None = None
  43. # rays depth range
  44. self._min_depth: float = 0.0
  45. self._max_depth: float = 0.0
  46. # whether to convert ray parameters to normalized device coordinates
  47. self._ndc: bool = True
  48. # images used for training
  49. self._imgs: Images | None = None
  50. # number of rays to randomly cast from each camera
  51. self._num_img_rays: Tensor | int | None = None
  52. # number of rays to sample in a batch
  53. self._batch_size: int = 0
  54. # number of points to sample along rays
  55. self._num_ray_points: int = 0
  56. # the model and optimizer
  57. self._nerf_model: Module | None = None
  58. self._nerf_optimizer: optim.Optimizer | None = None
  59. self._device = device
  60. self._dtype = dtype
  61. def setup_solver(
  62. self,
  63. cameras: PinholeCamera,
  64. min_depth: float,
  65. max_depth: float,
  66. ndc: bool,
  67. imgs: Images,
  68. num_img_rays: Tensor | int,
  69. batch_size: int,
  70. num_ray_points: int,
  71. irregular_ray_sampling: bool = True,
  72. log_space_encoding: bool = True,
  73. lr: float = 1.0e-3,
  74. ) -> None:
  75. """Initialize training settings and model.
  76. Args:
  77. cameras: Scene cameras in the order of input images.
  78. min_depth: sampled rays minimal depth from cameras.
  79. max_depth: sampled rays maximal depth from cameras.
  80. ndc: convert ray parameters to normalized device coordinates.
  81. imgs: Scene 2D images (one for each camera).
  82. num_img_rays: Number of rays to randomly cast from each camera: math: `(B)`.
  83. batch_size: Number of rays to sample in a batch.
  84. num_ray_points: Number of points to sample along rays.
  85. irregular_ray_sampling: Whether to sample ray points irregularly.
  86. log_space_encoding: Whether frequency sampling should be log spaced.
  87. lr: Learning rate.
  88. """
  89. self._cameras = cameras
  90. self._min_depth = min_depth
  91. self._max_depth = max_depth
  92. self._ndc = ndc
  93. self._imgs = imgs
  94. KORNIA_CHECK(
  95. isinstance(batch_size, int) and batch_size > 0,
  96. "batch_size must be a positive integer",
  97. )
  98. KORNIA_CHECK(
  99. isinstance(num_ray_points, int) and num_ray_points > 0,
  100. "num_ray_points must be a positive integer",
  101. )
  102. KORNIA_CHECK(num_img_rays is not None, "num_img_rays must be specified")
  103. if isinstance(num_img_rays, int):
  104. self._num_img_rays = tensor([num_img_rays] * cameras.batch_size)
  105. elif torch.is_tensor(num_img_rays):
  106. self._num_img_rays = num_img_rays
  107. else:
  108. raise TypeError("num_img_rays can be either an int or a Tensor")
  109. self._batch_size = batch_size
  110. self._nerf_model = NerfModel(
  111. num_ray_points, irregular_ray_sampling=irregular_ray_sampling, log_space_encoding=log_space_encoding
  112. )
  113. self._nerf_model.to(device=self._device, dtype=self._dtype)
  114. self._nerf_optimizer = optim.Adam(self._nerf_model.parameters(), lr=lr)
  115. @deprecated(replace_with="setup_solver", version="0.7.0")
  116. def init_training(
  117. self,
  118. cameras: PinholeCamera,
  119. min_depth: float,
  120. max_depth: float,
  121. ndc: bool,
  122. imgs: Images,
  123. num_img_rays: Tensor | int,
  124. batch_size: int,
  125. num_ray_points: int,
  126. irregular_ray_sampling: bool = True,
  127. log_space_encoding: bool = True,
  128. lr: float = 1.0e-3,
  129. ) -> None:
  130. self.setup_solver(
  131. cameras,
  132. min_depth,
  133. max_depth,
  134. ndc,
  135. imgs,
  136. num_img_rays,
  137. batch_size,
  138. num_ray_points,
  139. irregular_ray_sampling,
  140. log_space_encoding,
  141. lr,
  142. )
  143. @property
  144. def nerf_model(self) -> Module | None:
  145. """Returns the NeRF model."""
  146. return self._nerf_model
  147. def _train_one_epoch(self) -> float:
  148. r"""Trains the NeRF model one epoch.
  149. 1) A dataset of rays is initialized, and sent over to a data loader.
  150. 2) The data loader sample a batch of rays randomly, and runs them through the NeRF model,
  151. to predict ray associated rgb model values.
  152. 3) The model rgb is compared with the image pixel rgb, and the loss between the two is back
  153. propagated to update the model weights.
  154. Implemented steps:
  155. - Create an object of class RayDataset
  156. - Initialize ray dataset with group of images on disk, and number of rays to randomly sample
  157. - Initialize a data loader with batch size info
  158. - Iterate over data loader
  159. -- Reset optimizer
  160. -- Run ray batch through Nerf model
  161. -- Find loss
  162. -- Back propagate loss
  163. -- Optimizer step
  164. Returns:
  165. Average psnr over all epoch rays.
  166. """
  167. KORNIA_CHECK(self._nerf_model is not None, "The model should be a NeRF model.")
  168. KORNIA_CHECK(self._nerf_optimizer is not None, "The optimizer should be an Adam optimizer.")
  169. KORNIA_CHECK(self._cameras is not None, "The camera should be a PinholeCamera.")
  170. KORNIA_CHECK(self._imgs is not None, "The images should be a list of tensors.")
  171. KORNIA_CHECK(self._num_img_rays is not None, "The number of images of Ray should be a tensor.")
  172. # TODO: refactor and so that the constructor receives the correct types
  173. cameras: PinholeCamera = cast(PinholeCamera, self._cameras)
  174. num_img_rays: Tensor = cast(Tensor, self._num_img_rays)
  175. images = cast(Images, self._imgs)
  176. nerf_model: NerfModel = cast(NerfModel, self._nerf_model)
  177. nerf_optimizer: optim.Optimizer = cast(optim.Optimizer, self._nerf_optimizer)
  178. # create the dataset and data loader
  179. ray_dataset = RayDataset(
  180. cameras, self._min_depth, self._max_depth, self._ndc, device=self._device, dtype=self._dtype
  181. )
  182. ray_dataset.init_ray_dataset(num_img_rays)
  183. ray_dataset.init_images_for_training(images) # FIXME: Do we need to load the same images on each Epoch?
  184. # data loader
  185. ray_data_loader = instantiate_ray_dataloader(ray_dataset, self._batch_size, shuffle=True)
  186. total_psnr: Tensor = torch.tensor(0.0, device=self._device, dtype=self._dtype)
  187. i_batch: float = 0
  188. for origins, directions, rgbs in ray_data_loader:
  189. rgbs_model = nerf_model(origins, directions)
  190. loss = F.mse_loss(rgbs_model, rgbs)
  191. total_psnr += psnr(rgbs_model, rgbs, 1.0)
  192. nerf_optimizer.zero_grad()
  193. loss.backward()
  194. nerf_optimizer.step()
  195. i_batch += 1
  196. return float(total_psnr / (i_batch + 1))
  197. def run(self, num_epochs: int = 1) -> None:
  198. r"""Run training epochs.
  199. Args:
  200. num_epochs: number of epochs to run. Default: 1.
  201. """
  202. for i_epoch in range(num_epochs):
  203. # train one epoch
  204. epoch_psnr: float = self._train_one_epoch()
  205. if i_epoch % 10 == 0:
  206. current_time = datetime.now().strftime("%H:%M:%S") # noqa: DTZ005
  207. logger.info("Epoch: %d: epoch_psnr = %f; time: %s", i_epoch, epoch_psnr, current_time)