megadepth_dense_benchmark.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import torch
  2. import numpy as np
  3. import tqdm
  4. from romatch.datasets import MegadepthBuilder
  5. from romatch.utils import warp_kpts
  6. from torch.utils.data import ConcatDataset
  7. import romatch
  8. class MegadepthDenseBenchmark:
  9. def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
  10. mega = MegadepthBuilder(data_root=data_root)
  11. self.dataset = ConcatDataset(
  12. mega.build_scenes(split="test_loftr", ht=h, wt=w)
  13. ) # fixed resolution of 384,512
  14. self.num_samples = num_samples
  15. def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
  16. b, h1, w1, d = dense_matches.shape
  17. with torch.no_grad():
  18. x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
  19. mask, x2 = warp_kpts(
  20. x1.double(),
  21. depth1.double(),
  22. depth2.double(),
  23. T_1to2.double(),
  24. K1.double(),
  25. K2.double(),
  26. )
  27. x2 = torch.stack(
  28. (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
  29. )
  30. prob = mask.float().reshape(b, h1, w1)
  31. x2_hat = dense_matches[..., 2:]
  32. x2_hat = torch.stack(
  33. (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
  34. )
  35. gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
  36. gd = gd[prob == 1]
  37. pck_1 = (gd < 1.0).float().mean()
  38. pck_3 = (gd < 3.0).float().mean()
  39. pck_5 = (gd < 5.0).float().mean()
  40. return gd, pck_1, pck_3, pck_5, prob
  41. def benchmark(self, model, batch_size=8):
  42. model.train(False)
  43. with torch.no_grad():
  44. gd_tot = 0.0
  45. pck_1_tot = 0.0
  46. pck_3_tot = 0.0
  47. pck_5_tot = 0.0
  48. sampler = torch.utils.data.WeightedRandomSampler(
  49. torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
  50. )
  51. B = batch_size
  52. dataloader = torch.utils.data.DataLoader(
  53. self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
  54. )
  55. for idx, data in tqdm.tqdm(enumerate(dataloader), disable = romatch.RANK > 0):
  56. im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
  57. data["im_A"].cuda(),
  58. data["im_B"].cuda(),
  59. data["im_A_depth"].cuda(),
  60. data["im_B_depth"].cuda(),
  61. data["T_1to2"].cuda(),
  62. data["K1"].cuda(),
  63. data["K2"].cuda(),
  64. )
  65. matches, certainty = model.match(im_A, im_B, batched=True)
  66. gd, pck_1, pck_3, pck_5, prob = self.geometric_dist(
  67. depth1, depth2, T_1to2, K1, K2, matches
  68. )
  69. if romatch.DEBUG_MODE:
  70. from romatch.utils.utils import tensor_to_pil
  71. import torch.nn.functional as F
  72. path = "vis"
  73. H, W = model.get_output_resolution()
  74. white_im = torch.ones((B,1,H,W),device="cuda")
  75. im_B_transfer_rgb = F.grid_sample(
  76. im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
  77. )
  78. warp_im = im_B_transfer_rgb
  79. c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
  80. vis_im = c_b * warp_im + (1 - c_b) * white_im
  81. for b in range(B):
  82. import os
  83. os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
  84. tensor_to_pil(vis_im[b], unnormalize=True).save(
  85. f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
  86. tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
  87. f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
  88. tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
  89. f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
  90. gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
  91. gd_tot + gd.mean(),
  92. pck_1_tot + pck_1,
  93. pck_3_tot + pck_3,
  94. pck_5_tot + pck_5,
  95. )
  96. return {
  97. "epe": gd_tot.item() / len(dataloader),
  98. "mega_pck_1": pck_1_tot.item() / len(dataloader),
  99. "mega_pck_3": pck_3_tot.item() / len(dataloader),
  100. "mega_pck_5": pck_5_tot.item() / len(dataloader),
  101. }