Johan Edstedt 1 рік тому
батько
коміт
edd1b8b35a
1 змінених файлів з 33 додано та 33 видалено
  1. 33 33
      romatch/models/matcher.py

+ 33 - 33
romatch/models/matcher.py

@@ -573,40 +573,40 @@ class RegressionMatcher(nn.Module):
         kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
         return kpts_A, kpts_B
 
-def match_keypoints(
-    self, x_A, x_B, warp, certainty, return_tuple=True, return_inds=False, max_dist = 0.005, cert_th = 0,
-):
-    x_A_to_B = F.grid_sample(
-        warp[..., -2:].permute(2, 0, 1)[None],
-        x_A[None, None],
-        align_corners=False,
-        mode="bilinear",
-    )[0, :, 0].mT
-    cert_A_to_B = F.grid_sample(
-        certainty[None, None, ...],
-        x_A[None, None],
-        align_corners=False,
-        mode="bilinear",
-    )[0, 0, 0]
-    D = torch.cdist(x_A_to_B, x_B)
-    inds_A, inds_B = torch.nonzero(
-        (D == D.min(dim=-1, keepdim=True).values)
-        * (D == D.min(dim=-2, keepdim=True).values)
-        * (cert_A_to_B[:, None] > cert_th)
-        * (D < max_dist),
-        as_tuple=True,
-    )
-
-    if return_tuple:
-        if return_inds:
-            return inds_A, inds_B
-        else:
-            return x_A[inds_A], x_B[inds_B]
-    else:
-        if return_inds:
-            return torch.cat((inds_A, inds_B), dim=-1)
+    def match_keypoints(
+        self, x_A, x_B, warp, certainty, return_tuple=True, return_inds=False, max_dist = 0.005, cert_th = 0,
+    ):
+        x_A_to_B = F.grid_sample(
+            warp[..., -2:].permute(2, 0, 1)[None],
+            x_A[None, None],
+            align_corners=False,
+            mode="bilinear",
+        )[0, :, 0].mT
+        cert_A_to_B = F.grid_sample(
+            certainty[None, None, ...],
+            x_A[None, None],
+            align_corners=False,
+            mode="bilinear",
+        )[0, 0, 0]
+        D = torch.cdist(x_A_to_B, x_B)
+        inds_A, inds_B = torch.nonzero(
+            (D == D.min(dim=-1, keepdim=True).values)
+            * (D == D.min(dim=-2, keepdim=True).values)
+            * (cert_A_to_B[:, None] > cert_th)
+            * (D < max_dist),
+            as_tuple=True,
+        )
+
+        if return_tuple:
+            if return_inds:
+                return inds_A, inds_B
+            else:
+                return x_A[inds_A], x_B[inds_B]
         else:
-            return torch.cat((x_A[inds_A], x_B[inds_B]), dim=-1)
+            if return_inds:
+                return torch.cat((inds_A, inds_B), dim=-1)
+            else:
+                return torch.cat((x_A[inds_A], x_B[inds_B]), dim=-1)
             
     @torch.inference_mode()
     def match(