فهرست منبع

Batch-dimension fix for demo_match.py (#146)

Jeppe Nørregaard 3 ماه پیش
والد
کامیت
8afdca39b6
1فایلهای تغییر یافته به همراه2 افزوده شده و 2 حذف شده
  1. 2 2
      demo/demo_match.py

+ 2 - 2
demo/demo_match.py

@@ -39,10 +39,10 @@ if __name__ == "__main__":
     x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
     x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
 
 
     im2_transfer_rgb = F.grid_sample(
     im2_transfer_rgb = F.grid_sample(
-    x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
+    x2[None], warp[:, :, :W, 2:], mode="bilinear", align_corners=False
     )[0]
     )[0]
     im1_transfer_rgb = F.grid_sample(
     im1_transfer_rgb = F.grid_sample(
-    x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
+    x1[None], warp[:, :, W:, :2], mode="bilinear", align_corners=False
     )[0]
     )[0]
     warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
     warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
     white_im = torch.ones((H,2*W),device=device)
     white_im = torch.ones((H,2*W),device=device)