|
|
@@ -573,40 +573,40 @@ class RegressionMatcher(nn.Module):
|
|
|
kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
|
|
|
return kpts_A, kpts_B
|
|
|
|
|
|
-def match_keypoints(
|
|
|
- self, x_A, x_B, warp, certainty, return_tuple=True, return_inds=False, max_dist = 0.005, cert_th = 0,
|
|
|
-):
|
|
|
- x_A_to_B = F.grid_sample(
|
|
|
- warp[..., -2:].permute(2, 0, 1)[None],
|
|
|
- x_A[None, None],
|
|
|
- align_corners=False,
|
|
|
- mode="bilinear",
|
|
|
- )[0, :, 0].mT
|
|
|
- cert_A_to_B = F.grid_sample(
|
|
|
- certainty[None, None, ...],
|
|
|
- x_A[None, None],
|
|
|
- align_corners=False,
|
|
|
- mode="bilinear",
|
|
|
- )[0, 0, 0]
|
|
|
- D = torch.cdist(x_A_to_B, x_B)
|
|
|
- inds_A, inds_B = torch.nonzero(
|
|
|
- (D == D.min(dim=-1, keepdim=True).values)
|
|
|
- * (D == D.min(dim=-2, keepdim=True).values)
|
|
|
- * (cert_A_to_B[:, None] > cert_th)
|
|
|
- * (D < max_dist),
|
|
|
- as_tuple=True,
|
|
|
- )
|
|
|
-
|
|
|
- if return_tuple:
|
|
|
- if return_inds:
|
|
|
- return inds_A, inds_B
|
|
|
- else:
|
|
|
- return x_A[inds_A], x_B[inds_B]
|
|
|
- else:
|
|
|
- if return_inds:
|
|
|
- return torch.cat((inds_A, inds_B), dim=-1)
|
|
|
+ def match_keypoints(
|
|
|
+ self, x_A, x_B, warp, certainty, return_tuple=True, return_inds=False, max_dist = 0.005, cert_th = 0,
|
|
|
+ ):
|
|
|
+ x_A_to_B = F.grid_sample(
|
|
|
+ warp[..., -2:].permute(2, 0, 1)[None],
|
|
|
+ x_A[None, None],
|
|
|
+ align_corners=False,
|
|
|
+ mode="bilinear",
|
|
|
+ )[0, :, 0].mT
|
|
|
+ cert_A_to_B = F.grid_sample(
|
|
|
+ certainty[None, None, ...],
|
|
|
+ x_A[None, None],
|
|
|
+ align_corners=False,
|
|
|
+ mode="bilinear",
|
|
|
+ )[0, 0, 0]
|
|
|
+ D = torch.cdist(x_A_to_B, x_B)
|
|
|
+ inds_A, inds_B = torch.nonzero(
|
|
|
+ (D == D.min(dim=-1, keepdim=True).values)
|
|
|
+ * (D == D.min(dim=-2, keepdim=True).values)
|
|
|
+ * (cert_A_to_B[:, None] > cert_th)
|
|
|
+ * (D < max_dist),
|
|
|
+ as_tuple=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ if return_tuple:
|
|
|
+ if return_inds:
|
|
|
+ return inds_A, inds_B
|
|
|
+ else:
|
|
|
+ return x_A[inds_A], x_B[inds_B]
|
|
|
else:
|
|
|
- return torch.cat((x_A[inds_A], x_B[inds_B]), dim=-1)
|
|
|
+ if return_inds:
|
|
|
+ return torch.cat((inds_A, inds_B), dim=-1)
|
|
|
+ else:
|
|
|
+ return torch.cat((x_A[inds_A], x_B[inds_B]), dim=-1)
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def match(
|