test_roma_coarse_inference_time.py 818 B

1234567891011121314151617181920212223242526272829303132
  1. from romatch import roma_outdoor
  2. import torch
  3. from tqdm import tqdm
  4. import time
  5. def test_inference_time(model, name):
  6. T = 100
  7. im_A = torch.randn(8, 3, 560, 560).to(device)
  8. im_B = torch.randn(8, 3, 560, 560).to(device)
  9. # burn in
  10. for i in range(10):
  11. model.match(im_A, im_B, batched=True)
  12. start_time = time.time()
  13. for t in tqdm(range(T)):
  14. model.match(im_A, im_B, batched=True)
  15. end_time = time.time()
  16. return (end_time - start_time) / T
  17. if __name__ == "__main__":
  18. device = "cuda"
  19. model = roma_outdoor(
  20. device=device,
  21. coarse_res=560,
  22. upsample_res=None,
  23. symmetric=True,
  24. upsample_preds=False,
  25. use_custom_corr=True,
  26. )
  27. experiment_name = "roma_latest"
  28. results = test_inference_time(model, experiment_name)