|
|
@@ -250,7 +250,7 @@ def train(args):
|
|
|
checkpointer.save(model, optimizer, lr_scheduler, roma.GLOBAL_STEP)
|
|
|
wandb.log(megadense_benchmark.benchmark(model), step = roma.GLOBAL_STEP)
|
|
|
|
|
|
-def test_mega_8_scenes(model, name, resolution, sample_mode):
|
|
|
+def test_mega_8_scenes(model, name):
|
|
|
mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
|
|
|
scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
|
|
|
'mega_8_scenes_0025_0.1_0.3.npz',
|
|
|
@@ -268,21 +268,21 @@ def test_mega_8_scenes(model, name, resolution, sample_mode):
|
|
|
'mega_8_scenes_1589_0.3_0.5.npz',
|
|
|
'mega_8_scenes_0063_0.3_0.5.npz',
|
|
|
'mega_8_scenes_0024_0.3_0.5.npz'])
|
|
|
- mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name, scale_intrinsics = False)
|
|
|
+ mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
|
|
|
print(mega_8_scenes_results)
|
|
|
json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
|
|
|
|
|
|
-def test_mega1500(model, name, resolution, sample_mode):
|
|
|
+def test_mega1500(model, name):
|
|
|
mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
|
|
|
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
|
|
|
json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
|
|
|
|
|
|
-def test_mega_dense(model, name, resolution, sample_mode):
|
|
|
+def test_mega_dense(model, name):
|
|
|
megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
|
|
|
megadense_results = megadense_benchmark.benchmark(model)
|
|
|
json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
|
|
|
|
|
|
-def test_hpatches(model, name, resolution, sample_mode):
|
|
|
+def test_hpatches(model, name):
|
|
|
hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
|
|
|
hpatches_results = hpatches_benchmark.benchmark(model)
|
|
|
json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
|
|
|
@@ -306,18 +306,3 @@ if __name__ == "__main__":
|
|
|
roma.DEBUG_MODE = args.debug_mode
|
|
|
if not args.only_test:
|
|
|
train(args)
|
|
|
- experiment_name = os.path.splitext(os.path.basename(__file__))[0]
|
|
|
- checkpoint_dir = "workspace/checkpoints/"
|
|
|
- checkpoint_name = checkpoint_dir + experiment_name + ".pth"
|
|
|
-
|
|
|
- test_resolution = "high"
|
|
|
- sample_mode = "threshold_balanced"
|
|
|
- symmetric = True
|
|
|
- upsample_preds = True
|
|
|
- attenuate_cert = True
|
|
|
-
|
|
|
- model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert)
|
|
|
- model = model.cuda()
|
|
|
- weights = torch.load(checkpoint_name)
|
|
|
- model.load_state_dict(weights)
|
|
|
- test_mega1500(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode)
|