Ver Fonte

wip recrop at upsample

Johan Edstedt há 2 anos atrás
pai
commit
4aa8c27744
2 ficheiros alterados com 31 adições e 4 exclusões
  1. 31 3
      roma/models/matcher.py
  2. 0 1
      roma/utils/local_correlation.py

+ 31 - 3
roma/models/matcher.py

@@ -7,6 +7,7 @@ import torch.nn.functional as F
 from einops import rearrange
 import warnings
 from warnings import warn
+from PIL import Image
 
 import roma
 from roma.utils import get_tuple_transform_ops
@@ -427,6 +428,7 @@ class RegressionMatcher(nn.Module):
         symmetric = False,
         name = None,
         attenuate_cert = None,
+        recrop_upsample = False,
     ):
         super().__init__()
         self.attenuate_cert = attenuate_cert
@@ -441,6 +443,7 @@ class RegressionMatcher(nn.Module):
         self.upsample_res = (14*16*6, 14*16*6)
         self.symmetric = symmetric
         self.sample_thresh = 0.05
+        self.recrop_upsample = recrop_upsample
             
     def get_output_resolution(self):
         if not self.upsample_preds:
@@ -536,8 +539,27 @@ class RegressionMatcher(nn.Module):
             return x_A[inds_A], x_B[inds_B]
         else:
             return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
+    
+    def get_roi(self, certainty, W, H, thr = 0.025):
+        raise NotImplementedError("WIP, disable for now")
+        hs,ws = certainty.shape
+        certainty = certainty/certainty.sum(dim=(-1,-2))
+        cum_certainty_w = certainty.cumsum(dim=-1).sum(dim=-2)
+        cum_certainty_h = certainty.cumsum(dim=-2).sum(dim=-1)
+        print(cum_certainty_w)
+        print(torch.min(torch.nonzero(cum_certainty_w > thr)))
+        print(torch.min(torch.nonzero(cum_certainty_w < thr)))
+        left = int(W/ws * torch.min(torch.nonzero(cum_certainty_w > thr)))
+        right = int(W/ws * torch.max(torch.nonzero(cum_certainty_w < 1 - thr)))
+        top = int(H/hs * torch.min(torch.nonzero(cum_certainty_h > thr)))
+        bottom = int(H/hs * torch.max(torch.nonzero(cum_certainty_h < 1 - thr)))
+        print(left, right, top, bottom)
+        return left, top, right, bottom
 
-
+    def recrop(self, certainty, image_path):
+        roi = self.get_roi(certainty, *Image.open(image_path).size)
+        return Image.open(image_path).crop(roi)
+        
     @torch.inference_mode()
     def match(
         self,
@@ -549,7 +571,6 @@ class RegressionMatcher(nn.Module):
     ):
         if device is None:
             device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-        from PIL import Image
         if isinstance(im_A_path, (str, os.PathLike)):
             im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
         else:
@@ -603,7 +624,14 @@ class RegressionMatcher(nn.Module):
                 test_transform = get_tuple_transform_ops(
                     resize=(hs, ws), normalize=True
                 )
-                im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
+                if self.recrop_upsample:
+                    certainty = corresps[finest_scale]["certainty"]
+                    print(certainty.shape)
+                    im_A = self.recrop(certainty[0,0], im_A_path)
+                    im_B = self.recrop(certainty[1,0], im_B_path)
+                    #TODO: need to adjust corresps when doing this
+                else:
+                    im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
                 im_A, im_B = test_transform((im_A, im_B))
                 im_A, im_B = im_A[None].to(device), im_B[None].to(device)
                 scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))

+ 0 - 1
roma/utils/local_correlation.py

@@ -43,5 +43,4 @@ def local_correlation(
             )
             window_feature = window_feature.reshape(c,h,w,(2*r+1)**2)
         corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1)
-    torch.cuda.empty_cache()
     return corr