eval_tiny_roma_v1_outdoor.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import torch
  2. import os
  3. from pathlib import Path
  4. import json
  5. from romatch.benchmarks import ScanNetBenchmark
  6. from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
  7. from romatch.benchmarks import MegaDepthPoseEstimationBenchmark
  8. def test_mega_8_scenes(model, name):
  9. mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
  10. scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
  11. 'mega_8_scenes_0025_0.1_0.3.npz',
  12. 'mega_8_scenes_0021_0.1_0.3.npz',
  13. 'mega_8_scenes_0008_0.1_0.3.npz',
  14. 'mega_8_scenes_0032_0.1_0.3.npz',
  15. 'mega_8_scenes_1589_0.1_0.3.npz',
  16. 'mega_8_scenes_0063_0.1_0.3.npz',
  17. 'mega_8_scenes_0024_0.1_0.3.npz',
  18. 'mega_8_scenes_0019_0.3_0.5.npz',
  19. 'mega_8_scenes_0025_0.3_0.5.npz',
  20. 'mega_8_scenes_0021_0.3_0.5.npz',
  21. 'mega_8_scenes_0008_0.3_0.5.npz',
  22. 'mega_8_scenes_0032_0.3_0.5.npz',
  23. 'mega_8_scenes_1589_0.3_0.5.npz',
  24. 'mega_8_scenes_0063_0.3_0.5.npz',
  25. 'mega_8_scenes_0024_0.3_0.5.npz'])
  26. mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
  27. print(mega_8_scenes_results)
  28. json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
  29. def test_mega1500(model, name):
  30. mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
  31. mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
  32. json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
  33. def test_mega1500_poselib(model, name):
  34. #model.exact_softmax = True
  35. mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
  36. mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
  37. json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
  38. def test_mega_8_scenes_poselib(model, name):
  39. mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
  40. scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
  41. 'mega_8_scenes_0025_0.1_0.3.npz',
  42. 'mega_8_scenes_0021_0.1_0.3.npz',
  43. 'mega_8_scenes_0008_0.1_0.3.npz',
  44. 'mega_8_scenes_0032_0.1_0.3.npz',
  45. 'mega_8_scenes_1589_0.1_0.3.npz',
  46. 'mega_8_scenes_0063_0.1_0.3.npz',
  47. 'mega_8_scenes_0024_0.1_0.3.npz',
  48. 'mega_8_scenes_0019_0.3_0.5.npz',
  49. 'mega_8_scenes_0025_0.3_0.5.npz',
  50. 'mega_8_scenes_0021_0.3_0.5.npz',
  51. 'mega_8_scenes_0008_0.3_0.5.npz',
  52. 'mega_8_scenes_0032_0.3_0.5.npz',
  53. 'mega_8_scenes_1589_0.3_0.5.npz',
  54. 'mega_8_scenes_0063_0.3_0.5.npz',
  55. 'mega_8_scenes_0024_0.3_0.5.npz'])
  56. mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
  57. json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
  58. def test_scannet_poselib(model, name):
  59. scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
  60. scannet_results = scannet_benchmark.benchmark(model)
  61. json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
  62. def test_scannet(model, name):
  63. scannet_benchmark = ScanNetBenchmark("data/scannet")
  64. scannet_results = scannet_benchmark.benchmark(model)
  65. json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
  66. if __name__ == "__main__":
  67. os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
  68. os.environ["OMP_NUM_THREADS"] = "16"
  69. torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
  70. from romatch import tiny_roma_v1_outdoor
  71. experiment_name = Path(__file__).stem
  72. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  73. model = tiny_roma_v1_outdoor(device)
  74. #test_mega1500_poselib(model, experiment_name)
  75. test_mega_8_scenes_poselib(model, experiment_name)