data_utils.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union
  18. import torch
  19. from torch.utils.data import BatchSampler, DataLoader, Dataset, RandomSampler, SequentialSampler
  20. from typing_extensions import TypeGuard
  21. from kornia.core import Device, Tensor, stack
  22. from kornia.geometry.camera import PinholeCamera
  23. from kornia.io import ImageLoadType, load_image
  24. from kornia.nerf.core import Images, ImageTensors
  25. from kornia.nerf.samplers import RandomRaySampler, RaySampler, UniformRaySampler
  26. RayGroup = Tuple[Tensor, Tensor, Optional[Tensor]]
  27. def _is_list_of_str(lst: Sequence[object]) -> TypeGuard[List[str]]:
  28. return isinstance(lst, list) and all(isinstance(x, str) for x in lst)
  29. def _is_list_of_tensors(lst: Sequence[object]) -> TypeGuard[List[Tensor]]:
  30. return isinstance(lst, list) and all(isinstance(x, Tensor) for x in lst)
  31. class RayDataset(Dataset[RayGroup]):
  32. r"""Class to represent a dataset of rays.
  33. Args:
  34. cameras: scene cameras: PinholeCamera
  35. min_depth: sampled rays minimal depth from cameras: float
  36. max_depth: sampled rays maximal depth from cameras: float
  37. ndc: convert ray parameters to normalized device coordinates: bool
  38. device: device for ray tensors: Union[str, torch.device]
  39. dtype: type of ray tensors: torch.dtype
  40. """
  41. def __init__(
  42. self, cameras: PinholeCamera, min_depth: float, max_depth: float, ndc: bool, device: Device, dtype: torch.dtype
  43. ) -> None:
  44. super().__init__()
  45. self._ray_sampler: Optional[RaySampler] = None
  46. self._imgs: Optional[List[Tensor]] = None
  47. self._cameras = cameras
  48. self._min_depth = min_depth
  49. self._max_depth = max_depth
  50. self._ndc = ndc
  51. self._device = device
  52. self._dtype = dtype
  53. def init_ray_dataset(self, num_img_rays: Optional[Tensor] = None) -> None:
  54. r"""Initialize a ray dataset.
  55. Args:
  56. num_img_rays: If not None, number of rays to randomly cast from each camera: math: `(B)`.
  57. """
  58. if num_img_rays is None:
  59. self._init_uniform_ray_dataset()
  60. else:
  61. self._init_random_ray_dataset(num_img_rays)
  62. def init_images_for_training(self, imgs: Images) -> None:
  63. r"""Initialize images for training.
  64. Images can be either a list of tensors, or a list of paths to image disk locations.
  65. Args:
  66. imgs: List of image tensors or image paths: Images
  67. """
  68. self._check_image_type_consistency(imgs)
  69. if _is_list_of_str(imgs): # Load images from disk
  70. images = self._load_images(imgs)
  71. elif _is_list_of_tensors(imgs):
  72. images = imgs # Take images provided on input
  73. else:
  74. raise TypeError(f"Expected a list of image tensors or image paths. Gotcha {type(imgs)}.")
  75. self._check_dimensions(images)
  76. # Move images to defined device
  77. self._imgs = [img.to(self._device) for img in images]
  78. def _init_random_ray_dataset(self, num_img_rays: Tensor) -> None:
  79. r"""Initialize a random ray sampler and calculates dataset ray parameters.
  80. Args:
  81. num_img_rays: If not None, number of rays to randomly cast from each camers: math: `(B)`.
  82. """
  83. self._ray_sampler = RandomRaySampler(
  84. self._min_depth, self._max_depth, self._ndc, device=self._device, dtype=self._dtype
  85. )
  86. self._ray_sampler.calc_ray_params(self._cameras, num_img_rays)
  87. def _init_uniform_ray_dataset(self) -> None:
  88. r"""Initialize a uniform ray sampler and calculates dataset ray parameters."""
  89. self._ray_sampler = UniformRaySampler(
  90. self._min_depth, self._max_depth, self._ndc, device=self._device, dtype=self._dtype
  91. )
  92. self._ray_sampler.calc_ray_params(self._cameras)
  93. def _check_image_type_consistency(self, imgs: Images) -> None:
  94. if not all(isinstance(img, str) for img in imgs) and not all(isinstance(img, Tensor) for img in imgs):
  95. raise ValueError("The list of input images can only be all paths or tensors")
  96. def _check_dimensions(self, imgs: ImageTensors) -> None:
  97. if len(imgs) != self._cameras.batch_size:
  98. raise ValueError(
  99. f"Number of images {len(imgs)} does not match number of cameras {self._cameras.batch_size}"
  100. )
  101. if not all(img.shape[0] == 3 for img in imgs):
  102. raise ValueError("Not all input images have 3 channels")
  103. for i, (img, height, width) in enumerate(zip(imgs, self._cameras.height, self._cameras.width)):
  104. if img.shape[1:] != (height, width):
  105. raise ValueError(
  106. f"Image index {i} dimensions {(img.shape[1], img.shape[2])} are inconsistent with equivalent "
  107. f"camera dimensions {(height.item(), width.item())}"
  108. )
  109. @staticmethod
  110. def _load_images(img_paths: List[str]) -> List[Tensor]:
  111. imgs: List[Tensor] = []
  112. for img_path in img_paths:
  113. imgs.append(load_image(img_path, ImageLoadType.UNCHANGED))
  114. return imgs
  115. def __len__(self) -> int:
  116. if isinstance(self._ray_sampler, RaySampler):
  117. return len(self._ray_sampler)
  118. return 0
  119. def __getitem__(self, idxs: Union[int, List[int]]) -> RayGroup:
  120. r"""Get a dataset item.
  121. Args:
  122. idxs: An index or group of indices of ray parameter object: Union[int, List[int]]
  123. Return:
  124. A ray parameter object that includes ray origins, directions, and rgb values at the ray 2d pixel
  125. coordinates: RayGroup
  126. """
  127. if not isinstance(self._ray_sampler, RaySampler):
  128. raise TypeError("Ray sampler is not initiate yet, please run self.init_ray_dataset() before use it.")
  129. origins = self._ray_sampler.origins[idxs]
  130. directions = self._ray_sampler.directions[idxs]
  131. if self._imgs is None:
  132. return origins, directions, None
  133. camerd_ids = self._ray_sampler.camera_ids[idxs]
  134. points_2d = self._ray_sampler.points_2d[idxs]
  135. rgbs = None
  136. imgs_for_ids = [self._imgs[i] for i in camerd_ids]
  137. rgbs = stack([img[:, point2d[1].item(), point2d[0].item()] for img, point2d in zip(imgs_for_ids, points_2d)])
  138. rgbs = rgbs.to(dtype=self._dtype) / 255.0
  139. return origins, directions, rgbs
  140. def instantiate_ray_dataloader(dataset: RayDataset, batch_size: int = 1, shuffle: bool = True) -> DataLoader[RayGroup]:
  141. r"""Initialize a dataloader to manage a ray dataset.
  142. Args:
  143. dataset: A ray dataset: RayDataset
  144. batch_size: Number of rays to sample in a batch: int
  145. shuffle: Whether to shuffle rays or sample then sequentially: bool
  146. """
  147. def collate_rays(items: List[RayGroup]) -> RayGroup:
  148. return items[0]
  149. if TYPE_CHECKING:
  150. # TODO: remove the type ignore when kornia relies on kornia 1.10
  151. return DataLoader(dataset)
  152. else:
  153. return DataLoader(
  154. dataset,
  155. sampler=BatchSampler(
  156. RandomSampler(dataset) if shuffle else SequentialSampler(dataset), batch_size, drop_last=False
  157. ),
  158. collate_fn=collate_rays,
  159. )