scannet.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import os
  2. import random
  3. from PIL import Image
  4. import cv2
  5. import h5py
  6. import numpy as np
  7. import torch
  8. from torch.utils.data import (
  9. Dataset,
  10. DataLoader,
  11. ConcatDataset)
  12. import torchvision.transforms.functional as tvf
  13. import kornia.augmentation as K
  14. import os.path as osp
  15. import matplotlib.pyplot as plt
  16. import romatch
  17. from romatch.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
  18. from romatch.utils.transforms import GeometricSequential
  19. from tqdm import tqdm
  20. class ScanNetScene:
  21. def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False,
  22. ) -> None:
  23. self.scene_root = osp.join(data_root,"scans","scans_train")
  24. self.data_names = scene_info['name']
  25. self.overlaps = scene_info['score']
  26. # Only sample 10s
  27. valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
  28. self.overlaps = self.overlaps[valid]
  29. self.data_names = self.data_names[valid]
  30. if len(self.data_names) > 10000:
  31. pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
  32. self.data_names = self.data_names[pairinds]
  33. self.overlaps = self.overlaps[pairinds]
  34. self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
  35. self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
  36. self.wt, self.ht = wt, ht
  37. self.shake_t = shake_t
  38. self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
  39. self.use_horizontal_flip_aug = use_horizontal_flip_aug
  40. def load_im(self, im_B, crop=None):
  41. im = Image.open(im_B)
  42. return im
  43. def load_depth(self, depth_ref, crop=None):
  44. depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
  45. depth = depth / 1000
  46. depth = torch.from_numpy(depth).float() # (h, w)
  47. return depth
  48. def __len__(self):
  49. return len(self.data_names)
  50. def scale_intrinsic(self, K, wi, hi):
  51. sx, sy = self.wt / wi, self.ht / hi
  52. sK = torch.tensor([[sx, 0, 0],
  53. [0, sy, 0],
  54. [0, 0, 1]])
  55. return sK@K
  56. def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
  57. im_A = im_A.flip(-1)
  58. im_B = im_B.flip(-1)
  59. depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
  60. flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
  61. K_A = flip_mat@K_A
  62. K_B = flip_mat@K_B
  63. return im_A, im_B, depth_A, depth_B, K_A, K_B
  64. def read_scannet_pose(self,path):
  65. """ Read ScanNet's Camera2World pose and transform it to World2Camera.
  66. Returns:
  67. pose_w2c (np.ndarray): (4, 4)
  68. """
  69. cam2world = np.loadtxt(path, delimiter=' ')
  70. world2cam = np.linalg.inv(cam2world)
  71. return world2cam
  72. def read_scannet_intrinsic(self,path):
  73. """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
  74. """
  75. intrinsic = np.loadtxt(path, delimiter=' ')
  76. return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float)
  77. def __getitem__(self, pair_idx):
  78. # read intrinsics of original size
  79. data_name = self.data_names[pair_idx]
  80. scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
  81. scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
  82. # read the intrinsic of depthmap
  83. K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root,
  84. scene_name,
  85. 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
  86. # read and compute relative poses
  87. T1 = self.read_scannet_pose(osp.join(self.scene_root,
  88. scene_name,
  89. 'pose', f'{stem_name_1}.txt'))
  90. T2 = self.read_scannet_pose(osp.join(self.scene_root,
  91. scene_name,
  92. 'pose', f'{stem_name_2}.txt'))
  93. T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4)
  94. # Load positive pair data
  95. im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
  96. im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
  97. depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
  98. depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
  99. im_A = self.load_im(im_A_ref)
  100. im_B = self.load_im(im_B_ref)
  101. depth_A = self.load_depth(depth_A_ref)
  102. depth_B = self.load_depth(depth_B_ref)
  103. # Recompute camera intrinsic matrix due to the resize
  104. K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
  105. K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
  106. # Process images
  107. im_A, im_B = self.im_transform_ops((im_A, im_B))
  108. depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None]))
  109. if self.use_horizontal_flip_aug:
  110. if np.random.rand() > 0.5:
  111. im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
  112. data_dict = {'im_A': im_A,
  113. 'im_B': im_B,
  114. 'im_A_depth': depth_A[0,0],
  115. 'im_B_depth': depth_B[0,0],
  116. 'K1': K1,
  117. 'K2': K2,
  118. 'T_1to2':T_1to2,
  119. }
  120. return data_dict
  121. class ScanNetBuilder:
  122. def __init__(self, data_root = 'data/scannet') -> None:
  123. self.data_root = data_root
  124. self.scene_info_root = os.path.join(data_root,'scannet_indices')
  125. self.all_scenes = os.listdir(self.scene_info_root)
  126. def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
  127. # Note: split doesn't matter here as we always use same scannet_train scenes
  128. scene_names = self.all_scenes
  129. scenes = []
  130. for scene_name in tqdm(scene_names, disable = romatch.RANK > 0):
  131. scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
  132. scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
  133. return scenes
  134. def weight_scenes(self, concat_dataset, alpha=.5):
  135. ns = []
  136. for d in concat_dataset.datasets:
  137. ns.append(len(d))
  138. ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
  139. return ws