roma_indoor.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. import os
  2. import torch
  3. from argparse import ArgumentParser
  4. from torch import nn
  5. from torch.utils.data import ConcatDataset
  6. import torch.distributed as dist
  7. from torch.nn.parallel import DistributedDataParallel as DDP
  8. import json
  9. import wandb
  10. from tqdm import tqdm
  11. from romatch.benchmarks import MegadepthDenseBenchmark
  12. from romatch.datasets.megadepth import MegadepthBuilder
  13. from romatch.datasets.scannet import ScanNetBuilder
  14. from romatch.losses.robust_loss import RobustLosses
  15. from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
  16. from romatch.train.train import train_k_steps
  17. from romatch.models.matcher import *
  18. from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
  19. from romatch.models.encoders import *
  20. from romatch.checkpointing import CheckPoint
  21. resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
  22. def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
  23. gp_dim = 512
  24. feat_dim = 512
  25. decoder_dim = gp_dim + feat_dim
  26. cls_to_coord_res = 64
  27. coordinate_decoder = TransformerDecoder(
  28. nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
  29. decoder_dim,
  30. cls_to_coord_res**2 + 1,
  31. is_classifier=True,
  32. amp = True,
  33. pos_enc = False,)
  34. dw = True
  35. hidden_blocks = 8
  36. kernel_size = 5
  37. displacement_emb = "linear"
  38. disable_local_corr_grad = True
  39. conv_refiner = nn.ModuleDict(
  40. {
  41. "16": ConvRefiner(
  42. 2 * 512+128+(2*7+1)**2,
  43. 2 * 512+128+(2*7+1)**2,
  44. 2 + 1,
  45. kernel_size=kernel_size,
  46. dw=dw,
  47. hidden_blocks=hidden_blocks,
  48. displacement_emb=displacement_emb,
  49. displacement_emb_dim=128,
  50. local_corr_radius = 7,
  51. corr_in_other = True,
  52. amp = True,
  53. disable_local_corr_grad = disable_local_corr_grad,
  54. bn_momentum = 0.01,
  55. ),
  56. "8": ConvRefiner(
  57. 2 * 512+64+(2*3+1)**2,
  58. 2 * 512+64+(2*3+1)**2,
  59. 2 + 1,
  60. kernel_size=kernel_size,
  61. dw=dw,
  62. hidden_blocks=hidden_blocks,
  63. displacement_emb=displacement_emb,
  64. displacement_emb_dim=64,
  65. local_corr_radius = 3,
  66. corr_in_other = True,
  67. amp = True,
  68. disable_local_corr_grad = disable_local_corr_grad,
  69. bn_momentum = 0.01,
  70. ),
  71. "4": ConvRefiner(
  72. 2 * 256+32+(2*2+1)**2,
  73. 2 * 256+32+(2*2+1)**2,
  74. 2 + 1,
  75. kernel_size=kernel_size,
  76. dw=dw,
  77. hidden_blocks=hidden_blocks,
  78. displacement_emb=displacement_emb,
  79. displacement_emb_dim=32,
  80. local_corr_radius = 2,
  81. corr_in_other = True,
  82. amp = True,
  83. disable_local_corr_grad = disable_local_corr_grad,
  84. bn_momentum = 0.01,
  85. ),
  86. "2": ConvRefiner(
  87. 2 * 64+16,
  88. 128+16,
  89. 2 + 1,
  90. kernel_size=kernel_size,
  91. dw=dw,
  92. hidden_blocks=hidden_blocks,
  93. displacement_emb=displacement_emb,
  94. displacement_emb_dim=16,
  95. amp = True,
  96. disable_local_corr_grad = disable_local_corr_grad,
  97. bn_momentum = 0.01,
  98. ),
  99. "1": ConvRefiner(
  100. 2 * 9 + 6,
  101. 24,
  102. 2 + 1,
  103. kernel_size=kernel_size,
  104. dw=dw,
  105. hidden_blocks = hidden_blocks,
  106. displacement_emb = displacement_emb,
  107. displacement_emb_dim = 6,
  108. amp = True,
  109. disable_local_corr_grad = disable_local_corr_grad,
  110. bn_momentum = 0.01,
  111. ),
  112. }
  113. )
  114. kernel_temperature = 0.2
  115. learn_temperature = False
  116. no_cov = True
  117. kernel = CosKernel
  118. only_attention = False
  119. basis = "fourier"
  120. gp16 = GP(
  121. kernel,
  122. T=kernel_temperature,
  123. learn_temperature=learn_temperature,
  124. only_attention=only_attention,
  125. gp_dim=gp_dim,
  126. basis=basis,
  127. no_cov=no_cov,
  128. )
  129. gps = nn.ModuleDict({"16": gp16})
  130. proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
  131. proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
  132. proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
  133. proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
  134. proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
  135. proj = nn.ModuleDict({
  136. "16": proj16,
  137. "8": proj8,
  138. "4": proj4,
  139. "2": proj2,
  140. "1": proj1,
  141. })
  142. displacement_dropout_p = 0.0
  143. gm_warp_dropout_p = 0.0
  144. decoder = Decoder(coordinate_decoder,
  145. gps,
  146. proj,
  147. conv_refiner,
  148. detach=True,
  149. scales=["16", "8", "4", "2", "1"],
  150. displacement_dropout_p = displacement_dropout_p,
  151. gm_warp_dropout_p = gm_warp_dropout_p)
  152. h,w = resolutions[resolution]
  153. encoder = CNNandDinov2(
  154. cnn_kwargs = dict(
  155. pretrained=pretrained_backbone,
  156. amp = True),
  157. amp = True,
  158. use_vgg = True,
  159. )
  160. matcher = RegressionMatcher(encoder, decoder, h=h, w=w, alpha=1, beta=0,**kwargs)
  161. return matcher
  162. def train(args):
  163. dist.init_process_group('nccl')
  164. #torch._dynamo.config.verbose=True
  165. gpus = int(os.environ['WORLD_SIZE'])
  166. # create model and move it to GPU with id rank
  167. rank = dist.get_rank()
  168. print(f"Start running DDP on rank {rank}")
  169. device_id = rank % torch.cuda.device_count()
  170. romatch.LOCAL_RANK = device_id
  171. torch.cuda.set_device(device_id)
  172. resolution = args.train_resolution
  173. wandb_log = not args.dont_log_wandb
  174. experiment_name = os.path.splitext(os.path.basename(__file__))[0]
  175. wandb_mode = "online" if wandb_log and rank == 0 and False else "disabled"
  176. wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
  177. checkpoint_dir = "workspace/checkpoints/"
  178. h,w = resolutions[resolution]
  179. model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
  180. # Num steps
  181. global_step = 0
  182. batch_size = args.gpu_batch_size
  183. step_size = gpus*batch_size
  184. romatch.STEP_SIZE = step_size
  185. N = (32 * 250000) # 250k steps of batch size 32
  186. # checkpoint every
  187. k = 25000 // romatch.STEP_SIZE
  188. # Data
  189. mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
  190. use_horizontal_flip_aug = True
  191. rot_prob = 0
  192. depth_interpolation_mode = "bilinear"
  193. megadepth_train1 = mega.build_scenes(
  194. split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
  195. ht=h,wt=w,
  196. )
  197. megadepth_train2 = mega.build_scenes(
  198. split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
  199. ht=h,wt=w,
  200. )
  201. megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
  202. mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
  203. scannet = ScanNetBuilder(data_root="data/scannet")
  204. scannet_train = scannet.build_scenes(split="train", ht=h, wt=w, use_horizontal_flip_aug = use_horizontal_flip_aug)
  205. scannet_train = ConcatDataset(scannet_train)
  206. scannet_ws = scannet.weight_scenes(scannet_train, alpha=0.75)
  207. # Loss and optimizer
  208. depth_loss_scannet = RobustLosses(
  209. ce_weight=0.0,
  210. local_dist={1:4, 2:4, 4:8, 8:8},
  211. local_largest_scale=8,
  212. depth_interpolation_mode=depth_interpolation_mode,
  213. alpha = 0.5,
  214. c = 1e-4,)
  215. # Loss and optimizer
  216. depth_loss_mega = RobustLosses(
  217. ce_weight=0.01,
  218. local_dist={1:4, 2:4, 4:8, 8:8},
  219. local_largest_scale=8,
  220. depth_interpolation_mode=depth_interpolation_mode,
  221. alpha = 0.5,
  222. c = 1e-4,)
  223. parameters = [
  224. {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
  225. {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
  226. ]
  227. optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
  228. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
  229. optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
  230. megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
  231. checkpointer = CheckPoint(checkpoint_dir, experiment_name)
  232. model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
  233. romatch.GLOBAL_STEP = global_step
  234. ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
  235. grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
  236. grad_clip_norm = 0.01
  237. for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
  238. mega_sampler = torch.utils.data.WeightedRandomSampler(
  239. mega_ws, num_samples = batch_size * k, replacement=False
  240. )
  241. mega_dataloader = iter(
  242. torch.utils.data.DataLoader(
  243. megadepth_train,
  244. batch_size = batch_size,
  245. sampler = mega_sampler,
  246. num_workers = 8,
  247. )
  248. )
  249. scannet_ws_sampler = torch.utils.data.WeightedRandomSampler(
  250. scannet_ws, num_samples=batch_size * k, replacement=False
  251. )
  252. scannet_dataloader = iter(
  253. torch.utils.data.DataLoader(
  254. scannet_train,
  255. batch_size=batch_size,
  256. sampler=scannet_ws_sampler,
  257. num_workers=gpus * 8,
  258. )
  259. )
  260. for n_k in tqdm(range(n, n + 2 * k, 2),disable = romatch.RANK > 0):
  261. train_k_steps(
  262. n_k, 1, mega_dataloader, ddp_model, depth_loss_mega, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
  263. )
  264. train_k_steps(
  265. n_k + 1, 1, scannet_dataloader, ddp_model, depth_loss_scannet, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
  266. )
  267. checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
  268. wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
  269. def test_scannet(model, name, resolution, sample_mode):
  270. scannet_benchmark = ScanNetBenchmark("data/scannet")
  271. scannet_results = scannet_benchmark.benchmark(model)
  272. json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
  273. if __name__ == "__main__":
  274. import warnings
  275. warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
  276. warnings.filterwarnings('ignore')#, category=UserWarning)#, message='WARNING batched routines are designed for small sizes.')
  277. os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
  278. os.environ["OMP_NUM_THREADS"] = "16"
  279. import romatch
  280. parser = ArgumentParser()
  281. parser.add_argument("--test", action='store_true')
  282. parser.add_argument("--debug_mode", action='store_true')
  283. parser.add_argument("--dont_log_wandb", action='store_true')
  284. parser.add_argument("--train_resolution", default='medium')
  285. parser.add_argument("--gpu_batch_size", default=4, type=int)
  286. parser.add_argument("--wandb_entity", required = False)
  287. args, _ = parser.parse_known_args()
  288. romatch.DEBUG_MODE = args.debug_mode
  289. if not args.test:
  290. train(args)
  291. experiment_name = os.path.splitext(os.path.basename(__file__))[0]
  292. checkpoint_dir = "workspace/"
  293. checkpoint_name = checkpoint_dir + experiment_name + ".pth"
  294. test_resolution = "medium"
  295. sample_mode = "threshold_balanced"
  296. symmetric = True
  297. upsample_preds = False
  298. attenuate_cert = True
  299. model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert)
  300. model = model.cuda()
  301. states = torch.load(checkpoint_name)
  302. model.load_state_dict(states["model"])
  303. test_scannet(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode)