tiny.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import os
  5. import torch
  6. from pathlib import Path
  7. import math
  8. import numpy as np
  9. from torch import nn
  10. from PIL import Image
  11. from torchvision.transforms import ToTensor
  12. from romatch.utils.kde import kde
  13. class BasicLayer(nn.Module):
  14. """
  15. Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
  16. """
  17. def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True):
  18. super().__init__()
  19. self.layer = nn.Sequential(
  20. nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
  21. nn.BatchNorm2d(out_channels, affine=False),
  22. nn.ReLU(inplace = True) if relu else nn.Identity()
  23. )
  24. def forward(self, x):
  25. return self.layer(x)
  26. class TinyRoMa(nn.Module):
  27. """
  28. Implementation of architecture described in
  29. "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
  30. """
  31. def __init__(self, xfeat = None,
  32. freeze_xfeat = True,
  33. sample_mode = "threshold_balanced",
  34. symmetric = False,
  35. exact_softmax = False):
  36. super().__init__()
  37. del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher
  38. if freeze_xfeat:
  39. xfeat.train(False)
  40. self.xfeat = [xfeat]# hide params from ddp
  41. else:
  42. self.xfeat = nn.ModuleList([xfeat])
  43. self.freeze_xfeat = freeze_xfeat
  44. match_dim = 256
  45. self.coarse_matcher = nn.Sequential(
  46. BasicLayer(64+64+2, match_dim,),
  47. BasicLayer(match_dim, match_dim,),
  48. BasicLayer(match_dim, match_dim,),
  49. BasicLayer(match_dim, match_dim,),
  50. nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0))
  51. fine_match_dim = 64
  52. self.fine_matcher = nn.Sequential(
  53. BasicLayer(24+24+2, fine_match_dim,),
  54. BasicLayer(fine_match_dim, fine_match_dim,),
  55. BasicLayer(fine_match_dim, fine_match_dim,),
  56. BasicLayer(fine_match_dim, fine_match_dim,),
  57. nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),)
  58. self.sample_mode = sample_mode
  59. self.sample_thresh = 0.05
  60. self.symmetric = symmetric
  61. self.exact_softmax = exact_softmax
  62. @property
  63. def device(self):
  64. return self.fine_matcher[-1].weight.device
  65. def preprocess_tensor(self, x):
  66. """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
  67. H, W = x.shape[-2:]
  68. _H, _W = (H//32) * 32, (W//32) * 32
  69. rh, rw = H/_H, W/_W
  70. x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
  71. return x, rh, rw
  72. def forward_single(self, x):
  73. with torch.inference_mode(self.freeze_xfeat or not self.training):
  74. xfeat = self.xfeat[0]
  75. with torch.no_grad():
  76. x = x.mean(dim=1, keepdim = True)
  77. x = xfeat.norm(x)
  78. #main backbone
  79. x1 = xfeat.block1(x)
  80. x2 = xfeat.block2(x1 + xfeat.skip1(x))
  81. x3 = xfeat.block3(x2)
  82. x4 = xfeat.block4(x3)
  83. x5 = xfeat.block5(x4)
  84. x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
  85. x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
  86. feats = xfeat.block_fusion( x3 + x4 + x5 )
  87. if self.freeze_xfeat:
  88. return x2.clone(), feats.clone()
  89. return x2, feats
  90. def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
  91. if coords.shape[-1] == 2:
  92. return self._to_pixel_coordinates(coords, H_A, W_A)
  93. if isinstance(coords, (list, tuple)):
  94. kpts_A, kpts_B = coords[0], coords[1]
  95. else:
  96. kpts_A, kpts_B = coords[...,:2], coords[...,2:]
  97. return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
  98. def _to_pixel_coordinates(self, coords, H, W):
  99. kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
  100. return kpts
  101. def pos_embed(self, corr_volume: torch.Tensor):
  102. B, H1, W1, H0, W0 = corr_volume.shape
  103. grid = torch.stack(
  104. torch.meshgrid(
  105. torch.linspace(-1+1/W1,1-1/W1, W1),
  106. torch.linspace(-1+1/H1,1-1/H1, H1),
  107. indexing = "xy"),
  108. dim = -1).float().to(corr_volume).reshape(H1*W1, 2)
  109. down = 4
  110. if not self.training and not self.exact_softmax:
  111. grid_lr = torch.stack(
  112. torch.meshgrid(
  113. torch.linspace(-1+down/W1,1-down/W1, W1//down),
  114. torch.linspace(-1+down/H1,1-down/H1, H1//down),
  115. indexing = "xy"),
  116. dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
  117. cv = corr_volume
  118. best_match = cv.reshape(B,H1*W1,H0,W0).argmax(dim=1) # B, HW, H, W
  119. P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
  120. pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
  121. pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
  122. #print("hej")
  123. else:
  124. P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W
  125. pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid)
  126. return pos_embeddings
  127. def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
  128. im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False):
  129. device = warp.device
  130. H,W2,_ = warp.shape
  131. W = W2//2 if symmetric else W2
  132. if im_A is None:
  133. from PIL import Image
  134. im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
  135. if not isinstance(im_A, torch.Tensor):
  136. im_A = im_A.resize((W,H))
  137. im_B = im_B.resize((W,H))
  138. x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
  139. if symmetric:
  140. x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
  141. else:
  142. if symmetric:
  143. x_A = im_A
  144. x_B = im_B
  145. im_A_transfer_rgb = F.grid_sample(
  146. x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
  147. )[0]
  148. if symmetric:
  149. im_B_transfer_rgb = F.grid_sample(
  150. x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
  151. )[0]
  152. warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
  153. white_im = torch.ones((H,2*W),device=device)
  154. else:
  155. warp_im = im_A_transfer_rgb
  156. white_im = torch.ones((H, W), device = device)
  157. vis_im = certainty * warp_im + (1 - certainty) * white_im
  158. if save_path is not None:
  159. from romatch.utils import tensor_to_pil
  160. tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
  161. return vis_im
  162. def corr_volume(self, feat0, feat1):
  163. """
  164. input:
  165. feat0 -> torch.Tensor(B, C, H, W)
  166. feat1 -> torch.Tensor(B, C, H, W)
  167. return:
  168. corr_volume -> torch.Tensor(B, H, W, H, W)
  169. """
  170. B, C, H0, W0 = feat0.shape
  171. B, C, H1, W1 = feat1.shape
  172. feat0 = feat0.view(B, C, H0*W0)
  173. feat1 = feat1.view(B, C, H1*W1)
  174. corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16
  175. return corr_volume
  176. @torch.inference_mode()
  177. def match_from_path(self, im0_path, im1_path):
  178. device = self.device
  179. im0 = ToTensor()(Image.open(im0_path))[None].to(device)
  180. im1 = ToTensor()(Image.open(im1_path))[None].to(device)
  181. return self.match(im0, im1, batched = False)
  182. @torch.inference_mode()
  183. def match(self, im0, im1, *args, batched = True):
  184. # stupid
  185. if isinstance(im0, (str, Path)):
  186. return self.match_from_path(im0, im1)
  187. elif isinstance(im0, Image.Image):
  188. batched = False
  189. device = self.device
  190. im0 = ToTensor()(im0)[None].to(device)
  191. im1 = ToTensor()(im1)[None].to(device)
  192. B,C,H0,W0 = im0.shape
  193. B,C,H1,W1 = im1.shape
  194. self.train(False)
  195. corresps = self.forward({"im_A":im0, "im_B":im1})
  196. #return 1,1
  197. flow = F.interpolate(
  198. corresps[4]["flow"],
  199. size = (H0, W0),
  200. mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2)
  201. grid = torch.stack(
  202. torch.meshgrid(
  203. torch.linspace(-1+1/W0,1-1/W0, W0),
  204. torch.linspace(-1+1/H0,1-1/H0, H0),
  205. indexing = "xy"),
  206. dim = -1).float().to(flow.device).expand(B, H0, W0, 2)
  207. certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False)
  208. warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid()
  209. if batched:
  210. return warp, cert
  211. else:
  212. return warp[0], cert[0]
  213. def sample(
  214. self,
  215. matches,
  216. certainty,
  217. num=5_000,
  218. ):
  219. H,W,_ = matches.shape
  220. if "threshold" in self.sample_mode:
  221. upper_thresh = self.sample_thresh
  222. certainty = certainty.clone()
  223. certainty[certainty > upper_thresh] = 1
  224. matches, certainty = (
  225. matches.reshape(-1, 4),
  226. certainty.reshape(-1),
  227. )
  228. expansion_factor = 4 if "balanced" in self.sample_mode else 1
  229. good_samples = torch.multinomial(certainty,
  230. num_samples = min(expansion_factor*num, len(certainty)),
  231. replacement=False)
  232. good_matches, good_certainty = matches[good_samples], certainty[good_samples]
  233. if "balanced" not in self.sample_mode:
  234. return good_matches, good_certainty
  235. use_half = True if matches.device.type == "cuda" else False
  236. down = 1 if matches.device.type == "cuda" else 8
  237. density = kde(good_matches, std=0.1, half = use_half, down = down)
  238. p = 1 / (density+1)
  239. p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
  240. balanced_samples = torch.multinomial(p,
  241. num_samples = min(num,len(good_certainty)),
  242. replacement=False)
  243. return good_matches[balanced_samples], good_certainty[balanced_samples]
  244. def forward(self, batch):
  245. """
  246. input:
  247. x -> torch.Tensor(B, C, H, W) grayscale or rgb images
  248. return:
  249. """
  250. im0 = batch["im_A"]
  251. im1 = batch["im_B"]
  252. corresps = {}
  253. im0, rh0, rw0 = self.preprocess_tensor(im0)
  254. im1, rh1, rw1 = self.preprocess_tensor(im1)
  255. B, C, H0, W0 = im0.shape
  256. B, C, H1, W1 = im1.shape
  257. to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None]
  258. if im0.shape[-2:] == im1.shape[-2:]:
  259. x = torch.cat([im0, im1], dim=0)
  260. x = self.forward_single(x)
  261. feats_x0_c, feats_x1_c = x[1].chunk(2)
  262. feats_x0_f, feats_x1_f = x[0].chunk(2)
  263. else:
  264. feats_x0_f, feats_x0_c = self.forward_single(im0)
  265. feats_x1_f, feats_x1_c = self.forward_single(im1)
  266. corr_volume = self.corr_volume(feats_x0_c, feats_x1_c)
  267. coarse_warp = self.pos_embed(corr_volume)
  268. coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1)
  269. feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
  270. coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1))
  271. coarse_matches = coarse_matches + coarse_matches_delta * to_normalized
  272. corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]}
  273. coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False)
  274. coarse_matches_up_detach = coarse_matches_up.detach()#note the detach
  275. feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
  276. fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1))
  277. fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized
  278. corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]}
  279. return corresps