| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import os
- import torch
- from pathlib import Path
- import math
- import numpy as np
- from torch import nn
- from PIL import Image
- from torchvision.transforms import ToTensor
- from romatch.utils.kde import kde
- class BasicLayer(nn.Module):
- """
- Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
- """
- def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True):
- super().__init__()
- self.layer = nn.Sequential(
- nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
- nn.BatchNorm2d(out_channels, affine=False),
- nn.ReLU(inplace = True) if relu else nn.Identity()
- )
- def forward(self, x):
- return self.layer(x)
- class TinyRoMa(nn.Module):
- """
- Implementation of architecture described in
- "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
- """
- def __init__(self, xfeat = None,
- freeze_xfeat = True,
- sample_mode = "threshold_balanced",
- symmetric = False,
- exact_softmax = False):
- super().__init__()
- del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher
- if freeze_xfeat:
- xfeat.train(False)
- self.xfeat = [xfeat]# hide params from ddp
- else:
- self.xfeat = nn.ModuleList([xfeat])
- self.freeze_xfeat = freeze_xfeat
- match_dim = 256
- self.coarse_matcher = nn.Sequential(
- BasicLayer(64+64+2, match_dim,),
- BasicLayer(match_dim, match_dim,),
- BasicLayer(match_dim, match_dim,),
- BasicLayer(match_dim, match_dim,),
- nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0))
- fine_match_dim = 64
- self.fine_matcher = nn.Sequential(
- BasicLayer(24+24+2, fine_match_dim,),
- BasicLayer(fine_match_dim, fine_match_dim,),
- BasicLayer(fine_match_dim, fine_match_dim,),
- BasicLayer(fine_match_dim, fine_match_dim,),
- nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),)
- self.sample_mode = sample_mode
- self.sample_thresh = 0.05
- self.symmetric = symmetric
- self.exact_softmax = exact_softmax
-
- @property
- def device(self):
- return self.fine_matcher[-1].weight.device
-
- def preprocess_tensor(self, x):
- """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
- H, W = x.shape[-2:]
- _H, _W = (H//32) * 32, (W//32) * 32
- rh, rw = H/_H, W/_W
- x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
- return x, rh, rw
-
- def forward_single(self, x):
- with torch.inference_mode(self.freeze_xfeat or not self.training):
- xfeat = self.xfeat[0]
- with torch.no_grad():
- x = x.mean(dim=1, keepdim = True)
- x = xfeat.norm(x)
- #main backbone
- x1 = xfeat.block1(x)
- x2 = xfeat.block2(x1 + xfeat.skip1(x))
- x3 = xfeat.block3(x2)
- x4 = xfeat.block4(x3)
- x5 = xfeat.block5(x4)
- x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
- x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
- feats = xfeat.block_fusion( x3 + x4 + x5 )
- if self.freeze_xfeat:
- return x2.clone(), feats.clone()
- return x2, feats
- def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
- if coords.shape[-1] == 2:
- return self._to_pixel_coordinates(coords, H_A, W_A)
-
- if isinstance(coords, (list, tuple)):
- kpts_A, kpts_B = coords[0], coords[1]
- else:
- kpts_A, kpts_B = coords[...,:2], coords[...,2:]
- return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
- def _to_pixel_coordinates(self, coords, H, W):
- kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
- return kpts
-
- def pos_embed(self, corr_volume: torch.Tensor):
- B, H1, W1, H0, W0 = corr_volume.shape
- grid = torch.stack(
- torch.meshgrid(
- torch.linspace(-1+1/W1,1-1/W1, W1),
- torch.linspace(-1+1/H1,1-1/H1, H1),
- indexing = "xy"),
- dim = -1).float().to(corr_volume).reshape(H1*W1, 2)
- down = 4
- if not self.training and not self.exact_softmax:
- grid_lr = torch.stack(
- torch.meshgrid(
- torch.linspace(-1+down/W1,1-down/W1, W1//down),
- torch.linspace(-1+down/H1,1-down/H1, H1//down),
- indexing = "xy"),
- dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
- cv = corr_volume
- best_match = cv.reshape(B,H1*W1,H0,W0).argmax(dim=1) # B, HW, H, W
- P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
- pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
- pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
- #print("hej")
- else:
- P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W
- pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid)
- return pos_embeddings
-
- def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
- im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False):
- device = warp.device
- H,W2,_ = warp.shape
- W = W2//2 if symmetric else W2
- if im_A is None:
- from PIL import Image
- im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
- if not isinstance(im_A, torch.Tensor):
- im_A = im_A.resize((W,H))
- im_B = im_B.resize((W,H))
- x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
- if symmetric:
- x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
- else:
- if symmetric:
- x_A = im_A
- x_B = im_B
- im_A_transfer_rgb = F.grid_sample(
- x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
- )[0]
- if symmetric:
- im_B_transfer_rgb = F.grid_sample(
- x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
- )[0]
- warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
- white_im = torch.ones((H,2*W),device=device)
- else:
- warp_im = im_A_transfer_rgb
- white_im = torch.ones((H, W), device = device)
- vis_im = certainty * warp_im + (1 - certainty) * white_im
- if save_path is not None:
- from romatch.utils import tensor_to_pil
- tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
- return vis_im
-
- def corr_volume(self, feat0, feat1):
- """
- input:
- feat0 -> torch.Tensor(B, C, H, W)
- feat1 -> torch.Tensor(B, C, H, W)
- return:
- corr_volume -> torch.Tensor(B, H, W, H, W)
- """
- B, C, H0, W0 = feat0.shape
- B, C, H1, W1 = feat1.shape
- feat0 = feat0.view(B, C, H0*W0)
- feat1 = feat1.view(B, C, H1*W1)
- corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16
- return corr_volume
-
- @torch.inference_mode()
- def match_from_path(self, im0_path, im1_path):
- device = self.device
- im0 = ToTensor()(Image.open(im0_path))[None].to(device)
- im1 = ToTensor()(Image.open(im1_path))[None].to(device)
- return self.match(im0, im1, batched = False)
-
- @torch.inference_mode()
- def match(self, im0, im1, *args, batched = True):
- # stupid
- if isinstance(im0, (str, Path)):
- return self.match_from_path(im0, im1)
- elif isinstance(im0, Image.Image):
- batched = False
- device = self.device
- im0 = ToTensor()(im0)[None].to(device)
- im1 = ToTensor()(im1)[None].to(device)
-
- B,C,H0,W0 = im0.shape
- B,C,H1,W1 = im1.shape
- self.train(False)
- corresps = self.forward({"im_A":im0, "im_B":im1})
- #return 1,1
- flow = F.interpolate(
- corresps[4]["flow"],
- size = (H0, W0),
- mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2)
- grid = torch.stack(
- torch.meshgrid(
- torch.linspace(-1+1/W0,1-1/W0, W0),
- torch.linspace(-1+1/H0,1-1/H0, H0),
- indexing = "xy"),
- dim = -1).float().to(flow.device).expand(B, H0, W0, 2)
-
- certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False)
- warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid()
- if batched:
- return warp, cert
- else:
- return warp[0], cert[0]
- def sample(
- self,
- matches,
- certainty,
- num=5_000,
- ):
- H,W,_ = matches.shape
- if "threshold" in self.sample_mode:
- upper_thresh = self.sample_thresh
- certainty = certainty.clone()
- certainty[certainty > upper_thresh] = 1
- matches, certainty = (
- matches.reshape(-1, 4),
- certainty.reshape(-1),
- )
- expansion_factor = 4 if "balanced" in self.sample_mode else 1
- good_samples = torch.multinomial(certainty,
- num_samples = min(expansion_factor*num, len(certainty)),
- replacement=False)
- good_matches, good_certainty = matches[good_samples], certainty[good_samples]
- if "balanced" not in self.sample_mode:
- return good_matches, good_certainty
- use_half = True if matches.device.type == "cuda" else False
- down = 1 if matches.device.type == "cuda" else 8
- density = kde(good_matches, std=0.1, half = use_half, down = down)
- p = 1 / (density+1)
- p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
- balanced_samples = torch.multinomial(p,
- num_samples = min(num,len(good_certainty)),
- replacement=False)
- return good_matches[balanced_samples], good_certainty[balanced_samples]
-
-
- def forward(self, batch):
- """
- input:
- x -> torch.Tensor(B, C, H, W) grayscale or rgb images
- return:
- """
- im0 = batch["im_A"]
- im1 = batch["im_B"]
- corresps = {}
- im0, rh0, rw0 = self.preprocess_tensor(im0)
- im1, rh1, rw1 = self.preprocess_tensor(im1)
- B, C, H0, W0 = im0.shape
- B, C, H1, W1 = im1.shape
- to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None]
-
- if im0.shape[-2:] == im1.shape[-2:]:
- x = torch.cat([im0, im1], dim=0)
- x = self.forward_single(x)
- feats_x0_c, feats_x1_c = x[1].chunk(2)
- feats_x0_f, feats_x1_f = x[0].chunk(2)
- else:
- feats_x0_f, feats_x0_c = self.forward_single(im0)
- feats_x1_f, feats_x1_c = self.forward_single(im1)
- corr_volume = self.corr_volume(feats_x0_c, feats_x1_c)
- coarse_warp = self.pos_embed(corr_volume)
- coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1)
- feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
- coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1))
- coarse_matches = coarse_matches + coarse_matches_delta * to_normalized
- corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]}
- coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False)
- coarse_matches_up_detach = coarse_matches_up.detach()#note the detach
- feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
- fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1))
- fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized
- corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]}
- return corresps
|