|
|
@@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|
|
from einops import rearrange
|
|
|
import warnings
|
|
|
from warnings import warn
|
|
|
+from PIL import Image
|
|
|
|
|
|
import roma
|
|
|
from roma.utils import get_tuple_transform_ops
|
|
|
@@ -427,6 +428,7 @@ class RegressionMatcher(nn.Module):
|
|
|
symmetric = False,
|
|
|
name = None,
|
|
|
attenuate_cert = None,
|
|
|
+ recrop_upsample = False,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.attenuate_cert = attenuate_cert
|
|
|
@@ -441,6 +443,7 @@ class RegressionMatcher(nn.Module):
|
|
|
self.upsample_res = (14*16*6, 14*16*6)
|
|
|
self.symmetric = symmetric
|
|
|
self.sample_thresh = 0.05
|
|
|
+ self.recrop_upsample = recrop_upsample
|
|
|
|
|
|
def get_output_resolution(self):
|
|
|
if not self.upsample_preds:
|
|
|
@@ -536,8 +539,27 @@ class RegressionMatcher(nn.Module):
|
|
|
return x_A[inds_A], x_B[inds_B]
|
|
|
else:
|
|
|
return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
|
|
|
+
|
|
|
+ def get_roi(self, certainty, W, H, thr = 0.025):
|
|
|
+ raise NotImplementedError("WIP, disable for now")
|
|
|
+ hs,ws = certainty.shape
|
|
|
+ certainty = certainty/certainty.sum(dim=(-1,-2))
|
|
|
+ cum_certainty_w = certainty.cumsum(dim=-1).sum(dim=-2)
|
|
|
+ cum_certainty_h = certainty.cumsum(dim=-2).sum(dim=-1)
|
|
|
+ print(cum_certainty_w)
|
|
|
+ print(torch.min(torch.nonzero(cum_certainty_w > thr)))
|
|
|
+ print(torch.min(torch.nonzero(cum_certainty_w < thr)))
|
|
|
+ left = int(W/ws * torch.min(torch.nonzero(cum_certainty_w > thr)))
|
|
|
+ right = int(W/ws * torch.max(torch.nonzero(cum_certainty_w < 1 - thr)))
|
|
|
+ top = int(H/hs * torch.min(torch.nonzero(cum_certainty_h > thr)))
|
|
|
+ bottom = int(H/hs * torch.max(torch.nonzero(cum_certainty_h < 1 - thr)))
|
|
|
+ print(left, right, top, bottom)
|
|
|
+ return left, top, right, bottom
|
|
|
|
|
|
-
|
|
|
+ def recrop(self, certainty, image_path):
|
|
|
+ roi = self.get_roi(certainty, *Image.open(image_path).size)
|
|
|
+ return Image.open(image_path).crop(roi)
|
|
|
+
|
|
|
@torch.inference_mode()
|
|
|
def match(
|
|
|
self,
|
|
|
@@ -549,7 +571,6 @@ class RegressionMatcher(nn.Module):
|
|
|
):
|
|
|
if device is None:
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
- from PIL import Image
|
|
|
if isinstance(im_A_path, (str, os.PathLike)):
|
|
|
im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
|
|
|
else:
|
|
|
@@ -603,7 +624,14 @@ class RegressionMatcher(nn.Module):
|
|
|
test_transform = get_tuple_transform_ops(
|
|
|
resize=(hs, ws), normalize=True
|
|
|
)
|
|
|
- im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
|
|
|
+ if self.recrop_upsample:
|
|
|
+ certainty = corresps[finest_scale]["certainty"]
|
|
|
+ print(certainty.shape)
|
|
|
+ im_A = self.recrop(certainty[0,0], im_A_path)
|
|
|
+ im_B = self.recrop(certainty[1,0], im_B_path)
|
|
|
+ #TODO: need to adjust corresps when doing this
|
|
|
+ else:
|
|
|
+ im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
|
|
|
im_A, im_B = test_transform((im_A, im_B))
|
|
|
im_A, im_B = im_A[None].to(device), im_B[None].to(device)
|
|
|
scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
|