| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from __future__ import annotations
- import logging
- from datetime import datetime
- from typing import cast
- import torch
- import torch.nn.functional as F
- from torch import optim
- from kornia.core import Module, Tensor, tensor
- from kornia.core.check import KORNIA_CHECK
- from kornia.geometry.camera import PinholeCamera
- from kornia.metrics import psnr
- from kornia.nerf.core import Images
- from kornia.nerf.data_utils import RayDataset, instantiate_ray_dataloader
- from kornia.nerf.nerf_model import NerfModel
- from kornia.utils import deprecated
- logger = logging.getLogger(__name__)
- class NerfSolver:
- r"""NeRF solver class.
- Args:
- device: device for class tensors: Union[str, Device]
- dtype: type for all floating point calculations: torch.dtype
- """
- def __init__(self, device: torch.device, dtype: torch.dtype) -> None:
- # TODO: add support for the new CameraModel class
- # cameras used for training
- self._cameras: PinholeCamera | None = None
- # rays depth range
- self._min_depth: float = 0.0
- self._max_depth: float = 0.0
- # whether to convert ray parameters to normalized device coordinates
- self._ndc: bool = True
- # images used for training
- self._imgs: Images | None = None
- # number of rays to randomly cast from each camera
- self._num_img_rays: Tensor | int | None = None
- # number of rays to sample in a batch
- self._batch_size: int = 0
- # number of points to sample along rays
- self._num_ray_points: int = 0
- # the model and optimizer
- self._nerf_model: Module | None = None
- self._nerf_optimizer: optim.Optimizer | None = None
- self._device = device
- self._dtype = dtype
- def setup_solver(
- self,
- cameras: PinholeCamera,
- min_depth: float,
- max_depth: float,
- ndc: bool,
- imgs: Images,
- num_img_rays: Tensor | int,
- batch_size: int,
- num_ray_points: int,
- irregular_ray_sampling: bool = True,
- log_space_encoding: bool = True,
- lr: float = 1.0e-3,
- ) -> None:
- """Initialize training settings and model.
- Args:
- cameras: Scene cameras in the order of input images.
- min_depth: sampled rays minimal depth from cameras.
- max_depth: sampled rays maximal depth from cameras.
- ndc: convert ray parameters to normalized device coordinates.
- imgs: Scene 2D images (one for each camera).
- num_img_rays: Number of rays to randomly cast from each camera: math: `(B)`.
- batch_size: Number of rays to sample in a batch.
- num_ray_points: Number of points to sample along rays.
- irregular_ray_sampling: Whether to sample ray points irregularly.
- log_space_encoding: Whether frequency sampling should be log spaced.
- lr: Learning rate.
- """
- self._cameras = cameras
- self._min_depth = min_depth
- self._max_depth = max_depth
- self._ndc = ndc
- self._imgs = imgs
- KORNIA_CHECK(
- isinstance(batch_size, int) and batch_size > 0,
- "batch_size must be a positive integer",
- )
- KORNIA_CHECK(
- isinstance(num_ray_points, int) and num_ray_points > 0,
- "num_ray_points must be a positive integer",
- )
- KORNIA_CHECK(num_img_rays is not None, "num_img_rays must be specified")
- if isinstance(num_img_rays, int):
- self._num_img_rays = tensor([num_img_rays] * cameras.batch_size)
- elif torch.is_tensor(num_img_rays):
- self._num_img_rays = num_img_rays
- else:
- raise TypeError("num_img_rays can be either an int or a Tensor")
- self._batch_size = batch_size
- self._nerf_model = NerfModel(
- num_ray_points, irregular_ray_sampling=irregular_ray_sampling, log_space_encoding=log_space_encoding
- )
- self._nerf_model.to(device=self._device, dtype=self._dtype)
- self._nerf_optimizer = optim.Adam(self._nerf_model.parameters(), lr=lr)
- @deprecated(replace_with="setup_solver", version="0.7.0")
- def init_training(
- self,
- cameras: PinholeCamera,
- min_depth: float,
- max_depth: float,
- ndc: bool,
- imgs: Images,
- num_img_rays: Tensor | int,
- batch_size: int,
- num_ray_points: int,
- irregular_ray_sampling: bool = True,
- log_space_encoding: bool = True,
- lr: float = 1.0e-3,
- ) -> None:
- self.setup_solver(
- cameras,
- min_depth,
- max_depth,
- ndc,
- imgs,
- num_img_rays,
- batch_size,
- num_ray_points,
- irregular_ray_sampling,
- log_space_encoding,
- lr,
- )
- @property
- def nerf_model(self) -> Module | None:
- """Returns the NeRF model."""
- return self._nerf_model
- def _train_one_epoch(self) -> float:
- r"""Trains the NeRF model one epoch.
- 1) A dataset of rays is initialized, and sent over to a data loader.
- 2) The data loader sample a batch of rays randomly, and runs them through the NeRF model,
- to predict ray associated rgb model values.
- 3) The model rgb is compared with the image pixel rgb, and the loss between the two is back
- propagated to update the model weights.
- Implemented steps:
- - Create an object of class RayDataset
- - Initialize ray dataset with group of images on disk, and number of rays to randomly sample
- - Initialize a data loader with batch size info
- - Iterate over data loader
- -- Reset optimizer
- -- Run ray batch through Nerf model
- -- Find loss
- -- Back propagate loss
- -- Optimizer step
- Returns:
- Average psnr over all epoch rays.
- """
- KORNIA_CHECK(self._nerf_model is not None, "The model should be a NeRF model.")
- KORNIA_CHECK(self._nerf_optimizer is not None, "The optimizer should be an Adam optimizer.")
- KORNIA_CHECK(self._cameras is not None, "The camera should be a PinholeCamera.")
- KORNIA_CHECK(self._imgs is not None, "The images should be a list of tensors.")
- KORNIA_CHECK(self._num_img_rays is not None, "The number of images of Ray should be a tensor.")
- # TODO: refactor and so that the constructor receives the correct types
- cameras: PinholeCamera = cast(PinholeCamera, self._cameras)
- num_img_rays: Tensor = cast(Tensor, self._num_img_rays)
- images = cast(Images, self._imgs)
- nerf_model: NerfModel = cast(NerfModel, self._nerf_model)
- nerf_optimizer: optim.Optimizer = cast(optim.Optimizer, self._nerf_optimizer)
- # create the dataset and data loader
- ray_dataset = RayDataset(
- cameras, self._min_depth, self._max_depth, self._ndc, device=self._device, dtype=self._dtype
- )
- ray_dataset.init_ray_dataset(num_img_rays)
- ray_dataset.init_images_for_training(images) # FIXME: Do we need to load the same images on each Epoch?
- # data loader
- ray_data_loader = instantiate_ray_dataloader(ray_dataset, self._batch_size, shuffle=True)
- total_psnr: Tensor = torch.tensor(0.0, device=self._device, dtype=self._dtype)
- i_batch: float = 0
- for origins, directions, rgbs in ray_data_loader:
- rgbs_model = nerf_model(origins, directions)
- loss = F.mse_loss(rgbs_model, rgbs)
- total_psnr += psnr(rgbs_model, rgbs, 1.0)
- nerf_optimizer.zero_grad()
- loss.backward()
- nerf_optimizer.step()
- i_batch += 1
- return float(total_psnr / (i_batch + 1))
- def run(self, num_epochs: int = 1) -> None:
- r"""Run training epochs.
- Args:
- num_epochs: number of epochs to run. Default: 1.
- """
- for i_epoch in range(num_epochs):
- # train one epoch
- epoch_psnr: float = self._train_one_epoch()
- if i_epoch % 10 == 0:
- current_time = datetime.now().strftime("%H:%M:%S") # noqa: DTZ005
- logger.info("Epoch: %d: epoch_psnr = %f; time: %s", i_epoch, epoch_psnr, current_time)
|