demo_match_tiny.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import os
  2. os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
  3. import torch
  4. from PIL import Image
  5. import torch.nn.functional as F
  6. import numpy as np
  7. from romatch.utils.utils import tensor_to_pil
  8. from romatch import tiny_roma_v1_outdoor
  9. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  10. if torch.backends.mps.is_available():
  11. device = torch.device('mps')
  12. if __name__ == "__main__":
  13. from argparse import ArgumentParser
  14. parser = ArgumentParser()
  15. parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
  16. parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
  17. parser.add_argument("--save_A_path", default="demo/tiny_roma_warp_A.jpg", type=str)
  18. parser.add_argument("--save_B_path", default="demo/tiny_roma_warp_B.jpg", type=str)
  19. args, _ = parser.parse_known_args()
  20. im1_path = args.im_A_path
  21. im2_path = args.im_B_path
  22. # Create model
  23. roma_model = tiny_roma_v1_outdoor(device=device)
  24. # Match
  25. warp, certainty1 = roma_model.match(im1_path, im2_path)
  26. h1, w1 = warp.shape[:2]
  27. # maybe im1.size != im2.size
  28. im1 = Image.open(im1_path).resize((w1, h1))
  29. im2 = Image.open(im2_path)
  30. x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
  31. x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
  32. h2, w2 = x2.shape[1:]
  33. g1_p2x = w2 / 2 * (warp[..., 2] + 1)
  34. g1_p2y = h2 / 2 * (warp[..., 3] + 1)
  35. g2_p1x = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
  36. g2_p1y = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
  37. x, y = torch.meshgrid(
  38. torch.arange(w1, device=device),
  39. torch.arange(h1, device=device),
  40. indexing="xy",
  41. )
  42. g2x = torch.round(g1_p2x[y, x]).long()
  43. g2y = torch.round(g1_p2y[y, x]).long()
  44. idx_x = torch.bitwise_and(0 <= g2x, g2x < w2)
  45. idx_y = torch.bitwise_and(0 <= g2y, g2y < h2)
  46. idx = torch.bitwise_and(idx_x, idx_y)
  47. g2_p1x[g2y[idx], g2x[idx]] = x[idx].float() * 2 / w1 - 1
  48. g2_p1y[g2y[idx], g2x[idx]] = y[idx].float() * 2 / h1 - 1
  49. certainty2 = F.grid_sample(
  50. certainty1[None][None],
  51. torch.stack([g2_p1x, g2_p1y], dim=2)[None],
  52. mode="bilinear",
  53. align_corners=False,
  54. )[0]
  55. white_im1 = torch.ones((h1, w1), device = device)
  56. white_im2 = torch.ones((h2, w2), device = device)
  57. certainty1 = F.avg_pool2d(certainty1[None], kernel_size=5, stride=1, padding=2)[0]
  58. certainty2 = F.avg_pool2d(certainty2[None], kernel_size=5, stride=1, padding=2)[0]
  59. vis_im1 = certainty1 * x1 + (1 - certainty1) * white_im1
  60. vis_im2 = certainty2 * x2 + (1 - certainty2) * white_im2
  61. tensor_to_pil(vis_im1, unnormalize=False).save(args.save_A_path)
  62. tensor_to_pil(vis_im2, unnormalize=False).save(args.save_B_path)