test_match_modes.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import torch
  2. from PIL import Image
  3. from romatch.models.matcher import RegressionMatcher
  4. def test_bs_one_tensor_inputs(model: RegressionMatcher, device, coarse_res: int, upsample_res: int):
  5. model.match(
  6. torch.randn(1, 3, coarse_res, coarse_res).to(device),
  7. torch.randn(1, 3, coarse_res, coarse_res).to(device),
  8. im_A_high_res=torch.randn(1, 3, upsample_res, upsample_res).to(device),
  9. im_B_high_res=torch.randn(1, 3, upsample_res, upsample_res).to(device),
  10. )
  11. def test_bs_8_tensor_inputs(model: RegressionMatcher, device, coarse_res: int, upsample_res: int):
  12. model.match(
  13. torch.randn(8, 3, coarse_res, coarse_res).to(device),
  14. torch.randn(8, 3, coarse_res, coarse_res).to(device),
  15. im_A_high_res=torch.randn(8, 3, upsample_res, upsample_res).to(device),
  16. im_B_high_res=torch.randn(8, 3, upsample_res, upsample_res).to(device),
  17. )
  18. def test_pil_inputs(model: RegressionMatcher):
  19. model.match(Image.open("assets/toronto_A.jpg"), Image.open("assets/toronto_B.jpg"))
  20. def test_str_inputs(model: RegressionMatcher):
  21. model.match("assets/toronto_A.jpg", "assets/toronto_B.jpg")
  22. if __name__ == "__main__":
  23. from romatch import roma_outdoor
  24. coarse_res = 560
  25. upsample_res = 1152
  26. for device in [torch.device("cuda")]:
  27. model = roma_outdoor(
  28. device=device,
  29. coarse_res=coarse_res,
  30. upsample_res=upsample_res,
  31. use_custom_corr=True,
  32. symmetric=True,
  33. upsample_preds=True,
  34. )
  35. for is_symmetric in [True, False]:
  36. for upsample_preds in [True, False]:
  37. for batched in [True, False]:
  38. model.symmetric = is_symmetric
  39. model.upsample_preds = upsample_preds
  40. model.batched = batched
  41. test_bs_one_tensor_inputs(model, device, coarse_res, upsample_res)
  42. test_bs_8_tensor_inputs(model, device, coarse_res, upsample_res)
  43. test_pil_inputs(model)
  44. test_str_inputs(model)
  45. print(f"Done with {is_symmetric=}, {upsample_preds=}, {batched=}, {device=}")
  46. for device in [torch.device("cpu")]:
  47. model = roma_outdoor(
  48. device=device,
  49. coarse_res=coarse_res,
  50. upsample_res=upsample_res,
  51. use_custom_corr=True,
  52. symmetric=True,
  53. upsample_preds=True,
  54. )
  55. model.symmetric = is_symmetric
  56. model.upsample_preds = upsample_preds
  57. model.batched = batched
  58. model.device = device
  59. test_bs_one_tensor_inputs(model, device, coarse_res, upsample_res)
  60. test_bs_8_tensor_inputs(model, device, coarse_res, upsample_res)
  61. test_pil_inputs(model)
  62. test_str_inputs(model)
  63. print(f"Done with {is_symmetric=}, {upsample_preds=}, {batched=}, {device=}")