Bläddra i källkod

update kpt matching

Johan Edstedt 1 år sedan
förälder
incheckning
45c1129d15
1 ändrade filer med 33 tillägg och 15 borttagningar
  1. 33 15
      romatch/models/matcher.py

+ 33 - 15
romatch/models/matcher.py

@@ -573,22 +573,40 @@ class RegressionMatcher(nn.Module):
         kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
         return kpts_A, kpts_B
 
-    def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return_inds = False):
-        x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT
-        cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0]
-        D = torch.cdist(x_A_to_B, x_B)
-        inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True)
-        
-        if return_tuple:
-            if return_inds:
-                return inds_A, inds_B
-            else:
-                return x_A[inds_A], x_B[inds_B]
+def match_keypoints(
+    self, x_A, x_B, warp, certainty, return_tuple=True, return_inds=False, max_dist = 0.005, cert_th = 0,
+):
+    x_A_to_B = F.grid_sample(
+        warp[..., -2:].permute(2, 0, 1)[None],
+        x_A[None, None],
+        align_corners=False,
+        mode="bilinear",
+    )[0, :, 0].mT
+    cert_A_to_B = F.grid_sample(
+        certainty[None, None, ...],
+        x_A[None, None],
+        align_corners=False,
+        mode="bilinear",
+    )[0, 0, 0]
+    D = torch.cdist(x_A_to_B, x_B)
+    inds_A, inds_B = torch.nonzero(
+        (D == D.min(dim=-1, keepdim=True).values)
+        * (D == D.min(dim=-2, keepdim=True).values)
+        * (cert_A_to_B[:, None] > cert_th)
+        * (D < max_dist),
+        as_tuple=True,
+    )
+
+    if return_tuple:
+        if return_inds:
+            return inds_A, inds_B
         else:
-            if return_inds:
-                return torch.cat((inds_A, inds_B),dim=-1)
-            else:
-                return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
+            return x_A[inds_A], x_B[inds_B]
+    else:
+        if return_inds:
+            return torch.cat((inds_A, inds_B), dim=-1)
+        else:
+            return torch.cat((x_A[inds_A], x_B[inds_B]), dim=-1)
             
     @torch.inference_mode()
     def match(