瀏覽代碼

Add option in matcher to directly accept PIL images or image paths (#73)

Gašper Spagnolo 1 年之前
父節點
當前提交
0e50b44ab1
共有 1 個文件被更改,包括 36 次插入24 次删除
  1. 36 24
      romatch/models/matcher.py

+ 36 - 24
romatch/models/matcher.py

@@ -593,18 +593,25 @@ class RegressionMatcher(nn.Module):
     @torch.inference_mode()
     def match(
         self,
-        im_A_path,
-        im_B_path,
+        im_A_input,
+        im_B_input,
         *args,
         batched=False,
-        device = None,
+        device=None,
     ):
         if device is None:
             device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-        if isinstance(im_A_path, (str, os.PathLike)):
-            im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
+
+        # Check if inputs are file paths or already loaded images
+        if isinstance(im_A_input, (str, os.PathLike)):
+            im_A = Image.open(im_A_input).convert("RGB")
+        else:
+            im_A = im_A_input
+
+        if isinstance(im_B_input, (str, os.PathLike)):
+            im_B = Image.open(im_B_input).convert("RGB")
         else:
-            im_A, im_B = im_A_path, im_B_path 
+            im_B = im_B_input
 
         symmetric = self.symmetric
         self.train(False)
@@ -616,9 +623,9 @@ class RegressionMatcher(nn.Module):
                 # Get images in good format
                 ws = self.w_resized
                 hs = self.h_resized
-                
+
                 test_transform = get_tuple_transform_ops(
-                    resize=(hs, ws), normalize=True, clahe = False
+                    resize=(hs, ws), normalize=True, clahe=False
                 )
                 im_A, im_B = test_transform((im_A, im_B))
                 batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
@@ -633,20 +640,20 @@ class RegressionMatcher(nn.Module):
             finest_scale = 1
             # Run matcher
             if symmetric:
-                corresps  = self.forward_symmetric(batch)
+                corresps = self.forward_symmetric(batch)
             else:
-                corresps = self.forward(batch, batched = True)
+                corresps = self.forward(batch, batched=True)
 
             if self.upsample_preds:
                 hs, ws = self.upsample_res
-            
+
             if self.attenuate_cert:
                 low_res_certainty = F.interpolate(
-                corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
+                    corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
                 )
                 cert_clamp = 0
                 factor = 0.5
-                low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
+                low_res_certainty = factor * low_res_certainty * (low_res_certainty < cert_clamp)
 
             if self.upsample_preds:
                 finest_corresps = corresps[finest_scale]
@@ -654,34 +661,39 @@ class RegressionMatcher(nn.Module):
                 test_transform = get_tuple_transform_ops(
                     resize=(hs, ws), normalize=True
                 )
-                im_A, im_B = test_transform((Image.open(im_A_path).convert('RGB'), Image.open(im_B_path).convert('RGB')))
+                if isinstance(im_A_input, (str, os.PathLike)):
+                    im_A, im_B = test_transform(
+                        (Image.open(im_A_input).convert('RGB'), Image.open(im_B_input).convert('RGB')))
+                else:
+                    im_A, im_B = test_transform((im_A_input, im_B_input))
+
                 im_A, im_B = im_A[None].to(device), im_B[None].to(device)
                 scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
                 batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
                 if symmetric:
-                    corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
+                    corresps = self.forward_symmetric(batch, upsample=True, batched=True, scale_factor=scale_factor)
                 else:
-                    corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
-            
-            im_A_to_im_B = corresps[finest_scale]["flow"] 
+                    corresps = self.forward(batch, batched=True, upsample=True, scale_factor=scale_factor)
+
+            im_A_to_im_B = corresps[finest_scale]["flow"]
             certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
             if finest_scale != 1:
                 im_A_to_im_B = F.interpolate(
-                im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
+                    im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
                 )
                 certainty = F.interpolate(
-                certainty, size=(hs, ws), align_corners=False, mode="bilinear"
+                    certainty, size=(hs, ws), align_corners=False, mode="bilinear"
                 )
             im_A_to_im_B = im_A_to_im_B.permute(
                 0, 2, 3, 1
-                )
+            )
             # Create im_A meshgrid
             im_A_coords = torch.meshgrid(
                 (
                     torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
                     torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
                 ),
-                indexing = 'ij'
+                indexing='ij'
             )
             im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
             im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
@@ -689,14 +701,14 @@ class RegressionMatcher(nn.Module):
             im_A_coords = im_A_coords.permute(0, 2, 3, 1)
             if (im_A_to_im_B.abs() > 1).any() and True:
                 wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
-                certainty[wrong[:,None]] = 0
+                certainty[wrong[:, None]] = 0
             im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
             if symmetric:
                 A_to_B, B_to_A = im_A_to_im_B.chunk(2)
                 q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
                 im_B_coords = im_A_coords
                 s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
-                warp = torch.cat((q_warp, s_warp),dim=2)
+                warp = torch.cat((q_warp, s_warp), dim=2)
                 certainty = torch.cat(certainty.chunk(2), dim=3)
             else:
                 warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)