megadepth.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import os
  2. from PIL import Image
  3. import h5py
  4. import numpy as np
  5. import torch
  6. import torchvision.transforms.functional as tvf
  7. import kornia.augmentation as K
  8. from romatch.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
  9. import romatch
  10. from romatch.utils import *
  11. import math
  12. class MegadepthScene:
  13. def __init__(
  14. self,
  15. data_root,
  16. scene_info,
  17. ht=384,
  18. wt=512,
  19. min_overlap=0.0,
  20. max_overlap=1.0,
  21. shake_t=0,
  22. rot_prob=0.0,
  23. normalize=True,
  24. max_num_pairs = 100_000,
  25. scene_name = None,
  26. use_horizontal_flip_aug = False,
  27. use_single_horizontal_flip_aug = False,
  28. colorjiggle_params = None,
  29. random_eraser = None,
  30. use_randaug = False,
  31. randaug_params = None,
  32. randomize_size = False,
  33. ) -> None:
  34. self.data_root = data_root
  35. self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}"
  36. self.image_paths = scene_info["image_paths"]
  37. self.depth_paths = scene_info["depth_paths"]
  38. self.intrinsics = scene_info["intrinsics"]
  39. self.poses = scene_info["poses"]
  40. self.pairs = scene_info["pairs"]
  41. self.overlaps = scene_info["overlaps"]
  42. threshold = (self.overlaps > min_overlap) & (self.overlaps < max_overlap)
  43. self.pairs = self.pairs[threshold]
  44. self.overlaps = self.overlaps[threshold]
  45. if len(self.pairs) > max_num_pairs:
  46. pairinds = np.random.choice(
  47. np.arange(0, len(self.pairs)), max_num_pairs, replace=False
  48. )
  49. self.pairs = self.pairs[pairinds]
  50. self.overlaps = self.overlaps[pairinds]
  51. if randomize_size:
  52. area = ht * wt
  53. s = int(16 * (math.sqrt(area)//16))
  54. sizes = ((ht,wt), (s,s), (wt,ht))
  55. choice = romatch.RANK % 3
  56. ht, wt = sizes[choice]
  57. # counts, bins = np.histogram(self.overlaps,20)
  58. # print(counts)
  59. self.im_transform_ops = get_tuple_transform_ops(
  60. resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params,
  61. )
  62. self.depth_transform_ops = get_depth_tuple_transform_ops(
  63. resize=(ht, wt)
  64. )
  65. self.wt, self.ht = wt, ht
  66. self.shake_t = shake_t
  67. self.random_eraser = random_eraser
  68. if use_horizontal_flip_aug and use_single_horizontal_flip_aug:
  69. raise ValueError("Can't both flip both images and only flip one")
  70. self.use_horizontal_flip_aug = use_horizontal_flip_aug
  71. self.use_single_horizontal_flip_aug = use_single_horizontal_flip_aug
  72. self.use_randaug = use_randaug
  73. def load_im(self, im_path):
  74. im = Image.open(im_path)
  75. return im
  76. def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
  77. im_A = im_A.flip(-1)
  78. im_B = im_B.flip(-1)
  79. depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
  80. flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
  81. K_A = flip_mat@K_A
  82. K_B = flip_mat@K_B
  83. return im_A, im_B, depth_A, depth_B, K_A, K_B
  84. def load_depth(self, depth_ref, crop=None):
  85. depth = np.array(h5py.File(depth_ref, "r")["depth"])
  86. return torch.from_numpy(depth)
  87. def __len__(self):
  88. return len(self.pairs)
  89. def scale_intrinsic(self, K, wi, hi):
  90. sx, sy = self.wt / wi, self.ht / hi
  91. sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
  92. return sK @ K
  93. def rand_shake(self, *things):
  94. t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2)
  95. return [
  96. tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
  97. for thing in things
  98. ], t
  99. def __getitem__(self, pair_idx):
  100. # read intrinsics of original size
  101. idx1, idx2 = self.pairs[pair_idx]
  102. K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
  103. K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
  104. # read and compute relative poses
  105. T1 = self.poses[idx1]
  106. T2 = self.poses[idx2]
  107. T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
  108. :4, :4
  109. ] # (4, 4)
  110. # Load positive pair data
  111. im_A, im_B = self.image_paths[idx1], self.image_paths[idx2]
  112. depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
  113. im_A_ref = os.path.join(self.data_root, im_A)
  114. im_B_ref = os.path.join(self.data_root, im_B)
  115. depth_A_ref = os.path.join(self.data_root, depth1)
  116. depth_B_ref = os.path.join(self.data_root, depth2)
  117. im_A = self.load_im(im_A_ref)
  118. im_B = self.load_im(im_B_ref)
  119. K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
  120. K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
  121. if self.use_randaug:
  122. im_A, im_B = self.rand_augment(im_A, im_B)
  123. depth_A = self.load_depth(depth_A_ref)
  124. depth_B = self.load_depth(depth_B_ref)
  125. # Process images
  126. im_A, im_B = self.im_transform_ops((im_A, im_B))
  127. depth_A, depth_B = self.depth_transform_ops(
  128. (depth_A[None, None], depth_B[None, None])
  129. )
  130. [im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B)
  131. K1[:2, 2] += t
  132. K2[:2, 2] += t
  133. im_A, im_B = im_A[None], im_B[None]
  134. if self.random_eraser is not None:
  135. im_A, depth_A = self.random_eraser(im_A, depth_A)
  136. im_B, depth_B = self.random_eraser(im_B, depth_B)
  137. if self.use_horizontal_flip_aug:
  138. if np.random.rand() > 0.5:
  139. im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
  140. if self.use_single_horizontal_flip_aug:
  141. if np.random.rand() > 0.5:
  142. im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2)
  143. if romatch.DEBUG_MODE:
  144. tensor_to_pil(im_A[0], unnormalize=True).save(
  145. f"vis/im_A.jpg")
  146. tensor_to_pil(im_B[0], unnormalize=True).save(
  147. f"vis/im_B.jpg")
  148. data_dict = {
  149. "im_A": im_A[0],
  150. "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
  151. "im_B": im_B[0],
  152. "im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0],
  153. "im_A_depth": depth_A[0, 0],
  154. "im_B_depth": depth_B[0, 0],
  155. "K1": K1,
  156. "K2": K2,
  157. "T_1to2": T_1to2,
  158. "im_A_path": im_A_ref,
  159. "im_B_path": im_B_ref,
  160. }
  161. return data_dict
  162. class MegadepthBuilder:
  163. def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None:
  164. self.data_root = data_root
  165. self.scene_info_root = os.path.join(data_root, "prep_scene_info")
  166. self.all_scenes = os.listdir(self.scene_info_root)
  167. self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
  168. # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
  169. self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy'])
  170. self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy'])
  171. self.test_scenes_loftr = ["0015.npy", "0022.npy"]
  172. self.loftr_ignore = loftr_ignore
  173. self.imc21_ignore = imc21_ignore
  174. def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs):
  175. if split == "train":
  176. scene_names = set(self.all_scenes) - set(self.test_scenes)
  177. elif split == "train_loftr":
  178. scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
  179. elif split == "test":
  180. scene_names = self.test_scenes
  181. elif split == "test_loftr":
  182. scene_names = self.test_scenes_loftr
  183. elif split == "custom":
  184. scene_names = scene_names
  185. else:
  186. raise ValueError(f"Split {split} not available")
  187. scenes = []
  188. for scene_name in scene_names:
  189. if self.loftr_ignore and scene_name in self.loftr_ignore_scenes:
  190. continue
  191. if self.imc21_ignore and scene_name in self.imc21_scenes:
  192. continue
  193. if ".npy" not in scene_name:
  194. continue
  195. scene_info = np.load(
  196. os.path.join(self.scene_info_root, scene_name), allow_pickle=True
  197. ).item()
  198. scenes.append(
  199. MegadepthScene(
  200. self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs
  201. )
  202. )
  203. return scenes
  204. def weight_scenes(self, concat_dataset, alpha=0.5):
  205. ns = []
  206. for d in concat_dataset.datasets:
  207. ns.append(len(d))
  208. ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
  209. return ws