| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- from romatch import roma_outdoor
- import torch
- from tqdm import tqdm
- import time
- def test_inference_time(model, name):
- T = 1000
- im_A = torch.randn(8, 3, 560, 560).to(device)
- im_B = torch.randn(8, 3, 560, 560).to(device)
- im_A_high_res = torch.randn(8, 3, 864, 864).to(device)
- im_B_high_res = torch.randn(8, 3, 864, 864).to(device)
- # burn in
- for i in range(10):
- model.match(
- im_A,
- im_B,
- im_A_high_res=im_A_high_res,
- im_B_high_res=im_B_high_res,
- batched=True,
- )
- start_time = time.time()
- for t in tqdm(range(T)):
- model.match(
- im_A,
- im_B,
- im_A_high_res=im_A_high_res,
- im_B_high_res=im_B_high_res,
- batched=True,
- )
- end_time = time.time()
- return (end_time - start_time) / T
- if __name__ == "__main__":
- device = "cuda"
- model = roma_outdoor(
- device=device,
- coarse_res=560,
- upsample_res=864,
- symmetric=True,
- upsample_preds=True,
- use_custom_corr=True,
- amp_dtype=torch.bfloat16,
- )
- experiment_name = "roma_latest"
- results = test_inference_time(model, experiment_name)
|