train_tiny_roma_v1_outdoor.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import os
  5. import torch
  6. from argparse import ArgumentParser
  7. from pathlib import Path
  8. import math
  9. import numpy as np
  10. from torch import nn
  11. from torch.utils.data import ConcatDataset
  12. import torch.distributed as dist
  13. from torch.nn.parallel import DistributedDataParallel as DDP
  14. import json
  15. import wandb
  16. from PIL import Image
  17. from torchvision.transforms import ToTensor
  18. from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
  19. from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
  20. from romatch.datasets.megadepth import MegadepthBuilder
  21. from romatch.losses.robust_loss_tiny_roma import RobustLosses
  22. from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
  23. from romatch.train.train import train_k_steps
  24. from romatch.checkpointing import CheckPoint
  25. resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6), "xfeat": (600,800), "big": (768, 1024)}
  26. def kde(x, std = 0.1):
  27. # use a gaussian kernel to estimate density
  28. x = x.half() # Do it in half precision TODO: remove hardcoding
  29. scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
  30. density = scores.sum(dim=-1)
  31. return density
  32. class BasicLayer(nn.Module):
  33. """
  34. Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
  35. """
  36. def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True):
  37. super().__init__()
  38. self.layer = nn.Sequential(
  39. nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
  40. nn.BatchNorm2d(out_channels, affine=False),
  41. nn.ReLU(inplace = True) if relu else nn.Identity()
  42. )
  43. def forward(self, x):
  44. return self.layer(x)
  45. class XFeatModel(nn.Module):
  46. """
  47. Implementation of architecture described in
  48. "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
  49. """
  50. def __init__(self, xfeat = None,
  51. freeze_xfeat = True,
  52. sample_mode = "threshold_balanced",
  53. symmetric = False,
  54. exact_softmax = False):
  55. super().__init__()
  56. if xfeat is None:
  57. xfeat = torch.hub.load('verlab/accelerated_features', 'XFeat', pretrained = True, top_k = 4096).net
  58. del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher
  59. if freeze_xfeat:
  60. xfeat.train(False)
  61. self.xfeat = [xfeat]# hide params from ddp
  62. else:
  63. self.xfeat = nn.ModuleList([xfeat])
  64. self.freeze_xfeat = freeze_xfeat
  65. match_dim = 256
  66. self.coarse_matcher = nn.Sequential(
  67. BasicLayer(64+64+2, match_dim,),
  68. BasicLayer(match_dim, match_dim,),
  69. BasicLayer(match_dim, match_dim,),
  70. BasicLayer(match_dim, match_dim,),
  71. nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0))
  72. fine_match_dim = 64
  73. self.fine_matcher = nn.Sequential(
  74. BasicLayer(24+24+2, fine_match_dim,),
  75. BasicLayer(fine_match_dim, fine_match_dim,),
  76. BasicLayer(fine_match_dim, fine_match_dim,),
  77. BasicLayer(fine_match_dim, fine_match_dim,),
  78. nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),)
  79. self.sample_mode = sample_mode
  80. self.sample_thresh = 0.2
  81. self.symmetric = symmetric
  82. self.exact_softmax = exact_softmax
  83. @property
  84. def device(self):
  85. return self.fine_matcher[-1].weight.device
  86. def preprocess_tensor(self, x):
  87. """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
  88. H, W = x.shape[-2:]
  89. _H, _W = (H//32) * 32, (W//32) * 32
  90. rh, rw = H/_H, W/_W
  91. x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
  92. return x, rh, rw
  93. def forward_single(self, x):
  94. with torch.inference_mode(self.freeze_xfeat or not self.training):
  95. xfeat = self.xfeat[0]
  96. with torch.no_grad():
  97. x = x.mean(dim=1, keepdim = True)
  98. x = xfeat.norm(x)
  99. #main backbone
  100. x1 = xfeat.block1(x)
  101. x2 = xfeat.block2(x1 + xfeat.skip1(x))
  102. x3 = xfeat.block3(x2)
  103. x4 = xfeat.block4(x3)
  104. x5 = xfeat.block5(x4)
  105. x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
  106. x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
  107. feats = xfeat.block_fusion( x3 + x4 + x5 )
  108. if self.freeze_xfeat:
  109. return x2.clone(), feats.clone()
  110. return x2, feats
  111. def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
  112. if coords.shape[-1] == 2:
  113. return self._to_pixel_coordinates(coords, H_A, W_A)
  114. if isinstance(coords, (list, tuple)):
  115. kpts_A, kpts_B = coords[0], coords[1]
  116. else:
  117. kpts_A, kpts_B = coords[...,:2], coords[...,2:]
  118. return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
  119. def _to_pixel_coordinates(self, coords, H, W):
  120. kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
  121. return kpts
  122. def pos_embed(self, corr_volume: torch.Tensor):
  123. B, H1, W1, H0, W0 = corr_volume.shape
  124. grid = torch.stack(
  125. torch.meshgrid(
  126. torch.linspace(-1+1/W1,1-1/W1, W1),
  127. torch.linspace(-1+1/H1,1-1/H1, H1),
  128. indexing = "xy"),
  129. dim = -1).float().to(corr_volume).reshape(H1*W1, 2)
  130. down = 4
  131. if not self.training and not self.exact_softmax:
  132. grid_lr = torch.stack(
  133. torch.meshgrid(
  134. torch.linspace(-1+down/W1,1-down/W1, W1//down),
  135. torch.linspace(-1+down/H1,1-down/H1, H1//down),
  136. indexing = "xy"),
  137. dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
  138. cv = corr_volume
  139. best_match = cv.reshape(B,H1*W1,H0,W0).amax(dim=1) # B, HW, H, W
  140. P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
  141. pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
  142. pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
  143. else:
  144. P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W
  145. pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid)
  146. return pos_embeddings
  147. def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
  148. im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False):
  149. device = warp.device
  150. H,W2,_ = warp.shape
  151. W = W2//2 if symmetric else W2
  152. if im_A is None:
  153. from PIL import Image
  154. im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
  155. if not isinstance(im_A, torch.Tensor):
  156. im_A = im_A.resize((W,H))
  157. im_B = im_B.resize((W,H))
  158. x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
  159. if symmetric:
  160. x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
  161. else:
  162. if symmetric:
  163. x_A = im_A
  164. x_B = im_B
  165. im_A_transfer_rgb = F.grid_sample(
  166. x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
  167. )[0]
  168. if symmetric:
  169. im_B_transfer_rgb = F.grid_sample(
  170. x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
  171. )[0]
  172. warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
  173. white_im = torch.ones((H,2*W),device=device)
  174. else:
  175. warp_im = im_A_transfer_rgb
  176. white_im = torch.ones((H, W), device = device)
  177. vis_im = certainty * warp_im + (1 - certainty) * white_im
  178. if save_path is not None:
  179. from romatch.utils import tensor_to_pil
  180. tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
  181. return vis_im
  182. def corr_volume(self, feat0, feat1):
  183. """
  184. input:
  185. feat0 -> torch.Tensor(B, C, H, W)
  186. feat1 -> torch.Tensor(B, C, H, W)
  187. return:
  188. corr_volume -> torch.Tensor(B, H, W, H, W)
  189. """
  190. B, C, H0, W0 = feat0.shape
  191. B, C, H1, W1 = feat1.shape
  192. feat0 = feat0.view(B, C, H0*W0)
  193. feat1 = feat1.view(B, C, H1*W1)
  194. corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16
  195. return corr_volume
  196. @torch.inference_mode()
  197. def match_from_path(self, im0_path, im1_path):
  198. device = self.device
  199. im0 = ToTensor()(Image.open(im0_path))[None].to(device)
  200. im1 = ToTensor()(Image.open(im1_path))[None].to(device)
  201. return self.match(im0, im1, batched = False)
  202. @torch.inference_mode()
  203. def match(self, im0, im1, *args, batched = True):
  204. # stupid
  205. if isinstance(im0, (str, Path)):
  206. return self.match_from_path(im0, im1)
  207. elif isinstance(im0, Image.Image):
  208. batched = False
  209. device = self.device
  210. im0 = ToTensor()(im0)[None].to(device)
  211. im1 = ToTensor()(im1)[None].to(device)
  212. B,C,H0,W0 = im0.shape
  213. B,C,H1,W1 = im1.shape
  214. self.train(False)
  215. corresps = self.forward({"im_A":im0, "im_B":im1})
  216. #return 1,1
  217. flow = F.interpolate(
  218. corresps[4]["flow"],
  219. size = (H0, W0),
  220. mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2)
  221. grid = torch.stack(
  222. torch.meshgrid(
  223. torch.linspace(-1+1/W0,1-1/W0, W0),
  224. torch.linspace(-1+1/H0,1-1/H0, H0),
  225. indexing = "xy"),
  226. dim = -1).float().to(flow.device).expand(B, H0, W0, 2)
  227. certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False)
  228. warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid()
  229. if batched:
  230. return warp, cert
  231. else:
  232. return warp[0], cert[0]
  233. def sample(
  234. self,
  235. matches,
  236. certainty,
  237. num=10000,
  238. ):
  239. if "threshold" in self.sample_mode:
  240. upper_thresh = self.sample_thresh
  241. certainty = certainty.clone()
  242. certainty[certainty > upper_thresh] = 1
  243. matches, certainty = (
  244. matches.reshape(-1, 4),
  245. certainty.reshape(-1),
  246. )
  247. expansion_factor = 4 if "balanced" in self.sample_mode else 1
  248. good_samples = torch.multinomial(certainty,
  249. num_samples = min(expansion_factor*num, len(certainty)),
  250. replacement=False)
  251. good_matches, good_certainty = matches[good_samples], certainty[good_samples]
  252. if "balanced" not in self.sample_mode:
  253. return good_matches, good_certainty
  254. density = kde(good_matches, std=0.1)
  255. p = 1 / (density+1)
  256. p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
  257. balanced_samples = torch.multinomial(p,
  258. num_samples = min(num,len(good_certainty)),
  259. replacement=False)
  260. return good_matches[balanced_samples], good_certainty[balanced_samples]
  261. def forward(self, batch):
  262. """
  263. input:
  264. x -> torch.Tensor(B, C, H, W) grayscale or rgb images
  265. return:
  266. """
  267. im0 = batch["im_A"]
  268. im1 = batch["im_B"]
  269. corresps = {}
  270. im0, rh0, rw0 = self.preprocess_tensor(im0)
  271. im1, rh1, rw1 = self.preprocess_tensor(im1)
  272. B, C, H0, W0 = im0.shape
  273. B, C, H1, W1 = im1.shape
  274. to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None]
  275. if im0.shape[-2:] == im1.shape[-2:]:
  276. x = torch.cat([im0, im1], dim=0)
  277. x = self.forward_single(x)
  278. feats_x0_c, feats_x1_c = x[1].chunk(2)
  279. feats_x0_f, feats_x1_f = x[0].chunk(2)
  280. else:
  281. feats_x0_f, feats_x0_c = self.forward_single(im0)
  282. feats_x1_f, feats_x1_c = self.forward_single(im1)
  283. corr_volume = self.corr_volume(feats_x0_c, feats_x1_c)
  284. coarse_warp = self.pos_embed(corr_volume)
  285. coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1)
  286. feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
  287. coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1))
  288. coarse_matches = coarse_matches + coarse_matches_delta * to_normalized
  289. corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]}
  290. coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False)
  291. coarse_matches_up_detach = coarse_matches_up.detach()#note the detach
  292. feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
  293. fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1))
  294. fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized
  295. corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]}
  296. return corresps
  297. def train(args):
  298. rank = 0
  299. gpus = 1
  300. device_id = rank % torch.cuda.device_count()
  301. romatch.LOCAL_RANK = 0
  302. torch.cuda.set_device(device_id)
  303. resolution = "big"
  304. wandb_log = not args.dont_log_wandb
  305. experiment_name = Path(__file__).stem
  306. wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
  307. wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
  308. checkpoint_dir = "workspace/checkpoints/"
  309. h,w = resolutions[resolution]
  310. model = XFeatModel(freeze_xfeat = False).to(device_id)
  311. # Num steps
  312. global_step = 0
  313. batch_size = args.gpu_batch_size
  314. step_size = gpus*batch_size
  315. romatch.STEP_SIZE = step_size
  316. N = 2_000_000 # 2M pairs
  317. # checkpoint every
  318. k = 25000 // romatch.STEP_SIZE
  319. # Data
  320. mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
  321. use_horizontal_flip_aug = True
  322. normalize = False # don't imgnet normalize
  323. rot_prob = 0
  324. depth_interpolation_mode = "bilinear"
  325. megadepth_train1 = mega.build_scenes(
  326. split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
  327. ht=h,wt=w, normalize = normalize
  328. )
  329. megadepth_train2 = mega.build_scenes(
  330. split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
  331. ht=h,wt=w, normalize = normalize
  332. )
  333. megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
  334. mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
  335. # Loss and optimizer
  336. depth_loss = RobustLosses(
  337. ce_weight=0.01,
  338. local_dist={4:4},
  339. depth_interpolation_mode=depth_interpolation_mode,
  340. alpha = {4:0.15, 8:0.15},
  341. c = 1e-4,
  342. epe_mask_prob_th = 0.001,
  343. )
  344. parameters = [
  345. {"params": model.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
  346. ]
  347. optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
  348. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
  349. optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
  350. #megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
  351. mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 30)
  352. checkpointer = CheckPoint(checkpoint_dir, experiment_name)
  353. model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
  354. romatch.GLOBAL_STEP = global_step
  355. grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
  356. grad_clip_norm = 0.01
  357. #megadense_benchmark.benchmark(model)
  358. for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
  359. mega_sampler = torch.utils.data.WeightedRandomSampler(
  360. mega_ws, num_samples = batch_size * k, replacement=False
  361. )
  362. mega_dataloader = iter(
  363. torch.utils.data.DataLoader(
  364. megadepth_train,
  365. batch_size = batch_size,
  366. sampler = mega_sampler,
  367. num_workers = 8,
  368. )
  369. )
  370. train_k_steps(
  371. n, k, mega_dataloader, model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
  372. )
  373. checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
  374. wandb.log(mega1500_benchmark.benchmark(model, model_name=experiment_name), step = romatch.GLOBAL_STEP)
  375. def test_mega_8_scenes(model, name):
  376. mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
  377. scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
  378. 'mega_8_scenes_0025_0.1_0.3.npz',
  379. 'mega_8_scenes_0021_0.1_0.3.npz',
  380. 'mega_8_scenes_0008_0.1_0.3.npz',
  381. 'mega_8_scenes_0032_0.1_0.3.npz',
  382. 'mega_8_scenes_1589_0.1_0.3.npz',
  383. 'mega_8_scenes_0063_0.1_0.3.npz',
  384. 'mega_8_scenes_0024_0.1_0.3.npz',
  385. 'mega_8_scenes_0019_0.3_0.5.npz',
  386. 'mega_8_scenes_0025_0.3_0.5.npz',
  387. 'mega_8_scenes_0021_0.3_0.5.npz',
  388. 'mega_8_scenes_0008_0.3_0.5.npz',
  389. 'mega_8_scenes_0032_0.3_0.5.npz',
  390. 'mega_8_scenes_1589_0.3_0.5.npz',
  391. 'mega_8_scenes_0063_0.3_0.5.npz',
  392. 'mega_8_scenes_0024_0.3_0.5.npz'])
  393. mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
  394. print(mega_8_scenes_results)
  395. json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
  396. def test_mega1500(model, name):
  397. mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
  398. mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
  399. json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
  400. def test_mega1500_poselib(model, name):
  401. mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
  402. mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
  403. json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
  404. def test_mega_8_scenes_poselib(model, name):
  405. mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
  406. scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
  407. 'mega_8_scenes_0025_0.1_0.3.npz',
  408. 'mega_8_scenes_0021_0.1_0.3.npz',
  409. 'mega_8_scenes_0008_0.1_0.3.npz',
  410. 'mega_8_scenes_0032_0.1_0.3.npz',
  411. 'mega_8_scenes_1589_0.1_0.3.npz',
  412. 'mega_8_scenes_0063_0.1_0.3.npz',
  413. 'mega_8_scenes_0024_0.1_0.3.npz',
  414. 'mega_8_scenes_0019_0.3_0.5.npz',
  415. 'mega_8_scenes_0025_0.3_0.5.npz',
  416. 'mega_8_scenes_0021_0.3_0.5.npz',
  417. 'mega_8_scenes_0008_0.3_0.5.npz',
  418. 'mega_8_scenes_0032_0.3_0.5.npz',
  419. 'mega_8_scenes_1589_0.3_0.5.npz',
  420. 'mega_8_scenes_0063_0.3_0.5.npz',
  421. 'mega_8_scenes_0024_0.3_0.5.npz'])
  422. mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
  423. json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
  424. def test_scannet_poselib(model, name):
  425. scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
  426. scannet_results = scannet_benchmark.benchmark(model)
  427. json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
  428. def test_scannet(model, name):
  429. scannet_benchmark = ScanNetBenchmark("data/scannet")
  430. scannet_results = scannet_benchmark.benchmark(model)
  431. json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
  432. if __name__ == "__main__":
  433. os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
  434. os.environ["OMP_NUM_THREADS"] = "16"
  435. torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
  436. import romatch
  437. parser = ArgumentParser()
  438. parser.add_argument("--only_test", action='store_true')
  439. parser.add_argument("--debug_mode", action='store_true')
  440. parser.add_argument("--dont_log_wandb", action='store_true')
  441. parser.add_argument("--train_resolution", default='medium')
  442. parser.add_argument("--gpu_batch_size", default=8, type=int)
  443. parser.add_argument("--wandb_entity", required = False)
  444. args, _ = parser.parse_known_args()
  445. romatch.DEBUG_MODE = args.debug_mode
  446. if not args.only_test:
  447. train(args)
  448. experiment_name = "tiny_roma_v1_outdoor"#Path(__file__).stem
  449. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  450. model = XFeatModel(freeze_xfeat=False, exact_softmax=False).to(device)
  451. model.load_state_dict(torch.load(f"{experiment_name}.pth"))
  452. test_mega1500_poselib(model, experiment_name)