|
|
@@ -9,12 +9,10 @@ import warnings
|
|
|
from warnings import warn
|
|
|
from PIL import Image
|
|
|
|
|
|
-import romatch
|
|
|
from romatch.utils import get_tuple_transform_ops
|
|
|
from romatch.utils.local_correlation import local_correlation
|
|
|
from romatch.utils.utils import cls_to_flow_refine, get_autocast_params
|
|
|
from romatch.utils.kde import kde
|
|
|
-from typing import Union
|
|
|
|
|
|
class ConvRefiner(nn.Module):
|
|
|
def __init__(
|
|
|
@@ -431,10 +429,8 @@ class RegressionMatcher(nn.Module):
|
|
|
symmetric = False,
|
|
|
name = None,
|
|
|
attenuate_cert = None,
|
|
|
- recrop_upsample = False,
|
|
|
):
|
|
|
super().__init__()
|
|
|
- raise NotImplementedError("Possibly bugged, please use commit ca2615f for now. Should be fixed soon.")
|
|
|
self.attenuate_cert = attenuate_cert
|
|
|
self.encoder = encoder
|
|
|
self.decoder = decoder
|
|
|
@@ -447,7 +443,6 @@ class RegressionMatcher(nn.Module):
|
|
|
self.upsample_res = (14*16*6, 14*16*6)
|
|
|
self.symmetric = symmetric
|
|
|
self.sample_thresh = 0.05
|
|
|
- self.recrop_upsample = recrop_upsample
|
|
|
|
|
|
def get_output_resolution(self):
|
|
|
if not self.upsample_preds:
|
|
|
@@ -589,32 +584,12 @@ class RegressionMatcher(nn.Module):
|
|
|
return torch.cat((inds_A, inds_B),dim=-1)
|
|
|
else:
|
|
|
return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
|
|
|
-
|
|
|
- def get_roi(self, certainty, W, H, thr = 0.025):
|
|
|
- raise NotImplementedError("WIP, disable for now")
|
|
|
- hs,ws = certainty.shape
|
|
|
- certainty = certainty/certainty.sum(dim=(-1,-2))
|
|
|
- cum_certainty_w = certainty.cumsum(dim=-1).sum(dim=-2)
|
|
|
- cum_certainty_h = certainty.cumsum(dim=-2).sum(dim=-1)
|
|
|
- print(cum_certainty_w)
|
|
|
- print(torch.min(torch.nonzero(cum_certainty_w > thr)))
|
|
|
- print(torch.min(torch.nonzero(cum_certainty_w < thr)))
|
|
|
- left = int(W/ws * torch.min(torch.nonzero(cum_certainty_w > thr)))
|
|
|
- right = int(W/ws * torch.max(torch.nonzero(cum_certainty_w < 1 - thr)))
|
|
|
- top = int(H/hs * torch.min(torch.nonzero(cum_certainty_h > thr)))
|
|
|
- bottom = int(H/hs * torch.max(torch.nonzero(cum_certainty_h < 1 - thr)))
|
|
|
- print(left, right, top, bottom)
|
|
|
- return left, top, right, bottom
|
|
|
-
|
|
|
- def recrop(self, certainty, image_path):
|
|
|
- roi = self.get_roi(certainty, *Image.open(image_path).size)
|
|
|
- return Image.open(image_path).convert("RGB").crop(roi)
|
|
|
-
|
|
|
+
|
|
|
@torch.inference_mode()
|
|
|
def match(
|
|
|
self,
|
|
|
- im_A_path: Union[str, os.PathLike, Image.Image],
|
|
|
- im_B_path: Union[str, os.PathLike, Image.Image],
|
|
|
+ im_A_path,
|
|
|
+ im_B_path,
|
|
|
*args,
|
|
|
batched=False,
|
|
|
device = None,
|
|
|
@@ -674,14 +649,7 @@ class RegressionMatcher(nn.Module):
|
|
|
test_transform = get_tuple_transform_ops(
|
|
|
resize=(hs, ws), normalize=True
|
|
|
)
|
|
|
- if self.recrop_upsample:
|
|
|
- raise NotImplementedError("recrop_upsample not implemented")
|
|
|
- certainty = corresps[finest_scale]["certainty"]
|
|
|
- print(certainty.shape)
|
|
|
- im_A = self.recrop(certainty[0,0], im_A_path)
|
|
|
- im_B = self.recrop(certainty[1,0], im_B_path)
|
|
|
- #TODO: need to adjust corresps when doing this
|
|
|
- im_A, im_B = test_transform((im_A, im_B))
|
|
|
+ im_A, im_B = test_transform((Image.open(im_A_path).convert('RGB'), Image.open(im_B_path).convert('RGB')))
|
|
|
im_A, im_B = im_A[None].to(device), im_B[None].to(device)
|
|
|
scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
|
|
|
batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
|