Преглед на файлове

make sure to reread image for upsample_preds

Johan Edstedt преди 1 година
родител
ревизия
f2ec3dea08
променени са 1 файла, в които са добавени 4 реда и са изтрити 36 реда
  1. 4 36
      romatch/models/matcher.py

+ 4 - 36
romatch/models/matcher.py

@@ -9,12 +9,10 @@ import warnings
 from warnings import warn
 from PIL import Image
 
-import romatch
 from romatch.utils import get_tuple_transform_ops
 from romatch.utils.local_correlation import local_correlation
 from romatch.utils.utils import cls_to_flow_refine, get_autocast_params
 from romatch.utils.kde import kde
-from typing import Union
 
 class ConvRefiner(nn.Module):
     def __init__(
@@ -431,10 +429,8 @@ class RegressionMatcher(nn.Module):
         symmetric = False,
         name = None,
         attenuate_cert = None,
-        recrop_upsample = False,
     ):
         super().__init__()
-        raise NotImplementedError("Possibly bugged, please use commit ca2615f for now. Should be fixed soon.")
         self.attenuate_cert = attenuate_cert
         self.encoder = encoder
         self.decoder = decoder
@@ -447,7 +443,6 @@ class RegressionMatcher(nn.Module):
         self.upsample_res = (14*16*6, 14*16*6)
         self.symmetric = symmetric
         self.sample_thresh = 0.05
-        self.recrop_upsample = recrop_upsample
             
     def get_output_resolution(self):
         if not self.upsample_preds:
@@ -589,32 +584,12 @@ class RegressionMatcher(nn.Module):
                 return torch.cat((inds_A, inds_B),dim=-1)
             else:
                 return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
-    
-    def get_roi(self, certainty, W, H, thr = 0.025):
-        raise NotImplementedError("WIP, disable for now")
-        hs,ws = certainty.shape
-        certainty = certainty/certainty.sum(dim=(-1,-2))
-        cum_certainty_w = certainty.cumsum(dim=-1).sum(dim=-2)
-        cum_certainty_h = certainty.cumsum(dim=-2).sum(dim=-1)
-        print(cum_certainty_w)
-        print(torch.min(torch.nonzero(cum_certainty_w > thr)))
-        print(torch.min(torch.nonzero(cum_certainty_w < thr)))
-        left = int(W/ws * torch.min(torch.nonzero(cum_certainty_w > thr)))
-        right = int(W/ws * torch.max(torch.nonzero(cum_certainty_w < 1 - thr)))
-        top = int(H/hs * torch.min(torch.nonzero(cum_certainty_h > thr)))
-        bottom = int(H/hs * torch.max(torch.nonzero(cum_certainty_h < 1 - thr)))
-        print(left, right, top, bottom)
-        return left, top, right, bottom
-
-    def recrop(self, certainty, image_path):
-        roi = self.get_roi(certainty, *Image.open(image_path).size)
-        return Image.open(image_path).convert("RGB").crop(roi)
-        
+            
     @torch.inference_mode()
     def match(
         self,
-        im_A_path: Union[str, os.PathLike, Image.Image],
-        im_B_path: Union[str, os.PathLike, Image.Image],
+        im_A_path,
+        im_B_path,
         *args,
         batched=False,
         device = None,
@@ -674,14 +649,7 @@ class RegressionMatcher(nn.Module):
                 test_transform = get_tuple_transform_ops(
                     resize=(hs, ws), normalize=True
                 )
-                if self.recrop_upsample:
-                    raise NotImplementedError("recrop_upsample not implemented")
-                    certainty = corresps[finest_scale]["certainty"]
-                    print(certainty.shape)
-                    im_A = self.recrop(certainty[0,0], im_A_path)
-                    im_B = self.recrop(certainty[1,0], im_B_path)
-                    #TODO: need to adjust corresps when doing this
-                im_A, im_B = test_transform((im_A, im_B))
+                im_A, im_B = test_transform((Image.open(im_A_path).convert('RGB'), Image.open(im_B_path).convert('RGB')))
                 im_A, im_B = im_A[None].to(device), im_B[None].to(device)
                 scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
                 batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}