eval_roma_outdoor.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import os
  2. import torch
  3. from argparse import ArgumentParser
  4. from torch import nn
  5. from torch.utils.data import ConcatDataset
  6. import torch.distributed as dist
  7. from torch.nn.parallel import DistributedDataParallel as DDP
  8. import json
  9. import wandb
  10. from romatch.benchmarks import MegadepthDenseBenchmark
  11. from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
  12. from romatch.benchmarks import Mega1500PoseLibBenchmark
  13. def test_mega_8_scenes(model, name):
  14. mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
  15. scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
  16. 'mega_8_scenes_0025_0.1_0.3.npz',
  17. 'mega_8_scenes_0021_0.1_0.3.npz',
  18. 'mega_8_scenes_0008_0.1_0.3.npz',
  19. 'mega_8_scenes_0032_0.1_0.3.npz',
  20. 'mega_8_scenes_1589_0.1_0.3.npz',
  21. 'mega_8_scenes_0063_0.1_0.3.npz',
  22. 'mega_8_scenes_0024_0.1_0.3.npz',
  23. 'mega_8_scenes_0019_0.3_0.5.npz',
  24. 'mega_8_scenes_0025_0.3_0.5.npz',
  25. 'mega_8_scenes_0021_0.3_0.5.npz',
  26. 'mega_8_scenes_0008_0.3_0.5.npz',
  27. 'mega_8_scenes_0032_0.3_0.5.npz',
  28. 'mega_8_scenes_1589_0.3_0.5.npz',
  29. 'mega_8_scenes_0063_0.3_0.5.npz',
  30. 'mega_8_scenes_0024_0.3_0.5.npz'])
  31. mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
  32. print(mega_8_scenes_results)
  33. json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
  34. def test_mega1500(model, name):
  35. mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
  36. mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
  37. json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
  38. def test_mega1500_poselib(model, name):
  39. mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth")
  40. mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
  41. json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
  42. def test_mega_dense(model, name):
  43. megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
  44. megadense_results = megadense_benchmark.benchmark(model)
  45. json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
  46. def test_hpatches(model, name):
  47. hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
  48. hpatches_results = hpatches_benchmark.benchmark(model)
  49. json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
  50. if __name__ == "__main__":
  51. from romatch import roma_outdoor
  52. model = roma_outdoor(device = "cuda", coarse_res = 672, upsample_res = 1344)
  53. experiment_name = "roma_latest"
  54. #test_mega1500(model, experiment_name)
  55. test_mega1500_poselib(model, experiment_name)