test_roma_coarse_inference_time_cpu.py 731 B

1234567891011121314151617181920212223242526272829
  1. from romatch import roma_outdoor
  2. import torch
  3. from tqdm import tqdm
  4. import time
  5. def test_inference_time(model, name):
  6. T = 5
  7. im_A = torch.randn(8, 3, 560, 560).to(device)
  8. im_B = torch.randn(8, 3, 560, 560).to(device)
  9. start_time = time.time()
  10. for t in tqdm(range(T)):
  11. model.match(im_A, im_B, batched=True)
  12. end_time = time.time()
  13. return (end_time - start_time) / T
  14. if __name__ == "__main__":
  15. device = "cpu"
  16. model = roma_outdoor(
  17. device=device,
  18. coarse_res=560,
  19. upsample_res=None,
  20. symmetric=True,
  21. upsample_preds=False,
  22. use_custom_corr=True,
  23. )
  24. experiment_name = "roma_latest"
  25. results = test_inference_time(model, experiment_name)