roma_indoor.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. import os
  2. import torch
  3. from argparse import ArgumentParser
  4. from warnings import warn
  5. from torch import nn
  6. from torch.utils.data import ConcatDataset
  7. import torch.distributed as dist
  8. from torch.nn.parallel import DistributedDataParallel as DDP
  9. import json
  10. import wandb
  11. from tqdm import tqdm
  12. from romatch.benchmarks import MegadepthDenseBenchmark
  13. from romatch.datasets.megadepth import MegadepthBuilder
  14. from romatch.datasets.scannet import ScanNetBuilder
  15. from romatch.losses.robust_loss import RobustLosses
  16. from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
  17. from romatch.train.train import train_k_steps
  18. from romatch.models.matcher import *
  19. from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
  20. from romatch.models.encoders import *
  21. from romatch.checkpointing import CheckPoint
  22. resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
  23. def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
  24. gp_dim = 512
  25. feat_dim = 512
  26. decoder_dim = gp_dim + feat_dim
  27. cls_to_coord_res = 64
  28. coordinate_decoder = TransformerDecoder(
  29. nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
  30. decoder_dim,
  31. cls_to_coord_res**2 + 1,
  32. is_classifier=True,
  33. amp = True,
  34. pos_enc = False,)
  35. dw = True
  36. hidden_blocks = 8
  37. kernel_size = 5
  38. displacement_emb = "linear"
  39. disable_local_corr_grad = True
  40. conv_refiner = nn.ModuleDict(
  41. {
  42. "16": ConvRefiner(
  43. 2 * 512+128+(2*7+1)**2,
  44. 2 * 512+128+(2*7+1)**2,
  45. 2 + 1,
  46. kernel_size=kernel_size,
  47. dw=dw,
  48. hidden_blocks=hidden_blocks,
  49. displacement_emb=displacement_emb,
  50. displacement_emb_dim=128,
  51. local_corr_radius = 7,
  52. corr_in_other = True,
  53. amp = True,
  54. disable_local_corr_grad = disable_local_corr_grad,
  55. bn_momentum = 0.01,
  56. ),
  57. "8": ConvRefiner(
  58. 2 * 512+64+(2*3+1)**2,
  59. 2 * 512+64+(2*3+1)**2,
  60. 2 + 1,
  61. kernel_size=kernel_size,
  62. dw=dw,
  63. hidden_blocks=hidden_blocks,
  64. displacement_emb=displacement_emb,
  65. displacement_emb_dim=64,
  66. local_corr_radius = 3,
  67. corr_in_other = True,
  68. amp = True,
  69. disable_local_corr_grad = disable_local_corr_grad,
  70. bn_momentum = 0.01,
  71. ),
  72. "4": ConvRefiner(
  73. 2 * 256+32+(2*2+1)**2,
  74. 2 * 256+32+(2*2+1)**2,
  75. 2 + 1,
  76. kernel_size=kernel_size,
  77. dw=dw,
  78. hidden_blocks=hidden_blocks,
  79. displacement_emb=displacement_emb,
  80. displacement_emb_dim=32,
  81. local_corr_radius = 2,
  82. corr_in_other = True,
  83. amp = True,
  84. disable_local_corr_grad = disable_local_corr_grad,
  85. bn_momentum = 0.01,
  86. ),
  87. "2": ConvRefiner(
  88. 2 * 64+16,
  89. 128+16,
  90. 2 + 1,
  91. kernel_size=kernel_size,
  92. dw=dw,
  93. hidden_blocks=hidden_blocks,
  94. displacement_emb=displacement_emb,
  95. displacement_emb_dim=16,
  96. amp = True,
  97. disable_local_corr_grad = disable_local_corr_grad,
  98. bn_momentum = 0.01,
  99. ),
  100. "1": ConvRefiner(
  101. 2 * 9 + 6,
  102. 24,
  103. 2 + 1,
  104. kernel_size=kernel_size,
  105. dw=dw,
  106. hidden_blocks = hidden_blocks,
  107. displacement_emb = displacement_emb,
  108. displacement_emb_dim = 6,
  109. amp = True,
  110. disable_local_corr_grad = disable_local_corr_grad,
  111. bn_momentum = 0.01,
  112. ),
  113. }
  114. )
  115. kernel_temperature = 0.2
  116. learn_temperature = False
  117. no_cov = True
  118. kernel = CosKernel
  119. only_attention = False
  120. basis = "fourier"
  121. gp16 = GP(
  122. kernel,
  123. T=kernel_temperature,
  124. learn_temperature=learn_temperature,
  125. only_attention=only_attention,
  126. gp_dim=gp_dim,
  127. basis=basis,
  128. no_cov=no_cov,
  129. )
  130. gps = nn.ModuleDict({"16": gp16})
  131. proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
  132. proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
  133. proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
  134. proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
  135. proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
  136. proj = nn.ModuleDict({
  137. "16": proj16,
  138. "8": proj8,
  139. "4": proj4,
  140. "2": proj2,
  141. "1": proj1,
  142. })
  143. displacement_dropout_p = 0.0
  144. gm_warp_dropout_p = 0.0
  145. decoder = Decoder(coordinate_decoder,
  146. gps,
  147. proj,
  148. conv_refiner,
  149. detach=True,
  150. scales=["16", "8", "4", "2", "1"],
  151. displacement_dropout_p = displacement_dropout_p,
  152. gm_warp_dropout_p = gm_warp_dropout_p)
  153. h,w = resolutions[resolution]
  154. encoder = CNNandDinov2(
  155. cnn_kwargs = dict(
  156. pretrained=pretrained_backbone,
  157. amp = True),
  158. amp = True,
  159. use_vgg = True,
  160. )
  161. matcher = RegressionMatcher(encoder, decoder, h=h, w=w, alpha=1, beta=0,**kwargs)
  162. return matcher
  163. def train(args):
  164. dist.init_process_group('nccl')
  165. #torch._dynamo.config.verbose=True
  166. gpus = int(os.environ['WORLD_SIZE'])
  167. # create model and move it to GPU with id rank
  168. rank = dist.get_rank()
  169. print(f"Start running DDP on rank {rank}")
  170. device_id = rank % torch.cuda.device_count()
  171. romatch.LOCAL_RANK = device_id
  172. torch.cuda.set_device(device_id)
  173. resolution = args.train_resolution
  174. wandb_log = not args.dont_log_wandb
  175. experiment_name = os.path.splitext(os.path.basename(__file__))[0]
  176. wandb_mode = "online" if wandb_log and rank == 0 and False else "disabled"
  177. wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
  178. checkpoint_dir = "workspace/checkpoints/"
  179. h,w = resolutions[resolution]
  180. model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
  181. # Num steps
  182. global_step = 0
  183. batch_size = args.gpu_batch_size
  184. step_size = gpus*batch_size
  185. romatch.STEP_SIZE = step_size
  186. N = (32 * 250000) # 250k steps of batch size 32
  187. # checkpoint every
  188. k = 25000 // romatch.STEP_SIZE
  189. # Data
  190. mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
  191. use_horizontal_flip_aug = True
  192. rot_prob = 0
  193. depth_interpolation_mode = "bilinear"
  194. megadepth_train1 = mega.build_scenes(
  195. split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
  196. ht=h,wt=w,
  197. )
  198. megadepth_train2 = mega.build_scenes(
  199. split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
  200. ht=h,wt=w,
  201. )
  202. megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
  203. mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
  204. scannet = ScanNetBuilder(data_root="data/scannet")
  205. scannet_train = scannet.build_scenes(split="train", ht=h, wt=w, use_horizontal_flip_aug = use_horizontal_flip_aug)
  206. scannet_train = ConcatDataset(scannet_train)
  207. scannet_ws = scannet.weight_scenes(scannet_train, alpha=0.75)
  208. # Loss and optimizer
  209. depth_loss_scannet = RobustLosses(
  210. ce_weight=0.0,
  211. local_dist={1:4, 2:4, 4:8, 8:8},
  212. local_largest_scale=8,
  213. depth_interpolation_mode=depth_interpolation_mode,
  214. alpha = 0.5,
  215. c = 1e-4,)
  216. # Loss and optimizer
  217. depth_loss_mega = RobustLosses(
  218. ce_weight=0.01,
  219. local_dist={1:4, 2:4, 4:8, 8:8},
  220. local_largest_scale=8,
  221. depth_interpolation_mode=depth_interpolation_mode,
  222. alpha = 0.5,
  223. c = 1e-4,)
  224. parameters = [
  225. {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
  226. {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
  227. ]
  228. optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
  229. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
  230. optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
  231. megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
  232. checkpointer = CheckPoint(checkpoint_dir, experiment_name)
  233. model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
  234. romatch.GLOBAL_STEP = global_step
  235. ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
  236. grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
  237. grad_clip_norm = 0.01
  238. for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
  239. mega_sampler = torch.utils.data.WeightedRandomSampler(
  240. mega_ws, num_samples = batch_size * k, replacement=False
  241. )
  242. mega_dataloader = iter(
  243. torch.utils.data.DataLoader(
  244. megadepth_train,
  245. batch_size = batch_size,
  246. sampler = mega_sampler,
  247. num_workers = 8,
  248. )
  249. )
  250. scannet_ws_sampler = torch.utils.data.WeightedRandomSampler(
  251. scannet_ws, num_samples=batch_size * k, replacement=False
  252. )
  253. scannet_dataloader = iter(
  254. torch.utils.data.DataLoader(
  255. scannet_train,
  256. batch_size=batch_size,
  257. sampler=scannet_ws_sampler,
  258. num_workers=gpus * 8,
  259. )
  260. )
  261. for n_k in tqdm(range(n, n + 2 * k, 2),disable = romatch.RANK > 0):
  262. train_k_steps(
  263. n_k, 1, mega_dataloader, ddp_model, depth_loss_mega, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
  264. )
  265. train_k_steps(
  266. n_k + 1, 1, scannet_dataloader, ddp_model, depth_loss_scannet, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
  267. )
  268. checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
  269. wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
  270. def test_scannet(model, name, resolution, sample_mode):
  271. scannet_benchmark = ScanNetBenchmark("data/scannet")
  272. scannet_results = scannet_benchmark.benchmark(model)
  273. json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
  274. if __name__ == "__main__":
  275. import warnings
  276. warn('Current version of romatch is not tested for training, use at your own risk.')
  277. warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
  278. warnings.filterwarnings('ignore')#, category=UserWarning)#, message='WARNING batched routines are designed for small sizes.')
  279. os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
  280. os.environ["OMP_NUM_THREADS"] = "16"
  281. import romatch
  282. parser = ArgumentParser()
  283. parser.add_argument("--test", action='store_true')
  284. parser.add_argument("--debug_mode", action='store_true')
  285. parser.add_argument("--dont_log_wandb", action='store_true')
  286. parser.add_argument("--train_resolution", default='medium')
  287. parser.add_argument("--gpu_batch_size", default=4, type=int)
  288. parser.add_argument("--wandb_entity", required = False)
  289. args, _ = parser.parse_known_args()
  290. romatch.DEBUG_MODE = args.debug_mode
  291. if not args.test:
  292. train(args)
  293. experiment_name = os.path.splitext(os.path.basename(__file__))[0]
  294. checkpoint_dir = "workspace/"
  295. checkpoint_name = checkpoint_dir + experiment_name + ".pth"
  296. test_resolution = "medium"
  297. sample_mode = "threshold_balanced"
  298. symmetric = True
  299. upsample_preds = False
  300. attenuate_cert = True
  301. model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert)
  302. model = model.cuda()
  303. states = torch.load(checkpoint_name)
  304. model.load_state_dict(states["model"])
  305. test_scannet(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode)