Prechádzať zdrojové kódy

feat: handle pillow images as input (#35)

Co-authored-by: Johan Edstedt <johan.edstedt@liu.se>
Diego 1 rok pred
rodič
commit
36389eff40
2 zmenil súbory, kde vykonal 8 pridanie a 6 odobranie
  1. 2 0
      .gitignore
  2. 6 6
      roma/models/matcher.py

+ 2 - 0
.gitignore

@@ -3,6 +3,8 @@
 *__pycache__*
 vis*
 workspace*
+.venv
+.DS_Store
 jobs/*
 *ignore_me*
 *.pth

+ 6 - 6
roma/models/matcher.py

@@ -14,6 +14,7 @@ from roma.utils import get_tuple_transform_ops
 from roma.utils.local_correlation import local_correlation
 from roma.utils.utils import cls_to_flow_refine
 from roma.utils.kde import kde
+from typing import Union
 
 class ConvRefiner(nn.Module):
     def __init__(
@@ -610,8 +611,8 @@ class RegressionMatcher(nn.Module):
     @torch.inference_mode()
     def match(
         self,
-        im_A_path,
-        im_B_path,
+        im_A_path: Union[str, os.PathLike, Image.Image],
+        im_B_path: Union[str, os.PathLike, Image.Image],
         *args,
         batched=False,
         device = None,
@@ -621,8 +622,8 @@ class RegressionMatcher(nn.Module):
         if isinstance(im_A_path, (str, os.PathLike)):
             im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
         else:
-            # Assume its not a path
-            im_A, im_B = im_A_path, im_B_path
+            im_A, im_B = im_A_path, im_B_path 
+
         symmetric = self.symmetric
         self.train(False)
         with torch.no_grad():
@@ -672,13 +673,12 @@ class RegressionMatcher(nn.Module):
                     resize=(hs, ws), normalize=True
                 )
                 if self.recrop_upsample:
+                    raise NotImplementedError("recrop_upsample not implemented")
                     certainty = corresps[finest_scale]["certainty"]
                     print(certainty.shape)
                     im_A = self.recrop(certainty[0,0], im_A_path)
                     im_B = self.recrop(certainty[1,0], im_B_path)
                     #TODO: need to adjust corresps when doing this
-                else:
-                    im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
                 im_A, im_B = test_transform((im_A, im_B))
                 im_A, im_B = im_A[None].to(device), im_B[None].to(device)
                 scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))