scannet_benchmark.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os.path as osp
  2. import numpy as np
  3. import torch
  4. from romatch.utils import *
  5. from PIL import Image
  6. from tqdm import tqdm
  7. class ScanNetBenchmark:
  8. def __init__(self, data_root="data/scannet") -> None:
  9. self.data_root = data_root
  10. def benchmark(self, model, model_name = None):
  11. model.train(False)
  12. with torch.no_grad():
  13. data_root = self.data_root
  14. tmp = np.load(osp.join(data_root, "test.npz"))
  15. pairs, rel_pose = tmp["name"], tmp["rel_pose"]
  16. tot_e_t, tot_e_R, tot_e_pose = [], [], []
  17. pair_inds = np.random.choice(
  18. range(len(pairs)), size=len(pairs), replace=False
  19. )
  20. for pairind in tqdm(pair_inds, smoothing=0.9):
  21. scene = pairs[pairind]
  22. scene_name = f"scene0{scene[0]}_00"
  23. im_A_path = osp.join(
  24. self.data_root,
  25. "scans_test",
  26. scene_name,
  27. "color",
  28. f"{scene[2]}.jpg",
  29. )
  30. im_A = Image.open(im_A_path)
  31. im_B_path = osp.join(
  32. self.data_root,
  33. "scans_test",
  34. scene_name,
  35. "color",
  36. f"{scene[3]}.jpg",
  37. )
  38. im_B = Image.open(im_B_path)
  39. T_gt = rel_pose[pairind].reshape(3, 4)
  40. R, t = T_gt[:3, :3], T_gt[:3, 3]
  41. K = np.stack(
  42. [
  43. np.array([float(i) for i in r.split()])
  44. for r in open(
  45. osp.join(
  46. self.data_root,
  47. "scans_test",
  48. scene_name,
  49. "intrinsic",
  50. "intrinsic_color.txt",
  51. ),
  52. "r",
  53. )
  54. .read()
  55. .split("\n")
  56. if r
  57. ]
  58. )
  59. w1, h1 = im_A.size
  60. w2, h2 = im_B.size
  61. K1 = K.copy()
  62. K2 = K.copy()
  63. dense_matches, dense_certainty = model.match(im_A_path, im_B_path)
  64. sparse_matches, sparse_certainty = model.sample(
  65. dense_matches, dense_certainty, 5000
  66. )
  67. scale1 = 480 / min(w1, h1)
  68. scale2 = 480 / min(w2, h2)
  69. w1, h1 = scale1 * w1, scale1 * h1
  70. w2, h2 = scale2 * w2, scale2 * h2
  71. K1 = K1 * scale1
  72. K2 = K2 * scale2
  73. offset = 0.5
  74. kpts1 = sparse_matches[:, :2]
  75. kpts1 = (
  76. np.stack(
  77. (
  78. w1 * (kpts1[:, 0] + 1) / 2 - offset,
  79. h1 * (kpts1[:, 1] + 1) / 2 - offset,
  80. ),
  81. axis=-1,
  82. )
  83. )
  84. kpts2 = sparse_matches[:, 2:]
  85. kpts2 = (
  86. np.stack(
  87. (
  88. w2 * (kpts2[:, 0] + 1) / 2 - offset,
  89. h2 * (kpts2[:, 1] + 1) / 2 - offset,
  90. ),
  91. axis=-1,
  92. )
  93. )
  94. for _ in range(5):
  95. shuffling = np.random.permutation(np.arange(len(kpts1)))
  96. kpts1 = kpts1[shuffling]
  97. kpts2 = kpts2[shuffling]
  98. try:
  99. norm_threshold = 0.5 / (
  100. np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
  101. R_est, t_est, mask = estimate_pose(
  102. kpts1,
  103. kpts2,
  104. K1,
  105. K2,
  106. norm_threshold,
  107. conf=0.99999,
  108. )
  109. T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
  110. e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
  111. e_pose = max(e_t, e_R)
  112. except Exception as e:
  113. print(repr(e))
  114. e_t, e_R = 90, 90
  115. e_pose = max(e_t, e_R)
  116. tot_e_t.append(e_t)
  117. tot_e_R.append(e_R)
  118. tot_e_pose.append(e_pose)
  119. tot_e_t.append(e_t)
  120. tot_e_R.append(e_R)
  121. tot_e_pose.append(e_pose)
  122. tot_e_pose = np.array(tot_e_pose)
  123. thresholds = [5, 10, 20]
  124. auc = pose_auc(tot_e_pose, thresholds)
  125. acc_5 = (tot_e_pose < 5).mean()
  126. acc_10 = (tot_e_pose < 10).mean()
  127. acc_15 = (tot_e_pose < 15).mean()
  128. acc_20 = (tot_e_pose < 20).mean()
  129. map_5 = acc_5
  130. map_10 = np.mean([acc_5, acc_10])
  131. map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
  132. return {
  133. "auc_5": auc[0],
  134. "auc_10": auc[1],
  135. "auc_20": auc[2],
  136. "map_5": map_5,
  137. "map_10": map_10,
  138. "map_20": map_20,
  139. }