|
|
@@ -0,0 +1,649 @@
|
|
|
+import os
|
|
|
+import math
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+import torch.nn.functional as F
|
|
|
+from einops import rearrange
|
|
|
+import warnings
|
|
|
+from warnings import warn
|
|
|
+
|
|
|
+import roma
|
|
|
+from roma.utils import get_tuple_transform_ops
|
|
|
+from roma.utils.local_correlation import local_correlation
|
|
|
+from roma.utils.utils import cls_to_flow_refine
|
|
|
+from roma.utils.kde import kde
|
|
|
+
|
|
|
+class ConvRefiner(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ in_dim=6,
|
|
|
+ hidden_dim=16,
|
|
|
+ out_dim=2,
|
|
|
+ dw=False,
|
|
|
+ kernel_size=5,
|
|
|
+ hidden_blocks=3,
|
|
|
+ displacement_emb = None,
|
|
|
+ displacement_emb_dim = None,
|
|
|
+ local_corr_radius = None,
|
|
|
+ corr_in_other = None,
|
|
|
+ no_im_B_fm = False,
|
|
|
+ amp = False,
|
|
|
+ concat_logits = False,
|
|
|
+ use_bias_block_1 = True,
|
|
|
+ use_cosine_corr = False,
|
|
|
+ disable_local_corr_grad = False,
|
|
|
+ is_classifier = False,
|
|
|
+ sample_mode = "bilinear",
|
|
|
+ norm_type = nn.BatchNorm2d,
|
|
|
+ bn_momentum = 0.1,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ self.bn_momentum = bn_momentum
|
|
|
+ self.block1 = self.create_block(
|
|
|
+ in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1,
|
|
|
+ )
|
|
|
+ self.hidden_blocks = nn.Sequential(
|
|
|
+ *[
|
|
|
+ self.create_block(
|
|
|
+ hidden_dim,
|
|
|
+ hidden_dim,
|
|
|
+ dw=dw,
|
|
|
+ kernel_size=kernel_size,
|
|
|
+ norm_type=norm_type,
|
|
|
+ )
|
|
|
+ for hb in range(hidden_blocks)
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ self.hidden_blocks = self.hidden_blocks
|
|
|
+ self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
|
|
|
+ if displacement_emb:
|
|
|
+ self.has_displacement_emb = True
|
|
|
+ self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
|
|
|
+ else:
|
|
|
+ self.has_displacement_emb = False
|
|
|
+ self.local_corr_radius = local_corr_radius
|
|
|
+ self.corr_in_other = corr_in_other
|
|
|
+ self.no_im_B_fm = no_im_B_fm
|
|
|
+ self.amp = amp
|
|
|
+ self.concat_logits = concat_logits
|
|
|
+ self.use_cosine_corr = use_cosine_corr
|
|
|
+ self.disable_local_corr_grad = disable_local_corr_grad
|
|
|
+ self.is_classifier = is_classifier
|
|
|
+ self.sample_mode = sample_mode
|
|
|
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
|
+
|
|
|
+ def create_block(
|
|
|
+ self,
|
|
|
+ in_dim,
|
|
|
+ out_dim,
|
|
|
+ dw=False,
|
|
|
+ kernel_size=5,
|
|
|
+ bias = True,
|
|
|
+ norm_type = nn.BatchNorm2d,
|
|
|
+ ):
|
|
|
+ num_groups = 1 if not dw else in_dim
|
|
|
+ if dw:
|
|
|
+ assert (
|
|
|
+ out_dim % in_dim == 0
|
|
|
+ ), "outdim must be divisible by indim for depthwise"
|
|
|
+ conv1 = nn.Conv2d(
|
|
|
+ in_dim,
|
|
|
+ out_dim,
|
|
|
+ kernel_size=kernel_size,
|
|
|
+ stride=1,
|
|
|
+ padding=kernel_size // 2,
|
|
|
+ groups=num_groups,
|
|
|
+ bias=bias,
|
|
|
+ )
|
|
|
+ norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim)
|
|
|
+ relu = nn.ReLU(inplace=True)
|
|
|
+ conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
|
|
|
+ return nn.Sequential(conv1, norm, relu, conv2)
|
|
|
+
|
|
|
+ def forward(self, x, y, flow, scale_factor = 1, logits = None):
|
|
|
+ b,c,hs,ws = x.shape
|
|
|
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
|
|
+ with torch.no_grad():
|
|
|
+ x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode)
|
|
|
+ if self.has_displacement_emb:
|
|
|
+ im_A_coords = torch.meshgrid(
|
|
|
+ (
|
|
|
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
|
|
|
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
+ im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
|
|
|
+ in_displacement = flow-im_A_coords
|
|
|
+ emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement)
|
|
|
+ if self.local_corr_radius:
|
|
|
+ if self.corr_in_other:
|
|
|
+ # Corr in other means take a kxk grid around the predicted coordinate in other image
|
|
|
+ local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow,
|
|
|
+ sample_mode = self.sample_mode)
|
|
|
+ else:
|
|
|
+ raise NotImplementedError("Local corr in own frame should not be used.")
|
|
|
+ if self.no_im_B_fm:
|
|
|
+ x_hat = torch.zeros_like(x)
|
|
|
+ d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
|
|
|
+ else:
|
|
|
+ d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
|
|
|
+ else:
|
|
|
+ if self.no_im_B_fm:
|
|
|
+ x_hat = torch.zeros_like(x)
|
|
|
+ d = torch.cat((x, x_hat), dim=1)
|
|
|
+ if self.concat_logits:
|
|
|
+ d = torch.cat((d, logits), dim=1)
|
|
|
+ d = self.block1(d)
|
|
|
+ d = self.hidden_blocks(d)
|
|
|
+ d = self.out_conv(d.float())
|
|
|
+ displacement, certainty = d[:, :-1], d[:, -1:]
|
|
|
+ return displacement, certainty
|
|
|
+
|
|
|
+class CosKernel(nn.Module): # similar to softmax kernel
|
|
|
+ def __init__(self, T, learn_temperature=False):
|
|
|
+ super().__init__()
|
|
|
+ self.learn_temperature = learn_temperature
|
|
|
+ if self.learn_temperature:
|
|
|
+ self.T = nn.Parameter(torch.tensor(T))
|
|
|
+ else:
|
|
|
+ self.T = T
|
|
|
+
|
|
|
+ def __call__(self, x, y, eps=1e-6):
|
|
|
+ c = torch.einsum("bnd,bmd->bnm", x, y) / (
|
|
|
+ x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
|
|
|
+ )
|
|
|
+ if self.learn_temperature:
|
|
|
+ T = self.T.abs() + 0.01
|
|
|
+ else:
|
|
|
+ T = torch.tensor(self.T, device=c.device)
|
|
|
+ K = ((c - 1.0) / T).exp()
|
|
|
+ return K
|
|
|
+
|
|
|
+class GP(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ kernel,
|
|
|
+ T=1,
|
|
|
+ learn_temperature=False,
|
|
|
+ only_attention=False,
|
|
|
+ gp_dim=64,
|
|
|
+ basis="fourier",
|
|
|
+ covar_size=5,
|
|
|
+ only_nearest_neighbour=False,
|
|
|
+ sigma_noise=0.1,
|
|
|
+ no_cov=False,
|
|
|
+ predict_features = False,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ self.K = kernel(T=T, learn_temperature=learn_temperature)
|
|
|
+ self.sigma_noise = sigma_noise
|
|
|
+ self.covar_size = covar_size
|
|
|
+ self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
|
|
|
+ self.only_attention = only_attention
|
|
|
+ self.only_nearest_neighbour = only_nearest_neighbour
|
|
|
+ self.basis = basis
|
|
|
+ self.no_cov = no_cov
|
|
|
+ self.dim = gp_dim
|
|
|
+ self.predict_features = predict_features
|
|
|
+
|
|
|
+ def get_local_cov(self, cov):
|
|
|
+ K = self.covar_size
|
|
|
+ b, h, w, h, w = cov.shape
|
|
|
+ hw = h * w
|
|
|
+ cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
|
|
|
+ delta = torch.stack(
|
|
|
+ torch.meshgrid(
|
|
|
+ torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
|
|
|
+ ),
|
|
|
+ dim=-1,
|
|
|
+ )
|
|
|
+ positions = torch.stack(
|
|
|
+ torch.meshgrid(
|
|
|
+ torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
|
|
|
+ ),
|
|
|
+ dim=-1,
|
|
|
+ )
|
|
|
+ neighbours = positions[:, :, None, None, :] + delta[None, :, :]
|
|
|
+ points = torch.arange(hw)[:, None].expand(hw, K**2)
|
|
|
+ local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
|
|
|
+ :,
|
|
|
+ points.flatten(),
|
|
|
+ neighbours[..., 0].flatten(),
|
|
|
+ neighbours[..., 1].flatten(),
|
|
|
+ ].reshape(b, h, w, K**2)
|
|
|
+ return local_cov
|
|
|
+
|
|
|
+ def reshape(self, x):
|
|
|
+ return rearrange(x, "b d h w -> b (h w) d")
|
|
|
+
|
|
|
+ def project_to_basis(self, x):
|
|
|
+ if self.basis == "fourier":
|
|
|
+ return torch.cos(8 * math.pi * self.pos_conv(x))
|
|
|
+ elif self.basis == "linear":
|
|
|
+ return self.pos_conv(x)
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ "No other bases other than fourier and linear currently im_Bed in public release"
|
|
|
+ )
|
|
|
+
|
|
|
+ def get_pos_enc(self, y):
|
|
|
+ b, c, h, w = y.shape
|
|
|
+ coarse_coords = torch.meshgrid(
|
|
|
+ (
|
|
|
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
|
|
|
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
|
|
|
+ None
|
|
|
+ ].expand(b, h, w, 2)
|
|
|
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
|
|
|
+ coarse_embedded_coords = self.project_to_basis(coarse_coords)
|
|
|
+ return coarse_embedded_coords
|
|
|
+
|
|
|
+ def forward(self, x, y, **kwargs):
|
|
|
+ b, c, h1, w1 = x.shape
|
|
|
+ b, c, h2, w2 = y.shape
|
|
|
+ f = self.get_pos_enc(y)
|
|
|
+ b, d, h2, w2 = f.shape
|
|
|
+ x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f)
|
|
|
+ K_xx = self.K(x, x)
|
|
|
+ K_yy = self.K(y, y)
|
|
|
+ K_xy = self.K(x, y)
|
|
|
+ K_yx = K_xy.permute(0, 2, 1)
|
|
|
+ sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
|
|
|
+ with warnings.catch_warnings():
|
|
|
+ K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
|
|
|
+
|
|
|
+ mu_x = K_xy.matmul(K_yy_inv.matmul(f))
|
|
|
+ mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
|
|
|
+ if not self.no_cov:
|
|
|
+ cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
|
|
|
+ cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
|
|
|
+ local_cov_x = self.get_local_cov(cov_x)
|
|
|
+ local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
|
|
|
+ gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
|
|
|
+ else:
|
|
|
+ gp_feats = mu_x
|
|
|
+ return gp_feats
|
|
|
+
|
|
|
+class Decoder(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
|
|
|
+ num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
|
|
|
+ flow_upsample_mode = "bilinear"
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ self.embedding_decoder = embedding_decoder
|
|
|
+ self.num_refinement_steps_per_scale = num_refinement_steps_per_scale
|
|
|
+ self.gps = gps
|
|
|
+ self.proj = proj
|
|
|
+ self.conv_refiner = conv_refiner
|
|
|
+ self.detach = detach
|
|
|
+ if pos_embeddings is None:
|
|
|
+ self.pos_embeddings = {}
|
|
|
+ else:
|
|
|
+ self.pos_embeddings = pos_embeddings
|
|
|
+ if scales == "all":
|
|
|
+ self.scales = ["32", "16", "8", "4", "2", "1"]
|
|
|
+ else:
|
|
|
+ self.scales = scales
|
|
|
+ self.warp_noise_std = warp_noise_std
|
|
|
+ self.refine_init = 4
|
|
|
+ self.displacement_dropout_p = displacement_dropout_p
|
|
|
+ self.gm_warp_dropout_p = gm_warp_dropout_p
|
|
|
+ self.flow_upsample_mode = flow_upsample_mode
|
|
|
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
|
+
|
|
|
+ def get_placeholder_flow(self, b, h, w, device):
|
|
|
+ coarse_coords = torch.meshgrid(
|
|
|
+ (
|
|
|
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
|
|
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
|
|
|
+ None
|
|
|
+ ].expand(b, h, w, 2)
|
|
|
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
|
|
|
+ return coarse_coords
|
|
|
+
|
|
|
+ def get_positional_embedding(self, b, h ,w, device):
|
|
|
+ coarse_coords = torch.meshgrid(
|
|
|
+ (
|
|
|
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
|
|
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
|
|
|
+ None
|
|
|
+ ].expand(b, h, w, 2)
|
|
|
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
|
|
|
+ coarse_embedded_coords = self.pos_embedding(coarse_coords)
|
|
|
+ return coarse_embedded_coords
|
|
|
+
|
|
|
+ def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1):
|
|
|
+ coarse_scales = self.embedding_decoder.scales()
|
|
|
+ all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
|
|
|
+ sizes = {scale: f1[scale].shape[-2:] for scale in f1}
|
|
|
+ h, w = sizes[1]
|
|
|
+ b = f1[1].shape[0]
|
|
|
+ device = f1[1].device
|
|
|
+ coarsest_scale = int(all_scales[0])
|
|
|
+ old_stuff = torch.zeros(
|
|
|
+ b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
|
|
|
+ )
|
|
|
+ corresps = {}
|
|
|
+ if not upsample:
|
|
|
+ flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
|
|
|
+ certainty = 0.0
|
|
|
+ else:
|
|
|
+ flow = F.interpolate(
|
|
|
+ flow,
|
|
|
+ size=sizes[coarsest_scale],
|
|
|
+ align_corners=False,
|
|
|
+ mode="bilinear",
|
|
|
+ )
|
|
|
+ certainty = F.interpolate(
|
|
|
+ certainty,
|
|
|
+ size=sizes[coarsest_scale],
|
|
|
+ align_corners=False,
|
|
|
+ mode="bilinear",
|
|
|
+ )
|
|
|
+ displacement = 0.0
|
|
|
+ for new_scale in all_scales:
|
|
|
+ ins = int(new_scale)
|
|
|
+ corresps[ins] = {}
|
|
|
+ f1_s, f2_s = f1[ins], f2[ins]
|
|
|
+ if new_scale in self.proj:
|
|
|
+ with torch.autocast("cuda", self.amp_dtype):
|
|
|
+ f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
|
|
|
+
|
|
|
+ if ins in coarse_scales:
|
|
|
+ old_stuff = F.interpolate(
|
|
|
+ old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
|
|
|
+ )
|
|
|
+ gp_posterior = self.gps[new_scale](f1_s, f2_s)
|
|
|
+ gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
|
|
|
+ gp_posterior, f1_s, old_stuff, new_scale
|
|
|
+ )
|
|
|
+
|
|
|
+ if self.embedding_decoder.is_classifier:
|
|
|
+ flow = cls_to_flow_refine(
|
|
|
+ gm_warp_or_cls,
|
|
|
+ ).permute(0,3,1,2)
|
|
|
+ corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
|
|
|
+ else:
|
|
|
+ corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
|
|
|
+ flow = gm_warp_or_cls.detach()
|
|
|
+
|
|
|
+ if new_scale in self.conv_refiner:
|
|
|
+ corresps[ins].update({"flow_pre_delta": flow}) if self.training else None
|
|
|
+ delta_flow, delta_certainty = self.conv_refiner[new_scale](
|
|
|
+ f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty,
|
|
|
+ )
|
|
|
+ corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None
|
|
|
+ displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w),
|
|
|
+ delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,)
|
|
|
+ flow = flow + displacement
|
|
|
+ certainty = (
|
|
|
+ certainty + delta_certainty
|
|
|
+ ) # predict both certainty and displacement
|
|
|
+ corresps[ins].update({
|
|
|
+ "certainty": certainty,
|
|
|
+ "flow": flow,
|
|
|
+ })
|
|
|
+ if new_scale != "1":
|
|
|
+ flow = F.interpolate(
|
|
|
+ flow,
|
|
|
+ size=sizes[ins // 2],
|
|
|
+ mode=self.flow_upsample_mode,
|
|
|
+ )
|
|
|
+ certainty = F.interpolate(
|
|
|
+ certainty,
|
|
|
+ size=sizes[ins // 2],
|
|
|
+ mode=self.flow_upsample_mode,
|
|
|
+ )
|
|
|
+ if self.detach:
|
|
|
+ flow = flow.detach()
|
|
|
+ certainty = certainty.detach()
|
|
|
+ #torch.cuda.empty_cache()
|
|
|
+ return corresps
|
|
|
+
|
|
|
+
|
|
|
+class RegressionMatcher(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ encoder,
|
|
|
+ decoder,
|
|
|
+ h=448,
|
|
|
+ w=448,
|
|
|
+ sample_mode = "threshold",
|
|
|
+ upsample_preds = False,
|
|
|
+ symmetric = False,
|
|
|
+ name = None,
|
|
|
+ attenuate_cert = None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ self.attenuate_cert = attenuate_cert
|
|
|
+ self.encoder = encoder
|
|
|
+ self.decoder = decoder
|
|
|
+ self.name = name
|
|
|
+ self.w_resized = w
|
|
|
+ self.h_resized = h
|
|
|
+ self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
|
|
|
+ self.sample_mode = sample_mode
|
|
|
+ self.upsample_preds = upsample_preds
|
|
|
+ self.upsample_res = (14*16*6, 14*16*6)
|
|
|
+ self.symmetric = symmetric
|
|
|
+ self.sample_thresh = 0.05
|
|
|
+
|
|
|
+ def get_output_resolution(self):
|
|
|
+ if not self.upsample_preds:
|
|
|
+ return self.h_resized, self.w_resized
|
|
|
+ else:
|
|
|
+ return self.upsample_res
|
|
|
+
|
|
|
+ def extract_backbone_features(self, batch, batched = True, upsample = False):
|
|
|
+ x_q = batch["im_A"]
|
|
|
+ x_s = batch["im_B"]
|
|
|
+ if batched:
|
|
|
+ X = torch.cat((x_q, x_s), dim = 0)
|
|
|
+ feature_pyramid = self.encoder(X, upsample = upsample)
|
|
|
+ else:
|
|
|
+ feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample)
|
|
|
+ return feature_pyramid
|
|
|
+
|
|
|
+ def sample(
|
|
|
+ self,
|
|
|
+ matches,
|
|
|
+ certainty,
|
|
|
+ num=10000,
|
|
|
+ ):
|
|
|
+ if "threshold" in self.sample_mode:
|
|
|
+ upper_thresh = self.sample_thresh
|
|
|
+ certainty = certainty.clone()
|
|
|
+ certainty[certainty > upper_thresh] = 1
|
|
|
+ matches, certainty = (
|
|
|
+ matches.reshape(-1, 4),
|
|
|
+ certainty.reshape(-1),
|
|
|
+ )
|
|
|
+ expansion_factor = 4 if "balanced" in self.sample_mode else 1
|
|
|
+ good_samples = torch.multinomial(certainty,
|
|
|
+ num_samples = min(expansion_factor*num, len(certainty)),
|
|
|
+ replacement=False)
|
|
|
+ good_matches, good_certainty = matches[good_samples], certainty[good_samples]
|
|
|
+ if "balanced" not in self.sample_mode:
|
|
|
+ return good_matches, good_certainty
|
|
|
+ density = kde(good_matches, std=0.1)
|
|
|
+ p = 1 / (density+1)
|
|
|
+ p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
|
|
|
+ balanced_samples = torch.multinomial(p,
|
|
|
+ num_samples = min(num,len(good_certainty)),
|
|
|
+ replacement=False)
|
|
|
+ return good_matches[balanced_samples], good_certainty[balanced_samples]
|
|
|
+
|
|
|
+ def forward(self, batch, batched = True, upsample = False, scale_factor = 1):
|
|
|
+ feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample)
|
|
|
+ if batched:
|
|
|
+ f_q_pyramid = {
|
|
|
+ scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
|
|
|
+ }
|
|
|
+ f_s_pyramid = {
|
|
|
+ scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ f_q_pyramid, f_s_pyramid = feature_pyramid
|
|
|
+ corresps = self.decoder(f_q_pyramid,
|
|
|
+ f_s_pyramid,
|
|
|
+ upsample = upsample,
|
|
|
+ **(batch["corresps"] if "corresps" in batch else {}),
|
|
|
+ scale_factor=scale_factor)
|
|
|
+
|
|
|
+ return corresps
|
|
|
+
|
|
|
+ def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1):
|
|
|
+ feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample)
|
|
|
+ f_q_pyramid = feature_pyramid
|
|
|
+ f_s_pyramid = {
|
|
|
+ scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0)
|
|
|
+ for scale, f_scale in feature_pyramid.items()
|
|
|
+ }
|
|
|
+ corresps = self.decoder(f_q_pyramid,
|
|
|
+ f_s_pyramid,
|
|
|
+ upsample = upsample,
|
|
|
+ **(batch["corresps"] if "corresps" in batch else {}),
|
|
|
+ scale_factor=scale_factor)
|
|
|
+ return corresps
|
|
|
+
|
|
|
+ def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
|
|
|
+ kpts_A, kpts_B = matches[...,:2], matches[...,2:]
|
|
|
+ kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
|
|
|
+ kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
|
|
|
+ return kpts_A, kpts_B
|
|
|
+
|
|
|
+ def match(
|
|
|
+ self,
|
|
|
+ im_A_path,
|
|
|
+ im_B_path,
|
|
|
+ *args,
|
|
|
+ batched=False,
|
|
|
+ device = None,
|
|
|
+ ):
|
|
|
+ if device is None:
|
|
|
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
+ from PIL import Image
|
|
|
+ if isinstance(im_A_path, (str, os.PathLike)):
|
|
|
+ im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
|
|
|
+ else:
|
|
|
+ # Assume its not a path
|
|
|
+ im_A, im_B = im_A_path, im_B_path
|
|
|
+ symmetric = self.symmetric
|
|
|
+ self.train(False)
|
|
|
+ with torch.no_grad():
|
|
|
+ if not batched:
|
|
|
+ b = 1
|
|
|
+ w, h = im_A.size
|
|
|
+ w2, h2 = im_B.size
|
|
|
+ # Get images in good format
|
|
|
+ ws = self.w_resized
|
|
|
+ hs = self.h_resized
|
|
|
+
|
|
|
+ test_transform = get_tuple_transform_ops(
|
|
|
+ resize=(hs, ws), normalize=True, clahe = False
|
|
|
+ )
|
|
|
+ im_A, im_B = test_transform((im_A, im_B))
|
|
|
+ batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
|
|
|
+ else:
|
|
|
+ b, c, h, w = im_A.shape
|
|
|
+ b, c, h2, w2 = im_B.shape
|
|
|
+ assert w == w2 and h == h2, "For batched images we assume same size"
|
|
|
+ batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
|
|
|
+ if h != self.h_resized or self.w_resized != w:
|
|
|
+ warn("Model resolution and batch resolution differ, may produce unexpected results")
|
|
|
+ hs, ws = h, w
|
|
|
+ finest_scale = 1
|
|
|
+ # Run matcher
|
|
|
+ if symmetric:
|
|
|
+ corresps = self.forward_symmetric(batch)
|
|
|
+ else:
|
|
|
+ corresps = self.forward(batch, batched = True)
|
|
|
+
|
|
|
+ if self.upsample_preds:
|
|
|
+ hs, ws = self.upsample_res
|
|
|
+
|
|
|
+ if self.attenuate_cert:
|
|
|
+ low_res_certainty = F.interpolate(
|
|
|
+ corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
|
|
|
+ )
|
|
|
+ cert_clamp = 0
|
|
|
+ factor = 0.5
|
|
|
+ low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
|
|
|
+
|
|
|
+ if self.upsample_preds:
|
|
|
+ finest_corresps = corresps[finest_scale]
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+ test_transform = get_tuple_transform_ops(
|
|
|
+ resize=(hs, ws), normalize=True
|
|
|
+ )
|
|
|
+ im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
|
|
|
+ im_A, im_B = test_transform((im_A, im_B))
|
|
|
+ im_A, im_B = im_A[None].to(device), im_B[None].to(device)
|
|
|
+ scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
|
|
|
+ batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
|
|
|
+ if symmetric:
|
|
|
+ corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
|
|
|
+ else:
|
|
|
+ corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
|
|
|
+
|
|
|
+ im_A_to_im_B = corresps[finest_scale]["flow"]
|
|
|
+ certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
|
|
|
+ if finest_scale != 1:
|
|
|
+ im_A_to_im_B = F.interpolate(
|
|
|
+ im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
|
|
|
+ )
|
|
|
+ certainty = F.interpolate(
|
|
|
+ certainty, size=(hs, ws), align_corners=False, mode="bilinear"
|
|
|
+ )
|
|
|
+ im_A_to_im_B = im_A_to_im_B.permute(
|
|
|
+ 0, 2, 3, 1
|
|
|
+ )
|
|
|
+ # Create im_A meshgrid
|
|
|
+ im_A_coords = torch.meshgrid(
|
|
|
+ (
|
|
|
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
|
|
|
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
+ im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
|
|
|
+ certainty = certainty.sigmoid() # logits -> probs
|
|
|
+ im_A_coords = im_A_coords.permute(0, 2, 3, 1)
|
|
|
+ if (im_A_to_im_B.abs() > 1).any() and True:
|
|
|
+ wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
|
|
|
+ certainty[wrong[:,None]] = 0
|
|
|
+ im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
|
|
|
+ if symmetric:
|
|
|
+ A_to_B, B_to_A = im_A_to_im_B.chunk(2)
|
|
|
+ q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
|
|
|
+ im_B_coords = im_A_coords
|
|
|
+ s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
|
|
|
+ warp = torch.cat((q_warp, s_warp),dim=2)
|
|
|
+ certainty = torch.cat(certainty.chunk(2), dim=3)
|
|
|
+ else:
|
|
|
+ warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
|
|
|
+ if batched:
|
|
|
+ return (
|
|
|
+ warp,
|
|
|
+ certainty[:, 0]
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return (
|
|
|
+ warp[0],
|
|
|
+ certainty[0, 0],
|
|
|
+ )
|
|
|
+
|