Explorar o código

update keypoint matching with roma

Johan Edstedt %!s(int64=2) %!d(string=hai) anos
pai
achega
154175c7d2
Modificáronse 2 ficheiros con 27 adicións e 6 borrados
  1. 24 6
      roma/models/matcher.py
  2. 3 0
      roma/utils/utils.py

+ 24 - 6
roma/models/matcher.py

@@ -524,22 +524,40 @@ class RegressionMatcher(nn.Module):
                                 scale_factor=scale_factor)
         return corresps
     
-    def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
-        kpts_A, kpts_B = matches[...,:2], matches[...,2:]
+    def to_pixel_coordinates(self, coords, H_A, W_A, H_B, W_B):
+        if isinstance(coords, (list, tuple)):
+            kpts_A, kpts_B = coords[0], coords[1]
+        else:
+            kpts_A, kpts_B = coords[...,:2], coords[...,2:]
         kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
         kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
         return kpts_A, kpts_B
+    
+    def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B):
+        if isinstance(coords, (list, tuple)):
+            kpts_A, kpts_B = coords[0], coords[1]
+        else:
+            kpts_A, kpts_B = coords[...,:2], coords[...,2:]
+        kpts_A = torch.stack((2/W_A * kpts_A[...,0] - 1, 2/H_A * kpts_A[...,1] - 1),axis=-1)
+        kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
+        return kpts_A, kpts_B
 
-
-    def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True):
+    def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return_inds = False):
         x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT
         cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0]
         D = torch.cdist(x_A_to_B, x_B)
         inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True)
+        
         if return_tuple:
-            return x_A[inds_A], x_B[inds_B]
+            if return_inds:
+                return inds_A, inds_B
+            else:
+                return x_A[inds_A], x_B[inds_B]
         else:
-            return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
+            if return_inds:
+                return torch.cat((inds_A, inds_B),dim=-1)
+            else:
+                return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
     
     def get_roi(self, certainty, W, H, thr = 0.025):
         raise NotImplementedError("WIP, disable for now")

+ 3 - 0
roma/utils/utils.py

@@ -520,6 +520,8 @@ def flow_to_pixel_coords(flow, h1, w1):
     )
     return flow
 
+to_pixel_coords = flow_to_pixel_coords # just an alias
+
 def flow_to_normalized_coords(flow, h1, w1):
     flow = (
         torch.stack(
@@ -532,6 +534,7 @@ def flow_to_normalized_coords(flow, h1, w1):
     )
     return flow
 
+to_normalized_coords = flow_to_normalized_coords # just an alias
 
 def warp_to_pixel_coords(warp, h1, w1, h2, w2):
     warp1 = warp[..., :2]