Просмотр исходного кода

add simple method to match keypoints with roma

Johan Edstedt 2 лет назад
Родитель
Сommit
d08585c8be
1 измененных файлов с 13 добавлено и 0 удалено
  1. 13 0
      roma/models/matcher.py

+ 13 - 0
roma/models/matcher.py

@@ -526,6 +526,19 @@ class RegressionMatcher(nn.Module):
         kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
         return kpts_A, kpts_B
 
+
+    def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True):
+        x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT
+        cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0]
+        D = torch.cdist(x_A_to_B, x_B)
+        inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True)
+        if return_tuple:
+            return x_A[inds_A], x_B[inds_B]
+        else:
+            return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
+
+
+    @torch.inference_mode()
     def match(
         self,
         im_A_path,