matcher.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772
  1. import os
  2. import math
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from einops import rearrange
  8. import warnings
  9. from warnings import warn
  10. from PIL import Image
  11. import romatch
  12. from romatch.utils import get_tuple_transform_ops
  13. from romatch.utils.local_correlation import local_correlation
  14. from romatch.utils.utils import cls_to_flow_refine
  15. from romatch.utils.kde import kde
  16. from typing import Union
  17. class ConvRefiner(nn.Module):
  18. def __init__(
  19. self,
  20. in_dim=6,
  21. hidden_dim=16,
  22. out_dim=2,
  23. dw=False,
  24. kernel_size=5,
  25. hidden_blocks=3,
  26. displacement_emb = None,
  27. displacement_emb_dim = None,
  28. local_corr_radius = None,
  29. corr_in_other = None,
  30. no_im_B_fm = False,
  31. amp = False,
  32. concat_logits = False,
  33. use_bias_block_1 = True,
  34. use_cosine_corr = False,
  35. disable_local_corr_grad = False,
  36. is_classifier = False,
  37. sample_mode = "bilinear",
  38. norm_type = nn.BatchNorm2d,
  39. bn_momentum = 0.1,
  40. amp_dtype = torch.float16,
  41. ):
  42. super().__init__()
  43. self.bn_momentum = bn_momentum
  44. self.block1 = self.create_block(
  45. in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1,
  46. )
  47. self.hidden_blocks = nn.Sequential(
  48. *[
  49. self.create_block(
  50. hidden_dim,
  51. hidden_dim,
  52. dw=dw,
  53. kernel_size=kernel_size,
  54. norm_type=norm_type,
  55. )
  56. for hb in range(hidden_blocks)
  57. ]
  58. )
  59. self.hidden_blocks = self.hidden_blocks
  60. self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
  61. if displacement_emb:
  62. self.has_displacement_emb = True
  63. self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
  64. else:
  65. self.has_displacement_emb = False
  66. self.local_corr_radius = local_corr_radius
  67. self.corr_in_other = corr_in_other
  68. self.no_im_B_fm = no_im_B_fm
  69. self.amp = amp
  70. self.concat_logits = concat_logits
  71. self.use_cosine_corr = use_cosine_corr
  72. self.disable_local_corr_grad = disable_local_corr_grad
  73. self.is_classifier = is_classifier
  74. self.sample_mode = sample_mode
  75. self.amp_dtype = amp_dtype
  76. def create_block(
  77. self,
  78. in_dim,
  79. out_dim,
  80. dw=False,
  81. kernel_size=5,
  82. bias = True,
  83. norm_type = nn.BatchNorm2d,
  84. ):
  85. num_groups = 1 if not dw else in_dim
  86. if dw:
  87. assert (
  88. out_dim % in_dim == 0
  89. ), "outdim must be divisible by indim for depthwise"
  90. conv1 = nn.Conv2d(
  91. in_dim,
  92. out_dim,
  93. kernel_size=kernel_size,
  94. stride=1,
  95. padding=kernel_size // 2,
  96. groups=num_groups,
  97. bias=bias,
  98. )
  99. norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim)
  100. relu = nn.ReLU(inplace=True)
  101. conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
  102. return nn.Sequential(conv1, norm, relu, conv2)
  103. def forward(self, x, y, flow, scale_factor = 1, logits = None):
  104. b,c,hs,ws = x.shape
  105. with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
  106. with torch.no_grad():
  107. x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode)
  108. if self.has_displacement_emb:
  109. im_A_coords = torch.meshgrid(
  110. (
  111. torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
  112. torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
  113. )
  114. )
  115. im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
  116. im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
  117. in_displacement = flow-im_A_coords
  118. emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement)
  119. if self.local_corr_radius:
  120. if self.corr_in_other:
  121. # Corr in other means take a kxk grid around the predicted coordinate in other image
  122. local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow,
  123. sample_mode = self.sample_mode)
  124. else:
  125. raise NotImplementedError("Local corr in own frame should not be used.")
  126. if self.no_im_B_fm:
  127. x_hat = torch.zeros_like(x)
  128. d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
  129. else:
  130. d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
  131. else:
  132. if self.no_im_B_fm:
  133. x_hat = torch.zeros_like(x)
  134. d = torch.cat((x, x_hat), dim=1)
  135. if self.concat_logits:
  136. d = torch.cat((d, logits), dim=1)
  137. d = self.block1(d)
  138. d = self.hidden_blocks(d)
  139. d = self.out_conv(d.float())
  140. displacement, certainty = d[:, :-1], d[:, -1:]
  141. return displacement, certainty
  142. class CosKernel(nn.Module): # similar to softmax kernel
  143. def __init__(self, T, learn_temperature=False):
  144. super().__init__()
  145. self.learn_temperature = learn_temperature
  146. if self.learn_temperature:
  147. self.T = nn.Parameter(torch.tensor(T))
  148. else:
  149. self.T = T
  150. def __call__(self, x, y, eps=1e-6):
  151. c = torch.einsum("bnd,bmd->bnm", x, y) / (
  152. x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
  153. )
  154. if self.learn_temperature:
  155. T = self.T.abs() + 0.01
  156. else:
  157. T = torch.tensor(self.T, device=c.device)
  158. K = ((c - 1.0) / T).exp()
  159. return K
  160. class GP(nn.Module):
  161. def __init__(
  162. self,
  163. kernel,
  164. T=1,
  165. learn_temperature=False,
  166. only_attention=False,
  167. gp_dim=64,
  168. basis="fourier",
  169. covar_size=5,
  170. only_nearest_neighbour=False,
  171. sigma_noise=0.1,
  172. no_cov=False,
  173. predict_features = False,
  174. ):
  175. super().__init__()
  176. self.K = kernel(T=T, learn_temperature=learn_temperature)
  177. self.sigma_noise = sigma_noise
  178. self.covar_size = covar_size
  179. self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
  180. self.only_attention = only_attention
  181. self.only_nearest_neighbour = only_nearest_neighbour
  182. self.basis = basis
  183. self.no_cov = no_cov
  184. self.dim = gp_dim
  185. self.predict_features = predict_features
  186. def get_local_cov(self, cov):
  187. K = self.covar_size
  188. b, h, w, h, w = cov.shape
  189. hw = h * w
  190. cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
  191. delta = torch.stack(
  192. torch.meshgrid(
  193. torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
  194. ),
  195. dim=-1,
  196. )
  197. positions = torch.stack(
  198. torch.meshgrid(
  199. torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
  200. ),
  201. dim=-1,
  202. )
  203. neighbours = positions[:, :, None, None, :] + delta[None, :, :]
  204. points = torch.arange(hw)[:, None].expand(hw, K**2)
  205. local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
  206. :,
  207. points.flatten(),
  208. neighbours[..., 0].flatten(),
  209. neighbours[..., 1].flatten(),
  210. ].reshape(b, h, w, K**2)
  211. return local_cov
  212. def reshape(self, x):
  213. return rearrange(x, "b d h w -> b (h w) d")
  214. def project_to_basis(self, x):
  215. if self.basis == "fourier":
  216. return torch.cos(8 * math.pi * self.pos_conv(x))
  217. elif self.basis == "linear":
  218. return self.pos_conv(x)
  219. else:
  220. raise ValueError(
  221. "No other bases other than fourier and linear currently im_Bed in public release"
  222. )
  223. def get_pos_enc(self, y):
  224. b, c, h, w = y.shape
  225. coarse_coords = torch.meshgrid(
  226. (
  227. torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
  228. torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
  229. )
  230. )
  231. coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
  232. None
  233. ].expand(b, h, w, 2)
  234. coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
  235. coarse_embedded_coords = self.project_to_basis(coarse_coords)
  236. return coarse_embedded_coords
  237. def forward(self, x, y, **kwargs):
  238. b, c, h1, w1 = x.shape
  239. b, c, h2, w2 = y.shape
  240. f = self.get_pos_enc(y)
  241. b, d, h2, w2 = f.shape
  242. x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f)
  243. K_xx = self.K(x, x)
  244. K_yy = self.K(y, y)
  245. K_xy = self.K(x, y)
  246. K_yx = K_xy.permute(0, 2, 1)
  247. sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
  248. with warnings.catch_warnings():
  249. K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
  250. mu_x = K_xy.matmul(K_yy_inv.matmul(f))
  251. mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
  252. if not self.no_cov:
  253. cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
  254. cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
  255. local_cov_x = self.get_local_cov(cov_x)
  256. local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
  257. gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
  258. else:
  259. gp_feats = mu_x
  260. return gp_feats
  261. class Decoder(nn.Module):
  262. def __init__(
  263. self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
  264. num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
  265. flow_upsample_mode = "bilinear", amp_dtype = torch.float16,
  266. ):
  267. super().__init__()
  268. self.embedding_decoder = embedding_decoder
  269. self.num_refinement_steps_per_scale = num_refinement_steps_per_scale
  270. self.gps = gps
  271. self.proj = proj
  272. self.conv_refiner = conv_refiner
  273. self.detach = detach
  274. if pos_embeddings is None:
  275. self.pos_embeddings = {}
  276. else:
  277. self.pos_embeddings = pos_embeddings
  278. if scales == "all":
  279. self.scales = ["32", "16", "8", "4", "2", "1"]
  280. else:
  281. self.scales = scales
  282. self.warp_noise_std = warp_noise_std
  283. self.refine_init = 4
  284. self.displacement_dropout_p = displacement_dropout_p
  285. self.gm_warp_dropout_p = gm_warp_dropout_p
  286. self.flow_upsample_mode = flow_upsample_mode
  287. self.amp_dtype = amp_dtype
  288. def get_placeholder_flow(self, b, h, w, device):
  289. coarse_coords = torch.meshgrid(
  290. (
  291. torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
  292. torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
  293. )
  294. )
  295. coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
  296. None
  297. ].expand(b, h, w, 2)
  298. coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
  299. return coarse_coords
  300. def get_positional_embedding(self, b, h ,w, device):
  301. coarse_coords = torch.meshgrid(
  302. (
  303. torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
  304. torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
  305. )
  306. )
  307. coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
  308. None
  309. ].expand(b, h, w, 2)
  310. coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
  311. coarse_embedded_coords = self.pos_embedding(coarse_coords)
  312. return coarse_embedded_coords
  313. def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1):
  314. coarse_scales = self.embedding_decoder.scales()
  315. all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
  316. sizes = {scale: f1[scale].shape[-2:] for scale in f1}
  317. h, w = sizes[1]
  318. b = f1[1].shape[0]
  319. device = f1[1].device
  320. coarsest_scale = int(all_scales[0])
  321. old_stuff = torch.zeros(
  322. b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
  323. )
  324. corresps = {}
  325. if not upsample:
  326. flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
  327. certainty = 0.0
  328. else:
  329. flow = F.interpolate(
  330. flow,
  331. size=sizes[coarsest_scale],
  332. align_corners=False,
  333. mode="bilinear",
  334. )
  335. certainty = F.interpolate(
  336. certainty,
  337. size=sizes[coarsest_scale],
  338. align_corners=False,
  339. mode="bilinear",
  340. )
  341. displacement = 0.0
  342. for new_scale in all_scales:
  343. ins = int(new_scale)
  344. corresps[ins] = {}
  345. f1_s, f2_s = f1[ins], f2[ins]
  346. if new_scale in self.proj:
  347. with torch.autocast("cuda", dtype = self.amp_dtype):
  348. f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
  349. if ins in coarse_scales:
  350. old_stuff = F.interpolate(
  351. old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
  352. )
  353. gp_posterior = self.gps[new_scale](f1_s, f2_s)
  354. gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
  355. gp_posterior, f1_s, old_stuff, new_scale
  356. )
  357. if self.embedding_decoder.is_classifier:
  358. flow = cls_to_flow_refine(
  359. gm_warp_or_cls,
  360. ).permute(0,3,1,2)
  361. corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
  362. else:
  363. corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
  364. flow = gm_warp_or_cls.detach()
  365. if new_scale in self.conv_refiner:
  366. corresps[ins].update({"flow_pre_delta": flow}) if self.training else None
  367. delta_flow, delta_certainty = self.conv_refiner[new_scale](
  368. f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty,
  369. )
  370. corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None
  371. displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w),
  372. delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,)
  373. flow = flow + displacement
  374. certainty = (
  375. certainty + delta_certainty
  376. ) # predict both certainty and displacement
  377. corresps[ins].update({
  378. "certainty": certainty,
  379. "flow": flow,
  380. })
  381. if new_scale != "1":
  382. flow = F.interpolate(
  383. flow,
  384. size=sizes[ins // 2],
  385. mode=self.flow_upsample_mode,
  386. )
  387. certainty = F.interpolate(
  388. certainty,
  389. size=sizes[ins // 2],
  390. mode=self.flow_upsample_mode,
  391. )
  392. if self.detach:
  393. flow = flow.detach()
  394. certainty = certainty.detach()
  395. #torch.cuda.empty_cache()
  396. return corresps
  397. class RegressionMatcher(nn.Module):
  398. def __init__(
  399. self,
  400. encoder,
  401. decoder,
  402. h=448,
  403. w=448,
  404. sample_mode = "threshold_balanced",
  405. upsample_preds = False,
  406. symmetric = False,
  407. name = None,
  408. attenuate_cert = None,
  409. recrop_upsample = False,
  410. ):
  411. super().__init__()
  412. self.attenuate_cert = attenuate_cert
  413. self.encoder = encoder
  414. self.decoder = decoder
  415. self.name = name
  416. self.w_resized = w
  417. self.h_resized = h
  418. self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
  419. self.sample_mode = sample_mode
  420. self.upsample_preds = upsample_preds
  421. self.upsample_res = (14*16*6, 14*16*6)
  422. self.symmetric = symmetric
  423. self.sample_thresh = 0.05
  424. self.recrop_upsample = recrop_upsample
  425. def get_output_resolution(self):
  426. if not self.upsample_preds:
  427. return self.h_resized, self.w_resized
  428. else:
  429. return self.upsample_res
  430. def extract_backbone_features(self, batch, batched = True, upsample = False):
  431. x_q = batch["im_A"]
  432. x_s = batch["im_B"]
  433. if batched:
  434. X = torch.cat((x_q, x_s), dim = 0)
  435. feature_pyramid = self.encoder(X, upsample = upsample)
  436. else:
  437. feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample)
  438. return feature_pyramid
  439. def sample(
  440. self,
  441. matches,
  442. certainty,
  443. num=10000,
  444. ):
  445. if "threshold" in self.sample_mode:
  446. upper_thresh = self.sample_thresh
  447. certainty = certainty.clone()
  448. certainty[certainty > upper_thresh] = 1
  449. matches, certainty = (
  450. matches.reshape(-1, 4),
  451. certainty.reshape(-1),
  452. )
  453. expansion_factor = 4 if "balanced" in self.sample_mode else 1
  454. good_samples = torch.multinomial(certainty,
  455. num_samples = min(expansion_factor*num, len(certainty)),
  456. replacement=False)
  457. good_matches, good_certainty = matches[good_samples], certainty[good_samples]
  458. if "balanced" not in self.sample_mode:
  459. return good_matches, good_certainty
  460. density = kde(good_matches, std=0.1)
  461. p = 1 / (density+1)
  462. p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
  463. balanced_samples = torch.multinomial(p,
  464. num_samples = min(num,len(good_certainty)),
  465. replacement=False)
  466. return good_matches[balanced_samples], good_certainty[balanced_samples]
  467. def forward(self, batch, batched = True, upsample = False, scale_factor = 1):
  468. feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample)
  469. if batched:
  470. f_q_pyramid = {
  471. scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
  472. }
  473. f_s_pyramid = {
  474. scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
  475. }
  476. else:
  477. f_q_pyramid, f_s_pyramid = feature_pyramid
  478. corresps = self.decoder(f_q_pyramid,
  479. f_s_pyramid,
  480. upsample = upsample,
  481. **(batch["corresps"] if "corresps" in batch else {}),
  482. scale_factor=scale_factor)
  483. return corresps
  484. def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1):
  485. feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample)
  486. f_q_pyramid = feature_pyramid
  487. f_s_pyramid = {
  488. scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0)
  489. for scale, f_scale in feature_pyramid.items()
  490. }
  491. corresps = self.decoder(f_q_pyramid,
  492. f_s_pyramid,
  493. upsample = upsample,
  494. **(batch["corresps"] if "corresps" in batch else {}),
  495. scale_factor=scale_factor)
  496. return corresps
  497. def conf_from_fb_consistency(self, flow_forward, flow_backward, th = 2):
  498. # assumes that flow forward is of shape (..., H, W, 2)
  499. has_batch = False
  500. if len(flow_forward.shape) == 3:
  501. flow_forward, flow_backward = flow_forward[None], flow_backward[None]
  502. else:
  503. has_batch = True
  504. H,W = flow_forward.shape[-3:-1]
  505. th_n = 2 * th / max(H,W)
  506. coords = torch.stack(torch.meshgrid(
  507. torch.linspace(-1 + 1 / W, 1 - 1 / W, W),
  508. torch.linspace(-1 + 1 / H, 1 - 1 / H, H), indexing = "xy"),
  509. dim = -1).to(flow_forward.device)
  510. coords_fb = F.grid_sample(
  511. flow_backward.permute(0, 3, 1, 2),
  512. flow_forward,
  513. align_corners=False, mode="bilinear").permute(0, 2, 3, 1)
  514. diff = (coords - coords_fb).norm(dim=-1)
  515. in_th = (diff < th_n).float()
  516. if not has_batch:
  517. in_th = in_th[0]
  518. return in_th
  519. def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
  520. if coords.shape[-1] == 2:
  521. return self._to_pixel_coordinates(coords, H_A, W_A)
  522. if isinstance(coords, (list, tuple)):
  523. kpts_A, kpts_B = coords[0], coords[1]
  524. else:
  525. kpts_A, kpts_B = coords[...,:2], coords[...,2:]
  526. return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
  527. def _to_pixel_coordinates(self, coords, H, W):
  528. kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
  529. return kpts
  530. def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B):
  531. if isinstance(coords, (list, tuple)):
  532. kpts_A, kpts_B = coords[0], coords[1]
  533. else:
  534. kpts_A, kpts_B = coords[...,:2], coords[...,2:]
  535. kpts_A = torch.stack((2/W_A * kpts_A[...,0] - 1, 2/H_A * kpts_A[...,1] - 1),axis=-1)
  536. kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
  537. return kpts_A, kpts_B
  538. def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return_inds = False):
  539. x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT
  540. cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0]
  541. D = torch.cdist(x_A_to_B, x_B)
  542. inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True)
  543. if return_tuple:
  544. if return_inds:
  545. return inds_A, inds_B
  546. else:
  547. return x_A[inds_A], x_B[inds_B]
  548. else:
  549. if return_inds:
  550. return torch.cat((inds_A, inds_B),dim=-1)
  551. else:
  552. return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
  553. def get_roi(self, certainty, W, H, thr = 0.025):
  554. raise NotImplementedError("WIP, disable for now")
  555. hs,ws = certainty.shape
  556. certainty = certainty/certainty.sum(dim=(-1,-2))
  557. cum_certainty_w = certainty.cumsum(dim=-1).sum(dim=-2)
  558. cum_certainty_h = certainty.cumsum(dim=-2).sum(dim=-1)
  559. print(cum_certainty_w)
  560. print(torch.min(torch.nonzero(cum_certainty_w > thr)))
  561. print(torch.min(torch.nonzero(cum_certainty_w < thr)))
  562. left = int(W/ws * torch.min(torch.nonzero(cum_certainty_w > thr)))
  563. right = int(W/ws * torch.max(torch.nonzero(cum_certainty_w < 1 - thr)))
  564. top = int(H/hs * torch.min(torch.nonzero(cum_certainty_h > thr)))
  565. bottom = int(H/hs * torch.max(torch.nonzero(cum_certainty_h < 1 - thr)))
  566. print(left, right, top, bottom)
  567. return left, top, right, bottom
  568. def recrop(self, certainty, image_path):
  569. roi = self.get_roi(certainty, *Image.open(image_path).size)
  570. return Image.open(image_path).convert("RGB").crop(roi)
  571. @torch.inference_mode()
  572. def match(
  573. self,
  574. im_A_path: Union[str, os.PathLike, Image.Image],
  575. im_B_path: Union[str, os.PathLike, Image.Image],
  576. *args,
  577. batched=False,
  578. device = None,
  579. ):
  580. if device is None:
  581. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  582. if isinstance(im_A_path, (str, os.PathLike)):
  583. im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
  584. else:
  585. im_A, im_B = im_A_path, im_B_path
  586. symmetric = self.symmetric
  587. self.train(False)
  588. with torch.no_grad():
  589. if not batched:
  590. b = 1
  591. w, h = im_A.size
  592. w2, h2 = im_B.size
  593. # Get images in good format
  594. ws = self.w_resized
  595. hs = self.h_resized
  596. test_transform = get_tuple_transform_ops(
  597. resize=(hs, ws), normalize=True, clahe = False
  598. )
  599. im_A, im_B = test_transform((im_A, im_B))
  600. batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
  601. else:
  602. b, c, h, w = im_A.shape
  603. b, c, h2, w2 = im_B.shape
  604. assert w == w2 and h == h2, "For batched images we assume same size"
  605. batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
  606. if h != self.h_resized or self.w_resized != w:
  607. warn("Model resolution and batch resolution differ, may produce unexpected results")
  608. hs, ws = h, w
  609. finest_scale = 1
  610. # Run matcher
  611. if symmetric:
  612. corresps = self.forward_symmetric(batch)
  613. else:
  614. corresps = self.forward(batch, batched = True)
  615. if self.upsample_preds:
  616. hs, ws = self.upsample_res
  617. if self.attenuate_cert:
  618. low_res_certainty = F.interpolate(
  619. corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
  620. )
  621. cert_clamp = 0
  622. factor = 0.5
  623. low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
  624. if self.upsample_preds:
  625. finest_corresps = corresps[finest_scale]
  626. torch.cuda.empty_cache()
  627. test_transform = get_tuple_transform_ops(
  628. resize=(hs, ws), normalize=True
  629. )
  630. if self.recrop_upsample:
  631. raise NotImplementedError("recrop_upsample not implemented")
  632. certainty = corresps[finest_scale]["certainty"]
  633. print(certainty.shape)
  634. im_A = self.recrop(certainty[0,0], im_A_path)
  635. im_B = self.recrop(certainty[1,0], im_B_path)
  636. #TODO: need to adjust corresps when doing this
  637. im_A, im_B = test_transform((im_A, im_B))
  638. im_A, im_B = im_A[None].to(device), im_B[None].to(device)
  639. scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
  640. batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
  641. if symmetric:
  642. corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
  643. else:
  644. corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
  645. im_A_to_im_B = corresps[finest_scale]["flow"]
  646. certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
  647. if finest_scale != 1:
  648. im_A_to_im_B = F.interpolate(
  649. im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
  650. )
  651. certainty = F.interpolate(
  652. certainty, size=(hs, ws), align_corners=False, mode="bilinear"
  653. )
  654. im_A_to_im_B = im_A_to_im_B.permute(
  655. 0, 2, 3, 1
  656. )
  657. # Create im_A meshgrid
  658. im_A_coords = torch.meshgrid(
  659. (
  660. torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
  661. torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
  662. )
  663. )
  664. im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
  665. im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
  666. certainty = certainty.sigmoid() # logits -> probs
  667. im_A_coords = im_A_coords.permute(0, 2, 3, 1)
  668. if (im_A_to_im_B.abs() > 1).any() and True:
  669. wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
  670. certainty[wrong[:,None]] = 0
  671. im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
  672. if symmetric:
  673. A_to_B, B_to_A = im_A_to_im_B.chunk(2)
  674. q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
  675. im_B_coords = im_A_coords
  676. s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
  677. warp = torch.cat((q_warp, s_warp),dim=2)
  678. certainty = torch.cat(certainty.chunk(2), dim=3)
  679. else:
  680. warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
  681. if batched:
  682. return (
  683. warp,
  684. certainty[:, 0]
  685. )
  686. else:
  687. return (
  688. warp[0],
  689. certainty[0, 0],
  690. )
  691. def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
  692. im_A_path = None, im_B_path = None, device = "cuda", symmetric = True, save_path = None, unnormalize = False):
  693. #assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)"
  694. H,W2,_ = warp.shape
  695. W = W2//2 if symmetric else W2
  696. if im_A is None:
  697. from PIL import Image
  698. im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
  699. if not isinstance(im_A, torch.Tensor):
  700. im_A = im_A.resize((W,H))
  701. im_B = im_B.resize((W,H))
  702. x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
  703. if symmetric:
  704. x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
  705. else:
  706. if symmetric:
  707. x_A = im_A
  708. x_B = im_B
  709. im_A_transfer_rgb = F.grid_sample(
  710. x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
  711. )[0]
  712. if symmetric:
  713. im_B_transfer_rgb = F.grid_sample(
  714. x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
  715. )[0]
  716. warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
  717. white_im = torch.ones((H,2*W),device=device)
  718. else:
  719. warp_im = im_A_transfer_rgb
  720. white_im = torch.ones((H, W), device = device)
  721. vis_im = certainty * warp_im + (1 - certainty) * white_im
  722. if save_path is not None:
  723. from romatch.utils import tensor_to_pil
  724. tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
  725. return vis_im