# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union import torch from torch.utils.data import BatchSampler, DataLoader, Dataset, RandomSampler, SequentialSampler from typing_extensions import TypeGuard from kornia.core import Device, Tensor, stack from kornia.geometry.camera import PinholeCamera from kornia.io import ImageLoadType, load_image from kornia.nerf.core import Images, ImageTensors from kornia.nerf.samplers import RandomRaySampler, RaySampler, UniformRaySampler RayGroup = Tuple[Tensor, Tensor, Optional[Tensor]] def _is_list_of_str(lst: Sequence[object]) -> TypeGuard[List[str]]: return isinstance(lst, list) and all(isinstance(x, str) for x in lst) def _is_list_of_tensors(lst: Sequence[object]) -> TypeGuard[List[Tensor]]: return isinstance(lst, list) and all(isinstance(x, Tensor) for x in lst) class RayDataset(Dataset[RayGroup]): r"""Class to represent a dataset of rays. Args: cameras: scene cameras: PinholeCamera min_depth: sampled rays minimal depth from cameras: float max_depth: sampled rays maximal depth from cameras: float ndc: convert ray parameters to normalized device coordinates: bool device: device for ray tensors: Union[str, torch.device] dtype: type of ray tensors: torch.dtype """ def __init__( self, cameras: PinholeCamera, min_depth: float, max_depth: float, ndc: bool, device: Device, dtype: torch.dtype ) -> None: super().__init__() self._ray_sampler: Optional[RaySampler] = None self._imgs: Optional[List[Tensor]] = None self._cameras = cameras self._min_depth = min_depth self._max_depth = max_depth self._ndc = ndc self._device = device self._dtype = dtype def init_ray_dataset(self, num_img_rays: Optional[Tensor] = None) -> None: r"""Initialize a ray dataset. Args: num_img_rays: If not None, number of rays to randomly cast from each camera: math: `(B)`. """ if num_img_rays is None: self._init_uniform_ray_dataset() else: self._init_random_ray_dataset(num_img_rays) def init_images_for_training(self, imgs: Images) -> None: r"""Initialize images for training. Images can be either a list of tensors, or a list of paths to image disk locations. Args: imgs: List of image tensors or image paths: Images """ self._check_image_type_consistency(imgs) if _is_list_of_str(imgs): # Load images from disk images = self._load_images(imgs) elif _is_list_of_tensors(imgs): images = imgs # Take images provided on input else: raise TypeError(f"Expected a list of image tensors or image paths. Gotcha {type(imgs)}.") self._check_dimensions(images) # Move images to defined device self._imgs = [img.to(self._device) for img in images] def _init_random_ray_dataset(self, num_img_rays: Tensor) -> None: r"""Initialize a random ray sampler and calculates dataset ray parameters. Args: num_img_rays: If not None, number of rays to randomly cast from each camers: math: `(B)`. """ self._ray_sampler = RandomRaySampler( self._min_depth, self._max_depth, self._ndc, device=self._device, dtype=self._dtype ) self._ray_sampler.calc_ray_params(self._cameras, num_img_rays) def _init_uniform_ray_dataset(self) -> None: r"""Initialize a uniform ray sampler and calculates dataset ray parameters.""" self._ray_sampler = UniformRaySampler( self._min_depth, self._max_depth, self._ndc, device=self._device, dtype=self._dtype ) self._ray_sampler.calc_ray_params(self._cameras) def _check_image_type_consistency(self, imgs: Images) -> None: if not all(isinstance(img, str) for img in imgs) and not all(isinstance(img, Tensor) for img in imgs): raise ValueError("The list of input images can only be all paths or tensors") def _check_dimensions(self, imgs: ImageTensors) -> None: if len(imgs) != self._cameras.batch_size: raise ValueError( f"Number of images {len(imgs)} does not match number of cameras {self._cameras.batch_size}" ) if not all(img.shape[0] == 3 for img in imgs): raise ValueError("Not all input images have 3 channels") for i, (img, height, width) in enumerate(zip(imgs, self._cameras.height, self._cameras.width)): if img.shape[1:] != (height, width): raise ValueError( f"Image index {i} dimensions {(img.shape[1], img.shape[2])} are inconsistent with equivalent " f"camera dimensions {(height.item(), width.item())}" ) @staticmethod def _load_images(img_paths: List[str]) -> List[Tensor]: imgs: List[Tensor] = [] for img_path in img_paths: imgs.append(load_image(img_path, ImageLoadType.UNCHANGED)) return imgs def __len__(self) -> int: if isinstance(self._ray_sampler, RaySampler): return len(self._ray_sampler) return 0 def __getitem__(self, idxs: Union[int, List[int]]) -> RayGroup: r"""Get a dataset item. Args: idxs: An index or group of indices of ray parameter object: Union[int, List[int]] Return: A ray parameter object that includes ray origins, directions, and rgb values at the ray 2d pixel coordinates: RayGroup """ if not isinstance(self._ray_sampler, RaySampler): raise TypeError("Ray sampler is not initiate yet, please run self.init_ray_dataset() before use it.") origins = self._ray_sampler.origins[idxs] directions = self._ray_sampler.directions[idxs] if self._imgs is None: return origins, directions, None camerd_ids = self._ray_sampler.camera_ids[idxs] points_2d = self._ray_sampler.points_2d[idxs] rgbs = None imgs_for_ids = [self._imgs[i] for i in camerd_ids] rgbs = stack([img[:, point2d[1].item(), point2d[0].item()] for img, point2d in zip(imgs_for_ids, points_2d)]) rgbs = rgbs.to(dtype=self._dtype) / 255.0 return origins, directions, rgbs def instantiate_ray_dataloader(dataset: RayDataset, batch_size: int = 1, shuffle: bool = True) -> DataLoader[RayGroup]: r"""Initialize a dataloader to manage a ray dataset. Args: dataset: A ray dataset: RayDataset batch_size: Number of rays to sample in a batch: int shuffle: Whether to shuffle rays or sample then sequentially: bool """ def collate_rays(items: List[RayGroup]) -> RayGroup: return items[0] if TYPE_CHECKING: # TODO: remove the type ignore when kornia relies on kornia 1.10 return DataLoader(dataset) else: return DataLoader( dataset, sampler=BatchSampler( RandomSampler(dataset) if shuffle else SequentialSampler(dataset), batch_size, drop_last=False ), collate_fn=collate_rays, )