Bladeren bron

add simple visualization to model

Johan Edstedt 2 jaren geleden
bovenliggende
commit
d25d21ef7e
1 gewijzigde bestanden met toevoegingen van 26 en 0 verwijderingen
  1. 26 0
      roma/models/matcher.py

+ 26 - 0
roma/models/matcher.py

@@ -659,4 +659,30 @@ class RegressionMatcher(nn.Module):
                     warp[0],
                     certainty[0, 0],
                 )
+                
+    def visualize_warp(self, warp, certainty, im_A = None, im_B = None, im_A_path = None, im_B_path = None, device = "cuda", symmetric = True, save_path = None):
+        assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)"
+        H,W2,_ = warp.shape
+        W = W2//2 if symmetric else W2
+        if im_A is None:
+            from PIL import Image
+            im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
+        im_A = im_A.resize((W,H))
+        im_B = im_B.resize((W,H))
+            
+        x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
+        x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
 
+        im_A_transfer_rgb = F.grid_sample(
+        x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
+        )[0]
+        im_B_transfer_rgb = F.grid_sample(
+        x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
+        )[0]
+        warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
+        white_im = torch.ones((H,2*W),device=device)
+        vis_im = certainty * warp_im + (1 - certainty) * white_im
+        if save_path is not None:
+            from roma.utils import tensor_to_pil
+            tensor_to_pil(vis_im, unnormalize=False).save(save_path)
+        return vis_im