test_roma_upsample_inference_time.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from romatch import roma_outdoor
  2. import torch
  3. from tqdm import tqdm
  4. import time
  5. def test_inference_time(model, name):
  6. T = 1000
  7. im_A = torch.randn(8, 3, 560, 560).to(device)
  8. im_B = torch.randn(8, 3, 560, 560).to(device)
  9. im_A_high_res = torch.randn(8, 3, 864, 864).to(device)
  10. im_B_high_res = torch.randn(8, 3, 864, 864).to(device)
  11. # burn in
  12. for i in range(10):
  13. model.match(
  14. im_A,
  15. im_B,
  16. im_A_high_res=im_A_high_res,
  17. im_B_high_res=im_B_high_res,
  18. batched=True,
  19. )
  20. start_time = time.time()
  21. for t in tqdm(range(T)):
  22. model.match(
  23. im_A,
  24. im_B,
  25. im_A_high_res=im_A_high_res,
  26. im_B_high_res=im_B_high_res,
  27. batched=True,
  28. )
  29. end_time = time.time()
  30. return (end_time - start_time) / T
  31. if __name__ == "__main__":
  32. device = "cuda"
  33. model = roma_outdoor(
  34. device=device,
  35. coarse_res=560,
  36. upsample_res=864,
  37. symmetric=True,
  38. upsample_preds=True,
  39. use_custom_corr=True,
  40. amp_dtype=torch.bfloat16,
  41. )
  42. experiment_name = "roma_latest"
  43. results = test_inference_time(model, experiment_name)