matcher.py 35 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001
  1. import os
  2. import math
  3. import sys
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from einops import rearrange
  9. from warnings import warn
  10. from PIL import Image
  11. from romatch.utils import get_tuple_transform_ops
  12. from romatch.utils.local_correlation import local_correlation
  13. from romatch.utils.utils import (
  14. check_rgb,
  15. cls_to_flow_refine,
  16. get_autocast_params,
  17. check_not_i16,
  18. )
  19. from romatch.utils.kde import kde
  20. from romatch.models.encoders import CNNandDinov2
  21. class ConvRefiner(nn.Module):
  22. def __init__(
  23. self,
  24. in_dim=6,
  25. hidden_dim=16,
  26. out_dim=2,
  27. dw=False,
  28. kernel_size=5,
  29. hidden_blocks=3,
  30. displacement_emb=None,
  31. displacement_emb_dim=None,
  32. local_corr_radius=None,
  33. corr_in_other=None,
  34. no_im_B_fm=False,
  35. amp=False,
  36. concat_logits=False,
  37. use_bias_block_1=True,
  38. use_cosine_corr=False,
  39. disable_local_corr_grad=False,
  40. is_classifier=False,
  41. sample_mode="bilinear",
  42. norm_type=nn.BatchNorm2d,
  43. bn_momentum=0.1,
  44. amp_dtype=torch.float16,
  45. use_custom_corr=False,
  46. ):
  47. super().__init__()
  48. if sys.platform != "linux":
  49. warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
  50. use_custom_corr = False
  51. self.bn_momentum = bn_momentum
  52. self.block1 = self.create_block(
  53. in_dim,
  54. hidden_dim,
  55. dw=dw,
  56. kernel_size=kernel_size,
  57. bias=use_bias_block_1,
  58. )
  59. self.hidden_blocks = nn.Sequential(
  60. *[
  61. self.create_block(
  62. hidden_dim,
  63. hidden_dim,
  64. dw=dw,
  65. kernel_size=kernel_size,
  66. norm_type=norm_type,
  67. )
  68. for hb in range(hidden_blocks)
  69. ]
  70. )
  71. self.hidden_blocks = self.hidden_blocks
  72. self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
  73. if displacement_emb:
  74. self.has_displacement_emb = True
  75. self.disp_emb = nn.Conv2d(2, displacement_emb_dim, 1, 1, 0)
  76. else:
  77. self.has_displacement_emb = False
  78. self.local_corr_radius = local_corr_radius
  79. self.corr_in_other = corr_in_other
  80. self.no_im_B_fm = no_im_B_fm
  81. self.amp = amp
  82. self.concat_logits = concat_logits
  83. self.use_cosine_corr = use_cosine_corr
  84. self.disable_local_corr_grad = disable_local_corr_grad
  85. self.is_classifier = is_classifier
  86. self.sample_mode = sample_mode
  87. self.amp_dtype = amp_dtype
  88. self.use_custom_corr = use_custom_corr
  89. def create_block(
  90. self,
  91. in_dim,
  92. out_dim,
  93. dw=False,
  94. kernel_size=5,
  95. bias=True,
  96. norm_type=nn.BatchNorm2d,
  97. ):
  98. num_groups = 1 if not dw else in_dim
  99. if dw:
  100. assert out_dim % in_dim == 0, (
  101. "outdim must be divisible by indim for depthwise"
  102. )
  103. conv1 = nn.Conv2d(
  104. in_dim,
  105. out_dim,
  106. kernel_size=kernel_size,
  107. stride=1,
  108. padding=kernel_size // 2,
  109. groups=num_groups,
  110. bias=bias,
  111. )
  112. norm = (
  113. norm_type(out_dim, momentum=self.bn_momentum)
  114. if norm_type is nn.BatchNorm2d
  115. else norm_type(num_channels=out_dim)
  116. )
  117. relu = nn.ReLU(inplace=True)
  118. conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
  119. return nn.Sequential(conv1, norm, relu, conv2)
  120. def forward(self, x, y, warp, scale_factor=1, logits=None):
  121. b, c, hs, ws = x.shape
  122. autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(
  123. x.device, enabled=self.amp, dtype=self.amp_dtype
  124. )
  125. with torch.autocast(
  126. autocast_device, enabled=autocast_enabled, dtype=autocast_dtype
  127. ):
  128. x_hat = F.grid_sample(
  129. y, warp.permute(0, 2, 3, 1), align_corners=False, mode=self.sample_mode
  130. )
  131. if self.has_displacement_emb:
  132. im_A_coords = torch.meshgrid(
  133. (
  134. torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
  135. torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
  136. ),
  137. indexing="ij",
  138. )
  139. im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
  140. im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
  141. in_displacement = warp - im_A_coords
  142. emb_in_displacement = self.disp_emb(
  143. 40 / 32 * scale_factor * in_displacement
  144. )
  145. if self.local_corr_radius:
  146. if self.corr_in_other:
  147. # Corr in other means take a kxk grid around the predicted coordinate in other image
  148. local_corr = local_correlation(
  149. x,
  150. y,
  151. self.local_corr_radius,
  152. warp,
  153. sample_mode=self.sample_mode,
  154. use_custom_corr=self.use_custom_corr,
  155. )
  156. else:
  157. raise NotImplementedError(
  158. "Local corr in own frame should not be used."
  159. )
  160. if self.no_im_B_fm:
  161. x_hat = torch.zeros_like(x)
  162. d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
  163. else:
  164. d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
  165. else:
  166. if self.no_im_B_fm:
  167. x_hat = torch.zeros_like(x)
  168. d = torch.cat((x, x_hat), dim=1)
  169. if self.concat_logits:
  170. d = torch.cat((d, logits), dim=1)
  171. # pad d if needed
  172. channel_d = d.shape[1]
  173. channel_block1 = self.block1[0].in_channels
  174. if channel_d != channel_block1:
  175. d = F.pad(d, (0, 0, 0, 0, 0, channel_block1 - channel_d))
  176. d = self.block1(d)
  177. d = self.hidden_blocks(d)
  178. d = self.out_conv(d.float())
  179. displacement, certainty = d[:, :-1], d[:, -1:]
  180. return displacement, certainty
  181. class CosKernel(nn.Module): # similar to softmax kernel
  182. def __init__(self, T, learn_temperature=False):
  183. super().__init__()
  184. self.learn_temperature = learn_temperature
  185. if self.learn_temperature:
  186. self.T = nn.Parameter(torch.tensor(T))
  187. else:
  188. self.T = T
  189. def __call__(self, x, y, eps=1e-6):
  190. c = torch.einsum("bnd,bmd->bnm", x, y) / (
  191. x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
  192. )
  193. if self.learn_temperature:
  194. T = self.T.abs() + 0.01
  195. else:
  196. T = torch.tensor(self.T, device=c.device)
  197. K = ((c - 1.0) / T).exp()
  198. return K
  199. class GP(nn.Module):
  200. def __init__(
  201. self,
  202. kernel,
  203. T=1,
  204. learn_temperature=False,
  205. only_attention=False,
  206. gp_dim=64,
  207. basis="fourier",
  208. covar_size=5,
  209. only_nearest_neighbour=False,
  210. sigma_noise=0.1,
  211. no_cov=False,
  212. predict_features=False,
  213. ):
  214. super().__init__()
  215. self.K = kernel(T=T, learn_temperature=learn_temperature)
  216. self.sigma_noise = sigma_noise
  217. self.covar_size = covar_size
  218. self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
  219. self.only_attention = only_attention
  220. self.only_nearest_neighbour = only_nearest_neighbour
  221. self.basis = basis
  222. self.no_cov = no_cov
  223. self.dim = gp_dim
  224. self.predict_features = predict_features
  225. def get_local_cov(self, cov):
  226. K = self.covar_size
  227. b, h, w, h, w = cov.shape
  228. hw = h * w
  229. cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
  230. delta = torch.stack(
  231. torch.meshgrid(
  232. torch.arange(-(K // 2), K // 2 + 1),
  233. torch.arange(-(K // 2), K // 2 + 1),
  234. indexing="ij",
  235. ),
  236. dim=-1,
  237. )
  238. positions = torch.stack(
  239. torch.meshgrid(
  240. torch.arange(K // 2, h + K // 2),
  241. torch.arange(K // 2, w + K // 2),
  242. indexing="ij",
  243. ),
  244. dim=-1,
  245. )
  246. neighbours = positions[:, :, None, None, :] + delta[None, :, :]
  247. points = torch.arange(hw)[:, None].expand(hw, K**2)
  248. local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
  249. :,
  250. points.flatten(),
  251. neighbours[..., 0].flatten(),
  252. neighbours[..., 1].flatten(),
  253. ].reshape(b, h, w, K**2)
  254. return local_cov
  255. def reshape(self, x):
  256. return rearrange(x, "b d h w -> b (h w) d")
  257. def project_to_basis(self, x):
  258. if self.basis == "fourier":
  259. return torch.cos(8 * math.pi * self.pos_conv(x))
  260. elif self.basis == "linear":
  261. return self.pos_conv(x)
  262. else:
  263. raise ValueError(
  264. "No other bases other than fourier and linear currently im_Bed in public release"
  265. )
  266. def get_pos_enc(self, y):
  267. b, c, h, w = y.shape
  268. coarse_coords = torch.meshgrid(
  269. (
  270. torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
  271. torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
  272. ),
  273. indexing="ij",
  274. )
  275. coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
  276. None
  277. ].expand(b, h, w, 2)
  278. coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
  279. coarse_embedded_coords = self.project_to_basis(coarse_coords)
  280. return coarse_embedded_coords
  281. def forward(self, x, y, **kwargs):
  282. b, c, h1, w1 = x.shape
  283. b, c, h2, w2 = y.shape
  284. f = self.get_pos_enc(y)
  285. b, d, h2, w2 = f.shape
  286. x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f)
  287. # K_xx = self.K(x, x)
  288. K_yy = self.K(y, y)
  289. K_xy = self.K(x, y)
  290. K_yx = K_xy.permute(0, 2, 1)
  291. sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
  292. if self.training:
  293. K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
  294. mu_x = K_xy.matmul(K_yy_inv.matmul(f))
  295. else:
  296. # faster inference, possibly also useful for training
  297. L_t = torch.linalg.cholesky(K_yy + sigma_noise)
  298. pos_emb = torch.cholesky_solve(f.reshape(b, h2 * w2, d), L_t, upper=False)
  299. mu_x = K_xy @ pos_emb
  300. mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
  301. # if not self.no_cov:
  302. # cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
  303. # cov_x = rearrange(
  304. # cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1
  305. # )
  306. # local_cov_x = self.get_local_cov(cov_x)
  307. # local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
  308. # gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
  309. # else:
  310. gp_feats = mu_x
  311. return gp_feats
  312. class Decoder(nn.Module):
  313. def __init__(
  314. self,
  315. embedding_decoder,
  316. gps,
  317. proj,
  318. conv_refiner,
  319. detach=False,
  320. scales="all",
  321. pos_embeddings=None,
  322. num_refinement_steps_per_scale=1,
  323. warp_noise_std=0.0,
  324. displacement_dropout_p=0.0,
  325. gm_warp_dropout_p=0.0,
  326. flow_upsample_mode="bilinear",
  327. amp_dtype=torch.float16,
  328. ):
  329. super().__init__()
  330. self.embedding_decoder = embedding_decoder
  331. self.num_refinement_steps_per_scale = num_refinement_steps_per_scale
  332. self.gps = gps
  333. self.proj = proj
  334. self.conv_refiner = conv_refiner
  335. self.detach = detach
  336. if pos_embeddings is None:
  337. self.pos_embeddings = {}
  338. else:
  339. self.pos_embeddings = pos_embeddings
  340. if scales == "all":
  341. self.scales = ["32", "16", "8", "4", "2", "1"]
  342. else:
  343. self.scales = scales
  344. self.warp_noise_std = warp_noise_std
  345. self.refine_init = 4
  346. self.displacement_dropout_p = displacement_dropout_p
  347. self.gm_warp_dropout_p = gm_warp_dropout_p
  348. self.flow_upsample_mode = flow_upsample_mode
  349. self.amp_dtype = amp_dtype
  350. def get_placeholder_flow(self, b, h, w, device):
  351. coarse_coords = torch.meshgrid(
  352. (
  353. torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
  354. torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
  355. ),
  356. indexing="ij",
  357. )
  358. coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
  359. None
  360. ].expand(b, h, w, 2)
  361. coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
  362. return coarse_coords
  363. def get_positional_embedding(self, b, h, w, device):
  364. coarse_coords = torch.meshgrid(
  365. (
  366. torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
  367. torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
  368. ),
  369. indexing="ij",
  370. )
  371. coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
  372. None
  373. ].expand(b, h, w, 2)
  374. coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
  375. coarse_embedded_coords = self.pos_embedding(coarse_coords)
  376. return coarse_embedded_coords
  377. def forward(
  378. self,
  379. f1,
  380. f2,
  381. gt_warp=None,
  382. gt_prob=None,
  383. upsample=False,
  384. flow=None,
  385. certainty=None,
  386. scale_factor=1,
  387. ):
  388. coarse_scales = self.embedding_decoder.scales()
  389. all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
  390. sizes = {scale: f1[scale].shape[-2:] for scale in f1}
  391. h, w = sizes[1]
  392. b = f1[1].shape[0]
  393. device = f1[1].device
  394. coarsest_scale = int(all_scales[0])
  395. old_stuff = torch.zeros(
  396. b,
  397. self.embedding_decoder.hidden_dim,
  398. *sizes[coarsest_scale],
  399. device=f1[coarsest_scale].device,
  400. )
  401. corresps = {}
  402. if not upsample:
  403. flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
  404. certainty = 0.0
  405. else:
  406. flow = F.interpolate(
  407. flow,
  408. size=sizes[coarsest_scale],
  409. align_corners=False,
  410. mode="bilinear",
  411. )
  412. certainty = F.interpolate(
  413. certainty,
  414. size=sizes[coarsest_scale],
  415. align_corners=False,
  416. mode="bilinear",
  417. )
  418. displacement = 0.0
  419. for new_scale in all_scales:
  420. ins = int(new_scale)
  421. corresps[ins] = {}
  422. f1_s, f2_s = f1[ins], f2[ins]
  423. if new_scale in self.proj:
  424. autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(
  425. f1_s.device, str(f1_s) == "cuda", self.amp_dtype
  426. )
  427. with torch.autocast(
  428. autocast_device, enabled=autocast_enabled, dtype=autocast_dtype
  429. ):
  430. if not autocast_enabled:
  431. f1_s, f2_s = f1_s.to(torch.float32), f2_s.to(torch.float32)
  432. f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
  433. if ins in coarse_scales:
  434. old_stuff = F.interpolate(
  435. old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
  436. )
  437. gp_posterior = self.gps[new_scale](f1_s, f2_s)
  438. gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
  439. gp_posterior, f1_s, old_stuff, new_scale
  440. )
  441. if self.embedding_decoder.is_classifier:
  442. flow = cls_to_flow_refine(
  443. gm_warp_or_cls,
  444. ).permute(0, 3, 1, 2)
  445. corresps[ins].update(
  446. {
  447. "gm_cls": gm_warp_or_cls,
  448. "gm_certainty": certainty,
  449. }
  450. ) if self.training else None
  451. else:
  452. corresps[ins].update(
  453. {
  454. "gm_flow": gm_warp_or_cls,
  455. "gm_certainty": certainty,
  456. }
  457. ) if self.training else None
  458. flow = gm_warp_or_cls.detach()
  459. if new_scale in self.conv_refiner:
  460. corresps[ins].update(
  461. {"flow_pre_delta": flow}
  462. ) if self.training else None
  463. delta_flow, delta_certainty = self.conv_refiner[new_scale](
  464. f1_s,
  465. f2_s,
  466. flow,
  467. scale_factor=scale_factor,
  468. logits=certainty,
  469. )
  470. corresps[ins].update(
  471. {
  472. "delta_flow": delta_flow,
  473. }
  474. ) if self.training else None
  475. displacement = ins * torch.stack(
  476. (
  477. delta_flow[:, 0].float() / (self.refine_init * w),
  478. delta_flow[:, 1].float() / (self.refine_init * h),
  479. ),
  480. dim=1,
  481. )
  482. flow = flow + displacement
  483. certainty = (
  484. certainty + delta_certainty
  485. ) # predict both certainty and displacement
  486. corresps[ins].update(
  487. {
  488. "certainty": certainty,
  489. "flow": flow,
  490. }
  491. )
  492. if new_scale != "1":
  493. flow = F.interpolate(
  494. flow,
  495. size=sizes[ins // 2],
  496. mode=self.flow_upsample_mode,
  497. )
  498. certainty = F.interpolate(
  499. certainty,
  500. size=sizes[ins // 2],
  501. mode=self.flow_upsample_mode,
  502. )
  503. if self.detach:
  504. flow = flow.detach()
  505. certainty = certainty.detach()
  506. return corresps
  507. def _check_input(im_input):
  508. if isinstance(im_input, (str, os.PathLike)):
  509. im = Image.open(im_input)
  510. check_not_i16(im)
  511. im = im.convert("RGB")
  512. elif isinstance(im_input, Image.Image):
  513. check_rgb(im_input)
  514. im = im_input
  515. else:
  516. assert isinstance(im_input, torch.Tensor), (
  517. "im_input must be a string, path, or PIL image"
  518. )
  519. B, C, H, W = im_input.shape
  520. assert C == 3, "im_input must be a RGB image"
  521. assert H % 14 == 0, "im_input must be a multiple of 14"
  522. assert W % 14 == 0, "im_input must be a multiple of 14"
  523. im = im_input
  524. return im
  525. class RegressionMatcher(nn.Module):
  526. def __init__(
  527. self,
  528. encoder: CNNandDinov2,
  529. decoder: Decoder,
  530. h=448,
  531. w=448,
  532. sample_mode="threshold_balanced",
  533. upsample_preds=False,
  534. symmetric=False,
  535. sample_thresh=0.05,
  536. name=None,
  537. attenuate_cert=None,
  538. upsample_res=None,
  539. ):
  540. super().__init__()
  541. self.attenuate_cert = attenuate_cert
  542. self.encoder = encoder
  543. self.decoder = decoder
  544. self.name = name
  545. self.w_resized = w
  546. self.h_resized = h
  547. self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
  548. self.sample_mode = sample_mode
  549. self.upsample_preds = upsample_preds
  550. self.upsample_res = upsample_res or (14 * 16 * 6, 14 * 16 * 6)
  551. self.symmetric = symmetric
  552. self.sample_thresh = sample_thresh
  553. def get_output_resolution(self):
  554. if not self.upsample_preds:
  555. return self.h_resized, self.w_resized
  556. else:
  557. return self.upsample_res
  558. def extract_backbone_features(self, batch, batched=True, upsample=False):
  559. if 'unique_images' in batch:
  560. unique_images = batch['unique_images']
  561. im_AB_idx = batch['im_AB_idx']
  562. feature_pyramid0 = self.encoder(unique_images, upsample=upsample)
  563. feature_pyramid = {
  564. scale: feature_pyramid0[scale][im_AB_idx]
  565. for scale in feature_pyramid0
  566. }
  567. return feature_pyramid
  568. x_q = batch["im_A"]
  569. x_s = batch["im_B"]
  570. if batched:
  571. X = torch.cat((x_q, x_s), dim=0)
  572. feature_pyramid = self.encoder(X, upsample=upsample)
  573. else:
  574. feature_pyramid = (
  575. self.encoder(x_q, upsample=upsample),
  576. self.encoder(x_s, upsample=upsample),
  577. )
  578. return feature_pyramid
  579. def sample(
  580. self,
  581. matches,
  582. certainty,
  583. num=10000,
  584. ):
  585. if "threshold" in self.sample_mode:
  586. upper_thresh = self.sample_thresh
  587. certainty = certainty.clone()
  588. certainty[certainty > upper_thresh] = 1
  589. matches, certainty = (
  590. matches.reshape(-1, 4),
  591. certainty.reshape(-1),
  592. )
  593. expansion_factor = 4 if "balanced" in self.sample_mode else 1
  594. good_samples = torch.multinomial(
  595. certainty,
  596. num_samples=min(expansion_factor * num, len(certainty)),
  597. replacement=False,
  598. )
  599. good_matches, good_certainty = matches[good_samples], certainty[good_samples]
  600. if "balanced" not in self.sample_mode:
  601. return good_matches, good_certainty
  602. density = kde(good_matches, std=0.1)
  603. p = 1 / (density + 1)
  604. p[density < 10] = (
  605. 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
  606. )
  607. balanced_samples = torch.multinomial(
  608. p, num_samples=min(num, len(good_certainty)), replacement=False
  609. )
  610. return good_matches[balanced_samples], good_certainty[balanced_samples]
  611. def forward(self, batch, batched=True, upsample=False, scale_factor=1):
  612. feature_pyramid = self.extract_backbone_features(
  613. batch, batched=batched, upsample=upsample
  614. )
  615. if batched:
  616. f_q_pyramid = {
  617. scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
  618. }
  619. f_s_pyramid = {
  620. scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
  621. }
  622. else:
  623. f_q_pyramid, f_s_pyramid = feature_pyramid
  624. corresps = self.decoder(
  625. f_q_pyramid,
  626. f_s_pyramid,
  627. upsample=upsample,
  628. **(batch["corresps"] if "corresps" in batch else {}),
  629. scale_factor=scale_factor,
  630. )
  631. return corresps
  632. def forward_symmetric(self, batch, batched=True, upsample=False, scale_factor=1):
  633. feature_pyramid = self.extract_backbone_features(
  634. batch, batched=batched, upsample=upsample
  635. )
  636. f_q_pyramid = feature_pyramid
  637. f_s_pyramid = {
  638. scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim=0)
  639. for scale, f_scale in feature_pyramid.items()
  640. }
  641. corresps = self.decoder(
  642. f_q_pyramid,
  643. f_s_pyramid,
  644. upsample=upsample,
  645. **(batch["corresps"] if "corresps" in batch else {}),
  646. scale_factor=scale_factor,
  647. )
  648. return corresps
  649. def conf_from_fb_consistency(self, flow_forward, flow_backward, th=2):
  650. # assumes that flow forward is of shape (..., H, W, 2)
  651. has_batch = False
  652. if len(flow_forward.shape) == 3:
  653. flow_forward, flow_backward = flow_forward[None], flow_backward[None]
  654. else:
  655. has_batch = True
  656. H, W = flow_forward.shape[-3:-1]
  657. th_n = 2 * th / max(H, W)
  658. coords = torch.stack(
  659. torch.meshgrid(
  660. torch.linspace(-1 + 1 / W, 1 - 1 / W, W),
  661. torch.linspace(-1 + 1 / H, 1 - 1 / H, H),
  662. indexing="xy",
  663. ),
  664. dim=-1,
  665. ).to(flow_forward.device)
  666. coords_fb = F.grid_sample(
  667. flow_backward.permute(0, 3, 1, 2),
  668. flow_forward,
  669. align_corners=False,
  670. mode="bilinear",
  671. ).permute(0, 2, 3, 1)
  672. diff = (coords - coords_fb).norm(dim=-1)
  673. in_th = (diff < th_n).float()
  674. if not has_batch:
  675. in_th = in_th[0]
  676. return in_th
  677. def to_pixel_coordinates(self, coords, H_A, W_A, H_B=None, W_B=None):
  678. if coords.shape[-1] == 2:
  679. return self._to_pixel_coordinates(coords, H_A, W_A)
  680. if isinstance(coords, (list, tuple)):
  681. kpts_A, kpts_B = coords[0], coords[1]
  682. else:
  683. kpts_A, kpts_B = coords[..., :2], coords[..., 2:]
  684. return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(
  685. kpts_B, H_B, W_B
  686. )
  687. def _to_pixel_coordinates(self, coords, H, W):
  688. kpts = torch.stack(
  689. (W / 2 * (coords[..., 0] + 1), H / 2 * (coords[..., 1] + 1)), axis=-1
  690. )
  691. return kpts
  692. def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B):
  693. if isinstance(coords, (list, tuple)):
  694. kpts_A, kpts_B = coords[0], coords[1]
  695. else:
  696. kpts_A, kpts_B = coords[..., :2], coords[..., 2:]
  697. kpts_A = torch.stack(
  698. (2 / W_A * kpts_A[..., 0] - 1, 2 / H_A * kpts_A[..., 1] - 1), axis=-1
  699. )
  700. kpts_B = torch.stack(
  701. (2 / W_B * kpts_B[..., 0] - 1, 2 / H_B * kpts_B[..., 1] - 1), axis=-1
  702. )
  703. return kpts_A, kpts_B
  704. def match_keypoints(
  705. self,
  706. x_A,
  707. x_B,
  708. warp,
  709. certainty,
  710. return_tuple=True,
  711. return_inds=False,
  712. max_dist=0.005,
  713. cert_th=0,
  714. ):
  715. x_A_to_B = F.grid_sample(
  716. warp[..., -2:].permute(2, 0, 1)[None],
  717. x_A[None, None],
  718. align_corners=False,
  719. mode="bilinear",
  720. )[0, :, 0].mT
  721. cert_A_to_B = F.grid_sample(
  722. certainty[None, None, ...],
  723. x_A[None, None],
  724. align_corners=False,
  725. mode="bilinear",
  726. )[0, 0, 0]
  727. D = torch.cdist(x_A_to_B, x_B)
  728. inds_A, inds_B = torch.nonzero(
  729. (D == D.min(dim=-1, keepdim=True).values)
  730. * (D == D.min(dim=-2, keepdim=True).values)
  731. * (cert_A_to_B[:, None] > cert_th)
  732. * (D < max_dist),
  733. as_tuple=True,
  734. )
  735. if return_tuple:
  736. if return_inds:
  737. return inds_A, inds_B
  738. else:
  739. return x_A[inds_A], x_B[inds_B]
  740. else:
  741. if return_inds:
  742. return torch.cat((inds_A, inds_B), dim=-1)
  743. else:
  744. return torch.cat((x_A[inds_A], x_B[inds_B]), dim=-1)
  745. def _get_device(self):
  746. # let's hope this is same for all weights
  747. return self.encoder.cnn.layers[0].weight.device
  748. @torch.inference_mode()
  749. def match(
  750. self,
  751. im_A_input,
  752. im_B_input,
  753. *args,
  754. im_A_high_res=None,
  755. im_B_high_res=None,
  756. batched=True,
  757. device=None,
  758. ):
  759. self.train(False)
  760. if not batched:
  761. raise ValueError("batched must be True, non-batched inference is no longer supported.")
  762. if device is None and not isinstance(im_A_input, torch.Tensor):
  763. device = self._get_device()
  764. elif device is None and isinstance(im_A_input, torch.Tensor):
  765. device = im_A_input.device
  766. # Check if inputs are file paths or already loaded images
  767. im_A = _check_input(im_A_input)
  768. im_B = _check_input(im_B_input)
  769. symmetric = self.symmetric
  770. ws = self.w_resized
  771. hs = self.h_resized
  772. scale_factor = math.sqrt(hs * ws / (560**2)) # divide by training resolution
  773. if isinstance(im_A, Image.Image) and isinstance(im_B, Image.Image):
  774. b = 1
  775. w, h = im_A.size
  776. w2, h2 = im_B.size
  777. # Get images in good format
  778. test_transform = get_tuple_transform_ops(
  779. resize=(hs, ws), normalize=True, clahe=False
  780. )
  781. im_A, im_B = test_transform((im_A, im_B))
  782. batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
  783. elif isinstance(im_A, torch.Tensor) and isinstance(im_B, torch.Tensor):
  784. b, c, h, w = im_A.shape
  785. b, c, h2, w2 = im_B.shape
  786. assert w == w2 and h == h2, "For batched images we assume same size"
  787. batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
  788. if h != self.h_resized or self.w_resized != w:
  789. warn(
  790. "Model resolution and batch resolution differ, may produce unexpected results"
  791. )
  792. hs, ws = h, w
  793. else:
  794. raise ValueError(f"Unsupported input type: {type(im_A)=} and {type(im_B)=}")
  795. finest_scale = 1
  796. # Run matcher
  797. if symmetric:
  798. corresps = self.forward_symmetric(batch, scale_factor=scale_factor)
  799. else:
  800. corresps = self(batch, batched=True, scale_factor=scale_factor)
  801. if self.upsample_preds:
  802. hs, ws = self.upsample_res
  803. if self.attenuate_cert:
  804. low_res_certainty = F.interpolate(
  805. corresps[16]["certainty"],
  806. size=(hs, ws),
  807. align_corners=False,
  808. mode="bilinear",
  809. )
  810. cert_clamp = 0
  811. factor = 0.5
  812. low_res_certainty = (
  813. factor * low_res_certainty * (low_res_certainty < cert_clamp)
  814. )
  815. finest_corresps = corresps[finest_scale]
  816. if self.upsample_preds and im_A_high_res is None and im_B_high_res is None:
  817. torch.cuda.empty_cache()
  818. test_transform = get_tuple_transform_ops(resize=(hs, ws), normalize=True)
  819. if isinstance(im_A_input, (str, os.PathLike)):
  820. im_A, im_B = test_transform(
  821. (
  822. Image.open(im_A_input).convert("RGB"),
  823. Image.open(im_B_input).convert("RGB"),
  824. )
  825. )
  826. else:
  827. assert isinstance(im_A_input, Image.Image), f"Unsupported input type: {type(im_A_input)=}"
  828. assert isinstance(im_B_input, Image.Image), f"Unsupported input type: {type(im_B_input)=}"
  829. im_A, im_B = test_transform((im_A_input, im_B_input))
  830. im_A, im_B = im_A[None].to(device), im_B[None].to(device)
  831. batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
  832. elif self.upsample_preds and im_A_high_res is not None and im_B_high_res is not None:
  833. batch = {"im_A": im_A_high_res, "im_B": im_B_high_res, "corresps": finest_corresps}
  834. elif self.upsample_preds:
  835. raise ValueError(f"Invalid upsample_preds and high_res inputs with {im_A=},{im_A_high_res=},{im_B=} and {im_B_high_res=}")
  836. if self.upsample_preds:
  837. scale_factor = math.sqrt(
  838. self.upsample_res[0]
  839. * self.upsample_res[1]
  840. / (560**2) # divide by training resolution
  841. )
  842. if symmetric:
  843. corresps = self.forward_symmetric(
  844. batch, upsample=True, batched=True, scale_factor=scale_factor
  845. )
  846. else:
  847. corresps = self(
  848. batch, batched=True, upsample=True, scale_factor=scale_factor
  849. )
  850. im_A_to_im_B = corresps[finest_scale]["flow"]
  851. certainty = corresps[finest_scale]["certainty"] - (
  852. low_res_certainty if self.attenuate_cert else 0
  853. )
  854. if finest_scale != 1:
  855. im_A_to_im_B = F.interpolate(
  856. im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
  857. )
  858. certainty = F.interpolate(
  859. certainty, size=(hs, ws), align_corners=False, mode="bilinear"
  860. )
  861. im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
  862. # Create im_A meshgrid
  863. im_A_coords = torch.meshgrid(
  864. (
  865. torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
  866. torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
  867. ),
  868. indexing="ij",
  869. )
  870. im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
  871. im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
  872. certainty = certainty.sigmoid() # logits -> probs
  873. im_A_coords = im_A_coords.permute(0, 2, 3, 1)
  874. if (im_A_to_im_B.abs() > 1).any() and True:
  875. wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
  876. certainty[wrong[:, None]] = 0
  877. im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
  878. if symmetric:
  879. A_to_B, B_to_A = im_A_to_im_B.chunk(2)
  880. q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
  881. im_B_coords = im_A_coords
  882. s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
  883. warp = torch.cat((q_warp, s_warp), dim=2)
  884. certainty = torch.cat(certainty.chunk(2), dim=3)
  885. else:
  886. warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
  887. if batched:
  888. return (warp, certainty[:, 0])
  889. else:
  890. return (
  891. warp[0],
  892. certainty[0, 0],
  893. )
  894. def visualize_warp(
  895. self,
  896. warp,
  897. certainty,
  898. im_A=None,
  899. im_B=None,
  900. im_A_path=None,
  901. im_B_path=None,
  902. device="cuda",
  903. symmetric=True,
  904. save_path=None,
  905. unnormalize=False,
  906. ):
  907. # assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)"
  908. H, W2, _ = warp.shape
  909. W = W2 // 2 if symmetric else W2
  910. if im_A is None:
  911. from PIL import Image
  912. im_A, im_B = (
  913. Image.open(im_A_path).convert("RGB"),
  914. Image.open(im_B_path).convert("RGB"),
  915. )
  916. if not isinstance(im_A, torch.Tensor):
  917. im_A = im_A.resize((W, H))
  918. im_B = im_B.resize((W, H))
  919. x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
  920. if symmetric:
  921. x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
  922. else:
  923. if symmetric:
  924. x_A = im_A
  925. x_B = im_B
  926. im_A_transfer_rgb = F.grid_sample(
  927. x_B[None], warp[:, :W, 2:][None], mode="bilinear", align_corners=False
  928. )[0]
  929. if symmetric:
  930. im_B_transfer_rgb = F.grid_sample(
  931. x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
  932. )[0]
  933. warp_im = torch.cat((im_A_transfer_rgb, im_B_transfer_rgb), dim=2)
  934. white_im = torch.ones((H, 2 * W), device=device)
  935. else:
  936. warp_im = im_A_transfer_rgb
  937. white_im = torch.ones((H, W), device=device)
  938. vis_im = certainty * warp_im + (1 - certainty) * white_im
  939. if save_path is not None:
  940. from romatch.utils import tensor_to_pil
  941. tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
  942. return vis_im