| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001 |
- import os
- import math
- import sys
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from einops import rearrange
- from warnings import warn
- from PIL import Image
- from romatch.utils import get_tuple_transform_ops
- from romatch.utils.local_correlation import local_correlation
- from romatch.utils.utils import (
- check_rgb,
- cls_to_flow_refine,
- get_autocast_params,
- check_not_i16,
- )
- from romatch.utils.kde import kde
- from romatch.models.encoders import CNNandDinov2
- class ConvRefiner(nn.Module):
- def __init__(
- self,
- in_dim=6,
- hidden_dim=16,
- out_dim=2,
- dw=False,
- kernel_size=5,
- hidden_blocks=3,
- displacement_emb=None,
- displacement_emb_dim=None,
- local_corr_radius=None,
- corr_in_other=None,
- no_im_B_fm=False,
- amp=False,
- concat_logits=False,
- use_bias_block_1=True,
- use_cosine_corr=False,
- disable_local_corr_grad=False,
- is_classifier=False,
- sample_mode="bilinear",
- norm_type=nn.BatchNorm2d,
- bn_momentum=0.1,
- amp_dtype=torch.float16,
- use_custom_corr=False,
- ):
- super().__init__()
- if sys.platform != "linux":
- warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
- use_custom_corr = False
- self.bn_momentum = bn_momentum
- self.block1 = self.create_block(
- in_dim,
- hidden_dim,
- dw=dw,
- kernel_size=kernel_size,
- bias=use_bias_block_1,
- )
- self.hidden_blocks = nn.Sequential(
- *[
- self.create_block(
- hidden_dim,
- hidden_dim,
- dw=dw,
- kernel_size=kernel_size,
- norm_type=norm_type,
- )
- for hb in range(hidden_blocks)
- ]
- )
- self.hidden_blocks = self.hidden_blocks
- self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
- if displacement_emb:
- self.has_displacement_emb = True
- self.disp_emb = nn.Conv2d(2, displacement_emb_dim, 1, 1, 0)
- else:
- self.has_displacement_emb = False
- self.local_corr_radius = local_corr_radius
- self.corr_in_other = corr_in_other
- self.no_im_B_fm = no_im_B_fm
- self.amp = amp
- self.concat_logits = concat_logits
- self.use_cosine_corr = use_cosine_corr
- self.disable_local_corr_grad = disable_local_corr_grad
- self.is_classifier = is_classifier
- self.sample_mode = sample_mode
- self.amp_dtype = amp_dtype
- self.use_custom_corr = use_custom_corr
- def create_block(
- self,
- in_dim,
- out_dim,
- dw=False,
- kernel_size=5,
- bias=True,
- norm_type=nn.BatchNorm2d,
- ):
- num_groups = 1 if not dw else in_dim
- if dw:
- assert out_dim % in_dim == 0, (
- "outdim must be divisible by indim for depthwise"
- )
- conv1 = nn.Conv2d(
- in_dim,
- out_dim,
- kernel_size=kernel_size,
- stride=1,
- padding=kernel_size // 2,
- groups=num_groups,
- bias=bias,
- )
- norm = (
- norm_type(out_dim, momentum=self.bn_momentum)
- if norm_type is nn.BatchNorm2d
- else norm_type(num_channels=out_dim)
- )
- relu = nn.ReLU(inplace=True)
- conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
- return nn.Sequential(conv1, norm, relu, conv2)
- def forward(self, x, y, warp, scale_factor=1, logits=None):
- b, c, hs, ws = x.shape
- autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(
- x.device, enabled=self.amp, dtype=self.amp_dtype
- )
- with torch.autocast(
- autocast_device, enabled=autocast_enabled, dtype=autocast_dtype
- ):
- x_hat = F.grid_sample(
- y, warp.permute(0, 2, 3, 1), align_corners=False, mode=self.sample_mode
- )
- if self.has_displacement_emb:
- im_A_coords = torch.meshgrid(
- (
- torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
- torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
- ),
- indexing="ij",
- )
- im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
- im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
- in_displacement = warp - im_A_coords
- emb_in_displacement = self.disp_emb(
- 40 / 32 * scale_factor * in_displacement
- )
- if self.local_corr_radius:
- if self.corr_in_other:
- # Corr in other means take a kxk grid around the predicted coordinate in other image
- local_corr = local_correlation(
- x,
- y,
- self.local_corr_radius,
- warp,
- sample_mode=self.sample_mode,
- use_custom_corr=self.use_custom_corr,
- )
- else:
- raise NotImplementedError(
- "Local corr in own frame should not be used."
- )
- if self.no_im_B_fm:
- x_hat = torch.zeros_like(x)
- d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
- else:
- d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
- else:
- if self.no_im_B_fm:
- x_hat = torch.zeros_like(x)
- d = torch.cat((x, x_hat), dim=1)
- if self.concat_logits:
- d = torch.cat((d, logits), dim=1)
- # pad d if needed
- channel_d = d.shape[1]
- channel_block1 = self.block1[0].in_channels
- if channel_d != channel_block1:
- d = F.pad(d, (0, 0, 0, 0, 0, channel_block1 - channel_d))
- d = self.block1(d)
- d = self.hidden_blocks(d)
- d = self.out_conv(d.float())
- displacement, certainty = d[:, :-1], d[:, -1:]
- return displacement, certainty
- class CosKernel(nn.Module): # similar to softmax kernel
- def __init__(self, T, learn_temperature=False):
- super().__init__()
- self.learn_temperature = learn_temperature
- if self.learn_temperature:
- self.T = nn.Parameter(torch.tensor(T))
- else:
- self.T = T
- def __call__(self, x, y, eps=1e-6):
- c = torch.einsum("bnd,bmd->bnm", x, y) / (
- x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
- )
- if self.learn_temperature:
- T = self.T.abs() + 0.01
- else:
- T = torch.tensor(self.T, device=c.device)
- K = ((c - 1.0) / T).exp()
- return K
- class GP(nn.Module):
- def __init__(
- self,
- kernel,
- T=1,
- learn_temperature=False,
- only_attention=False,
- gp_dim=64,
- basis="fourier",
- covar_size=5,
- only_nearest_neighbour=False,
- sigma_noise=0.1,
- no_cov=False,
- predict_features=False,
- ):
- super().__init__()
- self.K = kernel(T=T, learn_temperature=learn_temperature)
- self.sigma_noise = sigma_noise
- self.covar_size = covar_size
- self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
- self.only_attention = only_attention
- self.only_nearest_neighbour = only_nearest_neighbour
- self.basis = basis
- self.no_cov = no_cov
- self.dim = gp_dim
- self.predict_features = predict_features
- def get_local_cov(self, cov):
- K = self.covar_size
- b, h, w, h, w = cov.shape
- hw = h * w
- cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
- delta = torch.stack(
- torch.meshgrid(
- torch.arange(-(K // 2), K // 2 + 1),
- torch.arange(-(K // 2), K // 2 + 1),
- indexing="ij",
- ),
- dim=-1,
- )
- positions = torch.stack(
- torch.meshgrid(
- torch.arange(K // 2, h + K // 2),
- torch.arange(K // 2, w + K // 2),
- indexing="ij",
- ),
- dim=-1,
- )
- neighbours = positions[:, :, None, None, :] + delta[None, :, :]
- points = torch.arange(hw)[:, None].expand(hw, K**2)
- local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
- :,
- points.flatten(),
- neighbours[..., 0].flatten(),
- neighbours[..., 1].flatten(),
- ].reshape(b, h, w, K**2)
- return local_cov
- def reshape(self, x):
- return rearrange(x, "b d h w -> b (h w) d")
- def project_to_basis(self, x):
- if self.basis == "fourier":
- return torch.cos(8 * math.pi * self.pos_conv(x))
- elif self.basis == "linear":
- return self.pos_conv(x)
- else:
- raise ValueError(
- "No other bases other than fourier and linear currently im_Bed in public release"
- )
- def get_pos_enc(self, y):
- b, c, h, w = y.shape
- coarse_coords = torch.meshgrid(
- (
- torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
- torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
- ),
- indexing="ij",
- )
- coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
- None
- ].expand(b, h, w, 2)
- coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
- coarse_embedded_coords = self.project_to_basis(coarse_coords)
- return coarse_embedded_coords
- def forward(self, x, y, **kwargs):
- b, c, h1, w1 = x.shape
- b, c, h2, w2 = y.shape
- f = self.get_pos_enc(y)
- b, d, h2, w2 = f.shape
- x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f)
- # K_xx = self.K(x, x)
- K_yy = self.K(y, y)
- K_xy = self.K(x, y)
- K_yx = K_xy.permute(0, 2, 1)
- sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
- if self.training:
- K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
- mu_x = K_xy.matmul(K_yy_inv.matmul(f))
- else:
- # faster inference, possibly also useful for training
- L_t = torch.linalg.cholesky(K_yy + sigma_noise)
- pos_emb = torch.cholesky_solve(f.reshape(b, h2 * w2, d), L_t, upper=False)
- mu_x = K_xy @ pos_emb
- mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
- # if not self.no_cov:
- # cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
- # cov_x = rearrange(
- # cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1
- # )
- # local_cov_x = self.get_local_cov(cov_x)
- # local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
- # gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
- # else:
- gp_feats = mu_x
- return gp_feats
- class Decoder(nn.Module):
- def __init__(
- self,
- embedding_decoder,
- gps,
- proj,
- conv_refiner,
- detach=False,
- scales="all",
- pos_embeddings=None,
- num_refinement_steps_per_scale=1,
- warp_noise_std=0.0,
- displacement_dropout_p=0.0,
- gm_warp_dropout_p=0.0,
- flow_upsample_mode="bilinear",
- amp_dtype=torch.float16,
- ):
- super().__init__()
- self.embedding_decoder = embedding_decoder
- self.num_refinement_steps_per_scale = num_refinement_steps_per_scale
- self.gps = gps
- self.proj = proj
- self.conv_refiner = conv_refiner
- self.detach = detach
- if pos_embeddings is None:
- self.pos_embeddings = {}
- else:
- self.pos_embeddings = pos_embeddings
- if scales == "all":
- self.scales = ["32", "16", "8", "4", "2", "1"]
- else:
- self.scales = scales
- self.warp_noise_std = warp_noise_std
- self.refine_init = 4
- self.displacement_dropout_p = displacement_dropout_p
- self.gm_warp_dropout_p = gm_warp_dropout_p
- self.flow_upsample_mode = flow_upsample_mode
- self.amp_dtype = amp_dtype
- def get_placeholder_flow(self, b, h, w, device):
- coarse_coords = torch.meshgrid(
- (
- torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
- torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
- ),
- indexing="ij",
- )
- coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
- None
- ].expand(b, h, w, 2)
- coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
- return coarse_coords
- def get_positional_embedding(self, b, h, w, device):
- coarse_coords = torch.meshgrid(
- (
- torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
- torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
- ),
- indexing="ij",
- )
- coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
- None
- ].expand(b, h, w, 2)
- coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
- coarse_embedded_coords = self.pos_embedding(coarse_coords)
- return coarse_embedded_coords
- def forward(
- self,
- f1,
- f2,
- gt_warp=None,
- gt_prob=None,
- upsample=False,
- flow=None,
- certainty=None,
- scale_factor=1,
- ):
- coarse_scales = self.embedding_decoder.scales()
- all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
- sizes = {scale: f1[scale].shape[-2:] for scale in f1}
- h, w = sizes[1]
- b = f1[1].shape[0]
- device = f1[1].device
- coarsest_scale = int(all_scales[0])
- old_stuff = torch.zeros(
- b,
- self.embedding_decoder.hidden_dim,
- *sizes[coarsest_scale],
- device=f1[coarsest_scale].device,
- )
- corresps = {}
- if not upsample:
- flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
- certainty = 0.0
- else:
- flow = F.interpolate(
- flow,
- size=sizes[coarsest_scale],
- align_corners=False,
- mode="bilinear",
- )
- certainty = F.interpolate(
- certainty,
- size=sizes[coarsest_scale],
- align_corners=False,
- mode="bilinear",
- )
- displacement = 0.0
- for new_scale in all_scales:
- ins = int(new_scale)
- corresps[ins] = {}
- f1_s, f2_s = f1[ins], f2[ins]
- if new_scale in self.proj:
- autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(
- f1_s.device, str(f1_s) == "cuda", self.amp_dtype
- )
- with torch.autocast(
- autocast_device, enabled=autocast_enabled, dtype=autocast_dtype
- ):
- if not autocast_enabled:
- f1_s, f2_s = f1_s.to(torch.float32), f2_s.to(torch.float32)
- f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
- if ins in coarse_scales:
- old_stuff = F.interpolate(
- old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
- )
- gp_posterior = self.gps[new_scale](f1_s, f2_s)
- gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
- gp_posterior, f1_s, old_stuff, new_scale
- )
- if self.embedding_decoder.is_classifier:
- flow = cls_to_flow_refine(
- gm_warp_or_cls,
- ).permute(0, 3, 1, 2)
- corresps[ins].update(
- {
- "gm_cls": gm_warp_or_cls,
- "gm_certainty": certainty,
- }
- ) if self.training else None
- else:
- corresps[ins].update(
- {
- "gm_flow": gm_warp_or_cls,
- "gm_certainty": certainty,
- }
- ) if self.training else None
- flow = gm_warp_or_cls.detach()
- if new_scale in self.conv_refiner:
- corresps[ins].update(
- {"flow_pre_delta": flow}
- ) if self.training else None
- delta_flow, delta_certainty = self.conv_refiner[new_scale](
- f1_s,
- f2_s,
- flow,
- scale_factor=scale_factor,
- logits=certainty,
- )
- corresps[ins].update(
- {
- "delta_flow": delta_flow,
- }
- ) if self.training else None
- displacement = ins * torch.stack(
- (
- delta_flow[:, 0].float() / (self.refine_init * w),
- delta_flow[:, 1].float() / (self.refine_init * h),
- ),
- dim=1,
- )
- flow = flow + displacement
- certainty = (
- certainty + delta_certainty
- ) # predict both certainty and displacement
- corresps[ins].update(
- {
- "certainty": certainty,
- "flow": flow,
- }
- )
- if new_scale != "1":
- flow = F.interpolate(
- flow,
- size=sizes[ins // 2],
- mode=self.flow_upsample_mode,
- )
- certainty = F.interpolate(
- certainty,
- size=sizes[ins // 2],
- mode=self.flow_upsample_mode,
- )
- if self.detach:
- flow = flow.detach()
- certainty = certainty.detach()
- return corresps
- def _check_input(im_input):
- if isinstance(im_input, (str, os.PathLike)):
- im = Image.open(im_input)
- check_not_i16(im)
- im = im.convert("RGB")
- elif isinstance(im_input, Image.Image):
- check_rgb(im_input)
- im = im_input
- else:
- assert isinstance(im_input, torch.Tensor), (
- "im_input must be a string, path, or PIL image"
- )
- B, C, H, W = im_input.shape
- assert C == 3, "im_input must be a RGB image"
- assert H % 14 == 0, "im_input must be a multiple of 14"
- assert W % 14 == 0, "im_input must be a multiple of 14"
- im = im_input
- return im
- class RegressionMatcher(nn.Module):
- def __init__(
- self,
- encoder: CNNandDinov2,
- decoder: Decoder,
- h=448,
- w=448,
- sample_mode="threshold_balanced",
- upsample_preds=False,
- symmetric=False,
- sample_thresh=0.05,
- name=None,
- attenuate_cert=None,
- upsample_res=None,
- ):
- super().__init__()
- self.attenuate_cert = attenuate_cert
- self.encoder = encoder
- self.decoder = decoder
- self.name = name
- self.w_resized = w
- self.h_resized = h
- self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
- self.sample_mode = sample_mode
- self.upsample_preds = upsample_preds
- self.upsample_res = upsample_res or (14 * 16 * 6, 14 * 16 * 6)
- self.symmetric = symmetric
- self.sample_thresh = sample_thresh
- def get_output_resolution(self):
- if not self.upsample_preds:
- return self.h_resized, self.w_resized
- else:
- return self.upsample_res
- def extract_backbone_features(self, batch, batched=True, upsample=False):
- if 'unique_images' in batch:
- unique_images = batch['unique_images']
- im_AB_idx = batch['im_AB_idx']
- feature_pyramid0 = self.encoder(unique_images, upsample=upsample)
- feature_pyramid = {
- scale: feature_pyramid0[scale][im_AB_idx]
- for scale in feature_pyramid0
- }
- return feature_pyramid
-
- x_q = batch["im_A"]
- x_s = batch["im_B"]
- if batched:
- X = torch.cat((x_q, x_s), dim=0)
- feature_pyramid = self.encoder(X, upsample=upsample)
- else:
- feature_pyramid = (
- self.encoder(x_q, upsample=upsample),
- self.encoder(x_s, upsample=upsample),
- )
- return feature_pyramid
- def sample(
- self,
- matches,
- certainty,
- num=10000,
- ):
- if "threshold" in self.sample_mode:
- upper_thresh = self.sample_thresh
- certainty = certainty.clone()
- certainty[certainty > upper_thresh] = 1
- matches, certainty = (
- matches.reshape(-1, 4),
- certainty.reshape(-1),
- )
- expansion_factor = 4 if "balanced" in self.sample_mode else 1
- good_samples = torch.multinomial(
- certainty,
- num_samples=min(expansion_factor * num, len(certainty)),
- replacement=False,
- )
- good_matches, good_certainty = matches[good_samples], certainty[good_samples]
- if "balanced" not in self.sample_mode:
- return good_matches, good_certainty
- density = kde(good_matches, std=0.1)
- p = 1 / (density + 1)
- p[density < 10] = (
- 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
- )
- balanced_samples = torch.multinomial(
- p, num_samples=min(num, len(good_certainty)), replacement=False
- )
- return good_matches[balanced_samples], good_certainty[balanced_samples]
- def forward(self, batch, batched=True, upsample=False, scale_factor=1):
- feature_pyramid = self.extract_backbone_features(
- batch, batched=batched, upsample=upsample
- )
- if batched:
- f_q_pyramid = {
- scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
- }
- f_s_pyramid = {
- scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
- }
- else:
- f_q_pyramid, f_s_pyramid = feature_pyramid
- corresps = self.decoder(
- f_q_pyramid,
- f_s_pyramid,
- upsample=upsample,
- **(batch["corresps"] if "corresps" in batch else {}),
- scale_factor=scale_factor,
- )
- return corresps
- def forward_symmetric(self, batch, batched=True, upsample=False, scale_factor=1):
- feature_pyramid = self.extract_backbone_features(
- batch, batched=batched, upsample=upsample
- )
- f_q_pyramid = feature_pyramid
- f_s_pyramid = {
- scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim=0)
- for scale, f_scale in feature_pyramid.items()
- }
- corresps = self.decoder(
- f_q_pyramid,
- f_s_pyramid,
- upsample=upsample,
- **(batch["corresps"] if "corresps" in batch else {}),
- scale_factor=scale_factor,
- )
- return corresps
- def conf_from_fb_consistency(self, flow_forward, flow_backward, th=2):
- # assumes that flow forward is of shape (..., H, W, 2)
- has_batch = False
- if len(flow_forward.shape) == 3:
- flow_forward, flow_backward = flow_forward[None], flow_backward[None]
- else:
- has_batch = True
- H, W = flow_forward.shape[-3:-1]
- th_n = 2 * th / max(H, W)
- coords = torch.stack(
- torch.meshgrid(
- torch.linspace(-1 + 1 / W, 1 - 1 / W, W),
- torch.linspace(-1 + 1 / H, 1 - 1 / H, H),
- indexing="xy",
- ),
- dim=-1,
- ).to(flow_forward.device)
- coords_fb = F.grid_sample(
- flow_backward.permute(0, 3, 1, 2),
- flow_forward,
- align_corners=False,
- mode="bilinear",
- ).permute(0, 2, 3, 1)
- diff = (coords - coords_fb).norm(dim=-1)
- in_th = (diff < th_n).float()
- if not has_batch:
- in_th = in_th[0]
- return in_th
- def to_pixel_coordinates(self, coords, H_A, W_A, H_B=None, W_B=None):
- if coords.shape[-1] == 2:
- return self._to_pixel_coordinates(coords, H_A, W_A)
- if isinstance(coords, (list, tuple)):
- kpts_A, kpts_B = coords[0], coords[1]
- else:
- kpts_A, kpts_B = coords[..., :2], coords[..., 2:]
- return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(
- kpts_B, H_B, W_B
- )
- def _to_pixel_coordinates(self, coords, H, W):
- kpts = torch.stack(
- (W / 2 * (coords[..., 0] + 1), H / 2 * (coords[..., 1] + 1)), axis=-1
- )
- return kpts
- def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B):
- if isinstance(coords, (list, tuple)):
- kpts_A, kpts_B = coords[0], coords[1]
- else:
- kpts_A, kpts_B = coords[..., :2], coords[..., 2:]
- kpts_A = torch.stack(
- (2 / W_A * kpts_A[..., 0] - 1, 2 / H_A * kpts_A[..., 1] - 1), axis=-1
- )
- kpts_B = torch.stack(
- (2 / W_B * kpts_B[..., 0] - 1, 2 / H_B * kpts_B[..., 1] - 1), axis=-1
- )
- return kpts_A, kpts_B
- def match_keypoints(
- self,
- x_A,
- x_B,
- warp,
- certainty,
- return_tuple=True,
- return_inds=False,
- max_dist=0.005,
- cert_th=0,
- ):
- x_A_to_B = F.grid_sample(
- warp[..., -2:].permute(2, 0, 1)[None],
- x_A[None, None],
- align_corners=False,
- mode="bilinear",
- )[0, :, 0].mT
- cert_A_to_B = F.grid_sample(
- certainty[None, None, ...],
- x_A[None, None],
- align_corners=False,
- mode="bilinear",
- )[0, 0, 0]
- D = torch.cdist(x_A_to_B, x_B)
- inds_A, inds_B = torch.nonzero(
- (D == D.min(dim=-1, keepdim=True).values)
- * (D == D.min(dim=-2, keepdim=True).values)
- * (cert_A_to_B[:, None] > cert_th)
- * (D < max_dist),
- as_tuple=True,
- )
- if return_tuple:
- if return_inds:
- return inds_A, inds_B
- else:
- return x_A[inds_A], x_B[inds_B]
- else:
- if return_inds:
- return torch.cat((inds_A, inds_B), dim=-1)
- else:
- return torch.cat((x_A[inds_A], x_B[inds_B]), dim=-1)
-
- def _get_device(self):
- # let's hope this is same for all weights
- return self.encoder.cnn.layers[0].weight.device
- @torch.inference_mode()
- def match(
- self,
- im_A_input,
- im_B_input,
- *args,
- im_A_high_res=None,
- im_B_high_res=None,
- batched=True,
- device=None,
- ):
- self.train(False)
- if not batched:
- raise ValueError("batched must be True, non-batched inference is no longer supported.")
- if device is None and not isinstance(im_A_input, torch.Tensor):
- device = self._get_device()
- elif device is None and isinstance(im_A_input, torch.Tensor):
- device = im_A_input.device
- # Check if inputs are file paths or already loaded images
- im_A = _check_input(im_A_input)
- im_B = _check_input(im_B_input)
- symmetric = self.symmetric
- ws = self.w_resized
- hs = self.h_resized
- scale_factor = math.sqrt(hs * ws / (560**2)) # divide by training resolution
- if isinstance(im_A, Image.Image) and isinstance(im_B, Image.Image):
- b = 1
- w, h = im_A.size
- w2, h2 = im_B.size
- # Get images in good format
- test_transform = get_tuple_transform_ops(
- resize=(hs, ws), normalize=True, clahe=False
- )
- im_A, im_B = test_transform((im_A, im_B))
- batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
- elif isinstance(im_A, torch.Tensor) and isinstance(im_B, torch.Tensor):
- b, c, h, w = im_A.shape
- b, c, h2, w2 = im_B.shape
- assert w == w2 and h == h2, "For batched images we assume same size"
- batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
- if h != self.h_resized or self.w_resized != w:
- warn(
- "Model resolution and batch resolution differ, may produce unexpected results"
- )
- hs, ws = h, w
- else:
- raise ValueError(f"Unsupported input type: {type(im_A)=} and {type(im_B)=}")
- finest_scale = 1
- # Run matcher
- if symmetric:
- corresps = self.forward_symmetric(batch, scale_factor=scale_factor)
- else:
- corresps = self(batch, batched=True, scale_factor=scale_factor)
- if self.upsample_preds:
- hs, ws = self.upsample_res
- if self.attenuate_cert:
- low_res_certainty = F.interpolate(
- corresps[16]["certainty"],
- size=(hs, ws),
- align_corners=False,
- mode="bilinear",
- )
- cert_clamp = 0
- factor = 0.5
- low_res_certainty = (
- factor * low_res_certainty * (low_res_certainty < cert_clamp)
- )
- finest_corresps = corresps[finest_scale]
- if self.upsample_preds and im_A_high_res is None and im_B_high_res is None:
- torch.cuda.empty_cache()
- test_transform = get_tuple_transform_ops(resize=(hs, ws), normalize=True)
- if isinstance(im_A_input, (str, os.PathLike)):
- im_A, im_B = test_transform(
- (
- Image.open(im_A_input).convert("RGB"),
- Image.open(im_B_input).convert("RGB"),
- )
- )
- else:
- assert isinstance(im_A_input, Image.Image), f"Unsupported input type: {type(im_A_input)=}"
- assert isinstance(im_B_input, Image.Image), f"Unsupported input type: {type(im_B_input)=}"
- im_A, im_B = test_transform((im_A_input, im_B_input))
- im_A, im_B = im_A[None].to(device), im_B[None].to(device)
-
- batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
- elif self.upsample_preds and im_A_high_res is not None and im_B_high_res is not None:
- batch = {"im_A": im_A_high_res, "im_B": im_B_high_res, "corresps": finest_corresps}
- elif self.upsample_preds:
- raise ValueError(f"Invalid upsample_preds and high_res inputs with {im_A=},{im_A_high_res=},{im_B=} and {im_B_high_res=}")
- if self.upsample_preds:
- scale_factor = math.sqrt(
- self.upsample_res[0]
- * self.upsample_res[1]
- / (560**2) # divide by training resolution
- )
- if symmetric:
- corresps = self.forward_symmetric(
- batch, upsample=True, batched=True, scale_factor=scale_factor
- )
- else:
- corresps = self(
- batch, batched=True, upsample=True, scale_factor=scale_factor
- )
- im_A_to_im_B = corresps[finest_scale]["flow"]
- certainty = corresps[finest_scale]["certainty"] - (
- low_res_certainty if self.attenuate_cert else 0
- )
- if finest_scale != 1:
- im_A_to_im_B = F.interpolate(
- im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
- )
- certainty = F.interpolate(
- certainty, size=(hs, ws), align_corners=False, mode="bilinear"
- )
- im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
- # Create im_A meshgrid
- im_A_coords = torch.meshgrid(
- (
- torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
- torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
- ),
- indexing="ij",
- )
- im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
- im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
- certainty = certainty.sigmoid() # logits -> probs
- im_A_coords = im_A_coords.permute(0, 2, 3, 1)
- if (im_A_to_im_B.abs() > 1).any() and True:
- wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
- certainty[wrong[:, None]] = 0
- im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
- if symmetric:
- A_to_B, B_to_A = im_A_to_im_B.chunk(2)
- q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
- im_B_coords = im_A_coords
- s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
- warp = torch.cat((q_warp, s_warp), dim=2)
- certainty = torch.cat(certainty.chunk(2), dim=3)
- else:
- warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
- if batched:
- return (warp, certainty[:, 0])
- else:
- return (
- warp[0],
- certainty[0, 0],
- )
- def visualize_warp(
- self,
- warp,
- certainty,
- im_A=None,
- im_B=None,
- im_A_path=None,
- im_B_path=None,
- device="cuda",
- symmetric=True,
- save_path=None,
- unnormalize=False,
- ):
- # assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)"
- H, W2, _ = warp.shape
- W = W2 // 2 if symmetric else W2
- if im_A is None:
- from PIL import Image
- im_A, im_B = (
- Image.open(im_A_path).convert("RGB"),
- Image.open(im_B_path).convert("RGB"),
- )
- if not isinstance(im_A, torch.Tensor):
- im_A = im_A.resize((W, H))
- im_B = im_B.resize((W, H))
- x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
- if symmetric:
- x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
- else:
- if symmetric:
- x_A = im_A
- x_B = im_B
- im_A_transfer_rgb = F.grid_sample(
- x_B[None], warp[:, :W, 2:][None], mode="bilinear", align_corners=False
- )[0]
- if symmetric:
- im_B_transfer_rgb = F.grid_sample(
- x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
- )[0]
- warp_im = torch.cat((im_A_transfer_rgb, im_B_transfer_rgb), dim=2)
- white_im = torch.ones((H, 2 * W), device=device)
- else:
- warp_im = im_A_transfer_rgb
- white_im = torch.ones((H, W), device=device)
- vis_im = certainty * warp_im + (1 - certainty) * white_im
- if save_path is not None:
- from romatch.utils import tensor_to_pil
- tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
- return vis_im
|