demo_3D_effect.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from PIL import Image
  2. import torch
  3. import torch.nn.functional as F
  4. import numpy as np
  5. from romatch.utils.utils import tensor_to_pil
  6. from romatch import roma_outdoor
  7. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  8. if __name__ == "__main__":
  9. from argparse import ArgumentParser
  10. parser = ArgumentParser()
  11. parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
  12. parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
  13. parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str)
  14. args, _ = parser.parse_known_args()
  15. im1_path = args.im_A_path
  16. im2_path = args.im_B_path
  17. save_path = args.save_path
  18. # Create model
  19. roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
  20. roma_model.symmetric = False
  21. H, W = roma_model.get_output_resolution()
  22. im1 = Image.open(im1_path).resize((W, H))
  23. im2 = Image.open(im2_path).resize((W, H))
  24. # Match
  25. warp, certainty = roma_model.match(im1_path, im2_path, device=device)
  26. # Sampling not needed, but can be done with model.sample(warp, certainty)
  27. x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
  28. x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
  29. coords_A, coords_B = warp[...,:2], warp[...,2:]
  30. for i, x in enumerate(np.linspace(0,2*np.pi,200)):
  31. t = (1 + np.cos(x))/2
  32. interp_warp = (1-t)*coords_A + t*coords_B
  33. im2_transfer_rgb = F.grid_sample(
  34. x2[None], interp_warp[None], mode="bilinear", align_corners=False
  35. )[0]
  36. tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg")