demo_match.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import os
  2. os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
  3. import torch
  4. from PIL import Image
  5. import torch.nn.functional as F
  6. import numpy as np
  7. from romatch.utils.utils import tensor_to_pil
  8. from romatch import roma_outdoor
  9. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  10. if torch.backends.mps.is_available():
  11. device = torch.device('mps')
  12. if __name__ == "__main__":
  13. from argparse import ArgumentParser
  14. parser = ArgumentParser()
  15. parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
  16. parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
  17. parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
  18. args, _ = parser.parse_known_args()
  19. im1_path = args.im_A_path
  20. im2_path = args.im_B_path
  21. save_path = args.save_path
  22. # Create model
  23. roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
  24. H, W = roma_model.get_output_resolution()
  25. im1 = Image.open(im1_path).resize((W, H))
  26. im2 = Image.open(im2_path).resize((W, H))
  27. # Match
  28. warp, certainty = roma_model.match(im1_path, im2_path, device=device)
  29. # Sampling not needed, but can be done with model.sample(warp, certainty)
  30. x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
  31. x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
  32. im2_transfer_rgb = F.grid_sample(
  33. x2[None], warp[:, :, :W, 2:], mode="bilinear", align_corners=False
  34. )[0]
  35. im1_transfer_rgb = F.grid_sample(
  36. x1[None], warp[:, :, W:, :2], mode="bilinear", align_corners=False
  37. )[0]
  38. warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
  39. white_im = torch.ones((H,2*W),device=device)
  40. vis_im = certainty * warp_im + (1 - certainty) * white_im
  41. tensor_to_pil(vis_im, unnormalize=False).save(save_path)