|
|
@@ -526,6 +526,19 @@ class RegressionMatcher(nn.Module):
|
|
|
kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
|
|
|
return kpts_A, kpts_B
|
|
|
|
|
|
+
|
|
|
+ def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True):
|
|
|
+ x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT
|
|
|
+ cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0]
|
|
|
+ D = torch.cdist(x_A_to_B, x_B)
|
|
|
+ inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True)
|
|
|
+ if return_tuple:
|
|
|
+ return x_A[inds_A], x_B[inds_B]
|
|
|
+ else:
|
|
|
+ return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
|
|
|
+
|
|
|
+
|
|
|
+ @torch.inference_mode()
|
|
|
def match(
|
|
|
self,
|
|
|
im_A_path,
|