|
|
@@ -524,22 +524,40 @@ class RegressionMatcher(nn.Module):
|
|
|
scale_factor=scale_factor)
|
|
|
return corresps
|
|
|
|
|
|
- def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
|
|
|
- kpts_A, kpts_B = matches[...,:2], matches[...,2:]
|
|
|
+ def to_pixel_coordinates(self, coords, H_A, W_A, H_B, W_B):
|
|
|
+ if isinstance(coords, (list, tuple)):
|
|
|
+ kpts_A, kpts_B = coords[0], coords[1]
|
|
|
+ else:
|
|
|
+ kpts_A, kpts_B = coords[...,:2], coords[...,2:]
|
|
|
kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
|
|
|
kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
|
|
|
return kpts_A, kpts_B
|
|
|
+
|
|
|
+ def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B):
|
|
|
+ if isinstance(coords, (list, tuple)):
|
|
|
+ kpts_A, kpts_B = coords[0], coords[1]
|
|
|
+ else:
|
|
|
+ kpts_A, kpts_B = coords[...,:2], coords[...,2:]
|
|
|
+ kpts_A = torch.stack((2/W_A * kpts_A[...,0] - 1, 2/H_A * kpts_A[...,1] - 1),axis=-1)
|
|
|
+ kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
|
|
|
+ return kpts_A, kpts_B
|
|
|
|
|
|
-
|
|
|
- def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True):
|
|
|
+ def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return_inds = False):
|
|
|
x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT
|
|
|
cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0]
|
|
|
D = torch.cdist(x_A_to_B, x_B)
|
|
|
inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True)
|
|
|
+
|
|
|
if return_tuple:
|
|
|
- return x_A[inds_A], x_B[inds_B]
|
|
|
+ if return_inds:
|
|
|
+ return inds_A, inds_B
|
|
|
+ else:
|
|
|
+ return x_A[inds_A], x_B[inds_B]
|
|
|
else:
|
|
|
- return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
|
|
|
+ if return_inds:
|
|
|
+ return torch.cat((inds_A, inds_B),dim=-1)
|
|
|
+ else:
|
|
|
+ return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
|
|
|
|
|
|
def get_roi(self, certainty, W, H, thr = 0.025):
|
|
|
raise NotImplementedError("WIP, disable for now")
|