|
|
@@ -593,18 +593,25 @@ class RegressionMatcher(nn.Module):
|
|
|
@torch.inference_mode()
|
|
|
def match(
|
|
|
self,
|
|
|
- im_A_path,
|
|
|
- im_B_path,
|
|
|
+ im_A_input,
|
|
|
+ im_B_input,
|
|
|
*args,
|
|
|
batched=False,
|
|
|
- device = None,
|
|
|
+ device=None,
|
|
|
):
|
|
|
if device is None:
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
- if isinstance(im_A_path, (str, os.PathLike)):
|
|
|
- im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
|
|
|
+
|
|
|
+ # Check if inputs are file paths or already loaded images
|
|
|
+ if isinstance(im_A_input, (str, os.PathLike)):
|
|
|
+ im_A = Image.open(im_A_input).convert("RGB")
|
|
|
+ else:
|
|
|
+ im_A = im_A_input
|
|
|
+
|
|
|
+ if isinstance(im_B_input, (str, os.PathLike)):
|
|
|
+ im_B = Image.open(im_B_input).convert("RGB")
|
|
|
else:
|
|
|
- im_A, im_B = im_A_path, im_B_path
|
|
|
+ im_B = im_B_input
|
|
|
|
|
|
symmetric = self.symmetric
|
|
|
self.train(False)
|
|
|
@@ -616,9 +623,9 @@ class RegressionMatcher(nn.Module):
|
|
|
# Get images in good format
|
|
|
ws = self.w_resized
|
|
|
hs = self.h_resized
|
|
|
-
|
|
|
+
|
|
|
test_transform = get_tuple_transform_ops(
|
|
|
- resize=(hs, ws), normalize=True, clahe = False
|
|
|
+ resize=(hs, ws), normalize=True, clahe=False
|
|
|
)
|
|
|
im_A, im_B = test_transform((im_A, im_B))
|
|
|
batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
|
|
|
@@ -633,20 +640,20 @@ class RegressionMatcher(nn.Module):
|
|
|
finest_scale = 1
|
|
|
# Run matcher
|
|
|
if symmetric:
|
|
|
- corresps = self.forward_symmetric(batch)
|
|
|
+ corresps = self.forward_symmetric(batch)
|
|
|
else:
|
|
|
- corresps = self.forward(batch, batched = True)
|
|
|
+ corresps = self.forward(batch, batched=True)
|
|
|
|
|
|
if self.upsample_preds:
|
|
|
hs, ws = self.upsample_res
|
|
|
-
|
|
|
+
|
|
|
if self.attenuate_cert:
|
|
|
low_res_certainty = F.interpolate(
|
|
|
- corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
|
|
|
+ corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
|
|
|
)
|
|
|
cert_clamp = 0
|
|
|
factor = 0.5
|
|
|
- low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
|
|
|
+ low_res_certainty = factor * low_res_certainty * (low_res_certainty < cert_clamp)
|
|
|
|
|
|
if self.upsample_preds:
|
|
|
finest_corresps = corresps[finest_scale]
|
|
|
@@ -654,34 +661,39 @@ class RegressionMatcher(nn.Module):
|
|
|
test_transform = get_tuple_transform_ops(
|
|
|
resize=(hs, ws), normalize=True
|
|
|
)
|
|
|
- im_A, im_B = test_transform((Image.open(im_A_path).convert('RGB'), Image.open(im_B_path).convert('RGB')))
|
|
|
+ if isinstance(im_A_input, (str, os.PathLike)):
|
|
|
+ im_A, im_B = test_transform(
|
|
|
+ (Image.open(im_A_input).convert('RGB'), Image.open(im_B_input).convert('RGB')))
|
|
|
+ else:
|
|
|
+ im_A, im_B = test_transform((im_A_input, im_B_input))
|
|
|
+
|
|
|
im_A, im_B = im_A[None].to(device), im_B[None].to(device)
|
|
|
scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
|
|
|
batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
|
|
|
if symmetric:
|
|
|
- corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
|
|
|
+ corresps = self.forward_symmetric(batch, upsample=True, batched=True, scale_factor=scale_factor)
|
|
|
else:
|
|
|
- corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
|
|
|
-
|
|
|
- im_A_to_im_B = corresps[finest_scale]["flow"]
|
|
|
+ corresps = self.forward(batch, batched=True, upsample=True, scale_factor=scale_factor)
|
|
|
+
|
|
|
+ im_A_to_im_B = corresps[finest_scale]["flow"]
|
|
|
certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
|
|
|
if finest_scale != 1:
|
|
|
im_A_to_im_B = F.interpolate(
|
|
|
- im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
|
|
|
+ im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
|
|
|
)
|
|
|
certainty = F.interpolate(
|
|
|
- certainty, size=(hs, ws), align_corners=False, mode="bilinear"
|
|
|
+ certainty, size=(hs, ws), align_corners=False, mode="bilinear"
|
|
|
)
|
|
|
im_A_to_im_B = im_A_to_im_B.permute(
|
|
|
0, 2, 3, 1
|
|
|
- )
|
|
|
+ )
|
|
|
# Create im_A meshgrid
|
|
|
im_A_coords = torch.meshgrid(
|
|
|
(
|
|
|
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
|
|
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
|
|
),
|
|
|
- indexing = 'ij'
|
|
|
+ indexing='ij'
|
|
|
)
|
|
|
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
|
|
|
@@ -689,14 +701,14 @@ class RegressionMatcher(nn.Module):
|
|
|
im_A_coords = im_A_coords.permute(0, 2, 3, 1)
|
|
|
if (im_A_to_im_B.abs() > 1).any() and True:
|
|
|
wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
|
|
|
- certainty[wrong[:,None]] = 0
|
|
|
+ certainty[wrong[:, None]] = 0
|
|
|
im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
|
|
|
if symmetric:
|
|
|
A_to_B, B_to_A = im_A_to_im_B.chunk(2)
|
|
|
q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
|
|
|
im_B_coords = im_A_coords
|
|
|
s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
|
|
|
- warp = torch.cat((q_warp, s_warp),dim=2)
|
|
|
+ warp = torch.cat((q_warp, s_warp), dim=2)
|
|
|
certainty = torch.cat(certainty.chunk(2), dim=3)
|
|
|
else:
|
|
|
warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
|