Explorar o código

some updates to api and add demo image pair and 3D effect demo

Johan Edstedt %!s(int64=2) %!d(string=hai) anos
pai
achega
d74389802f

BIN=BIN
assets/toronto_A.jpg


BIN=BIN
assets/toronto_B.jpg


+ 46 - 0
demo/demo_3D_effect.py

@@ -0,0 +1,46 @@
+from PIL import Image
+import torch
+import torch.nn.functional as F
+import numpy as np
+from roma.utils.utils import tensor_to_pil
+
+from roma import roma_outdoor
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+if __name__ == "__main__":
+    from argparse import ArgumentParser
+    parser = ArgumentParser()
+    parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
+    parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
+    parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str)
+
+    args, _ = parser.parse_known_args()
+    im1_path = args.im_A_path
+    im2_path = args.im_B_path
+    save_path = args.save_path
+
+    # Create model
+    roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
+    roma_model.symmetric = False
+
+    H, W = roma_model.get_output_resolution()
+
+    im1 = Image.open(im1_path).resize((W, H))
+    im2 = Image.open(im2_path).resize((W, H))
+
+    # Match
+    warp, certainty = roma_model.match(im1_path, im2_path, device=device)
+    # Sampling not needed, but can be done with model.sample(warp, certainty)
+    x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
+    x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
+
+    coords_A, coords_B = warp[...,:2], warp[...,2:]
+    for i, x in enumerate(np.linspace(0,2*np.pi,200)):
+        t = (1 + np.cos(x))/2
+        interp_warp = (1-t)*coords_A + t*coords_B
+        im2_transfer_rgb = F.grid_sample(
+        x2[None], interp_warp[None], mode="bilinear", align_corners=False
+        )[0]
+        tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg")

+ 5 - 5
demo/demo_match.py

@@ -4,7 +4,7 @@ import torch.nn.functional as F
 import numpy as np
 from roma.utils.utils import tensor_to_pil
 
-from roma import roma_indoor
+from roma import roma_outdoor
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
@@ -12,9 +12,9 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == "__main__":
     from argparse import ArgumentParser
     parser = ArgumentParser()
-    parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
-    parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
-    parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str)
+    parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
+    parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
+    parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
 
     args, _ = parser.parse_known_args()
     im1_path = args.im_A_path
@@ -22,7 +22,7 @@ if __name__ == "__main__":
     save_path = args.save_path
 
     # Create model
-    roma_model = roma_indoor(device=device)
+    roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
 
     H, W = roma_model.get_output_resolution()
 

+ 43 - 0
demo/demo_match_opencv_sift.py

@@ -0,0 +1,43 @@
+from PIL import Image
+import numpy as np
+
+import numpy as np
+import cv2 as cv
+import matplotlib.pyplot as plt
+
+
+
+if __name__ == "__main__":
+    from argparse import ArgumentParser
+    parser = ArgumentParser()
+    parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
+    parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
+    parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
+
+    args, _ = parser.parse_known_args()
+    im1_path = args.im_A_path
+    im2_path = args.im_B_path
+    save_path = args.save_path
+
+    img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE)          # queryImage
+    img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage
+    # Initiate SIFT detector
+    sift = cv.SIFT_create()
+    # find the keypoints and descriptors with SIFT
+    kp1, des1 = sift.detectAndCompute(img1,None)
+    kp2, des2 = sift.detectAndCompute(img2,None)
+    # BFMatcher with default params
+    bf = cv.BFMatcher()
+    matches = bf.knnMatch(des1,des2,k=2)
+    # Apply ratio test
+    good = []
+    for m,n in matches:
+        if m.distance < 0.75*n.distance:
+            good.append([m])
+    # cv.drawMatchesKnn expects list of lists as matches.
+    draw_params = dict(matchColor = (255,0,0), # draw matches in red color
+                   singlePointColor = None,
+                   flags = 2)
+
+    img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params)
+    Image.fromarray(img3).save("demo/sift_matches.png")

+ 2 - 0
demo/gif/.gitignore

@@ -0,0 +1,2 @@
+*
+!.gitignore

+ 27 - 4
roma/models/model_zoo/__init__.py

@@ -1,3 +1,4 @@
+from typing import Union
 import torch
 from .roma_models import roma_model
 
@@ -9,22 +10,44 @@ weight_urls = {
     "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D
 }
 
-def roma_outdoor(device, weights=None, dinov2_weights=None):
+def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864):
+    if isinstance(coarse_res, int):
+        coarse_res = (coarse_res, coarse_res)
+    if isinstance(upsample_res, int):    
+        upsample_res = (upsample_res, upsample_res)
+
+    assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
+    assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
+    
     if weights is None:
         weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["outdoor"],
                                                      map_location=device)
     if dinov2_weights is None:
         dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
                                                      map_location=device)
-    return roma_model(resolution=(14*8*6,14*8*6), upsample_preds=True,
+    model = roma_model(resolution=coarse_res, upsample_preds=True,
                weights=weights,dinov2_weights = dinov2_weights,device=device)
+    model.upsample_res = upsample_res
+    print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
+    return model
+
+def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864):
+    if isinstance(coarse_res, int):
+        coarse_res = (coarse_res, coarse_res)
+    if isinstance(upsample_res, int):    
+        upsample_res = (upsample_res, upsample_res)
 
-def roma_indoor(device, weights=None, dinov2_weights=None):
+    assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
+    assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
+    
     if weights is None:
         weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["indoor"],
                                                      map_location=device)
     if dinov2_weights is None:
         dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
                                                      map_location=device)
-    return roma_model(resolution=(14*8*5,14*8*5), upsample_preds=False,
+    model = roma_model(resolution=coarse_res, upsample_preds=True,
                weights=weights,dinov2_weights = dinov2_weights,device=device)
+    model.upsample_res = upsample_res
+    print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
+    return model