train_roma_outdoor.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. import os
  2. import torch
  3. from argparse import ArgumentParser
  4. from warnings import warn
  5. from torch import nn
  6. from torch.utils.data import ConcatDataset
  7. import torch.distributed as dist
  8. from torch.nn.parallel import DistributedDataParallel as DDP
  9. import json
  10. import wandb
  11. from romatch.benchmarks import MegadepthDenseBenchmark
  12. from romatch.datasets.megadepth import MegadepthBuilder
  13. from romatch.losses.robust_loss import RobustLosses
  14. from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
  15. from romatch.train.train import train_k_steps
  16. from romatch.models.matcher import *
  17. from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
  18. from romatch.models.encoders import *
  19. from romatch.checkpointing import CheckPoint
  20. resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
  21. def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
  22. import warnings
  23. warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
  24. gp_dim = 512
  25. feat_dim = 512
  26. decoder_dim = gp_dim + feat_dim
  27. cls_to_coord_res = 64
  28. coordinate_decoder = TransformerDecoder(
  29. nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
  30. decoder_dim,
  31. cls_to_coord_res**2 + 1,
  32. is_classifier=True,
  33. amp = True,
  34. pos_enc = False,)
  35. dw = True
  36. hidden_blocks = 8
  37. kernel_size = 5
  38. displacement_emb = "linear"
  39. disable_local_corr_grad = True
  40. conv_refiner = nn.ModuleDict(
  41. {
  42. "16": ConvRefiner(
  43. 2 * 512+128+(2*7+1)**2,
  44. 2 * 512+128+(2*7+1)**2,
  45. 2 + 1,
  46. kernel_size=kernel_size,
  47. dw=dw,
  48. hidden_blocks=hidden_blocks,
  49. displacement_emb=displacement_emb,
  50. displacement_emb_dim=128,
  51. local_corr_radius = 7,
  52. corr_in_other = True,
  53. amp = True,
  54. disable_local_corr_grad = disable_local_corr_grad,
  55. bn_momentum = 0.01,
  56. ),
  57. "8": ConvRefiner(
  58. 2 * 512+64+(2*3+1)**2,
  59. 2 * 512+64+(2*3+1)**2,
  60. 2 + 1,
  61. kernel_size=kernel_size,
  62. dw=dw,
  63. hidden_blocks=hidden_blocks,
  64. displacement_emb=displacement_emb,
  65. displacement_emb_dim=64,
  66. local_corr_radius = 3,
  67. corr_in_other = True,
  68. amp = True,
  69. disable_local_corr_grad = disable_local_corr_grad,
  70. bn_momentum = 0.01,
  71. ),
  72. "4": ConvRefiner(
  73. 2 * 256+32+(2*2+1)**2,
  74. 2 * 256+32+(2*2+1)**2,
  75. 2 + 1,
  76. kernel_size=kernel_size,
  77. dw=dw,
  78. hidden_blocks=hidden_blocks,
  79. displacement_emb=displacement_emb,
  80. displacement_emb_dim=32,
  81. local_corr_radius = 2,
  82. corr_in_other = True,
  83. amp = True,
  84. disable_local_corr_grad = disable_local_corr_grad,
  85. bn_momentum = 0.01,
  86. ),
  87. "2": ConvRefiner(
  88. 2 * 64+16,
  89. 128+16,
  90. 2 + 1,
  91. kernel_size=kernel_size,
  92. dw=dw,
  93. hidden_blocks=hidden_blocks,
  94. displacement_emb=displacement_emb,
  95. displacement_emb_dim=16,
  96. amp = True,
  97. disable_local_corr_grad = disable_local_corr_grad,
  98. bn_momentum = 0.01,
  99. ),
  100. "1": ConvRefiner(
  101. 2 * 9 + 6,
  102. 24,
  103. 2 + 1,
  104. kernel_size=kernel_size,
  105. dw=dw,
  106. hidden_blocks = hidden_blocks,
  107. displacement_emb = displacement_emb,
  108. displacement_emb_dim = 6,
  109. amp = True,
  110. disable_local_corr_grad = disable_local_corr_grad,
  111. bn_momentum = 0.01,
  112. ),
  113. }
  114. )
  115. kernel_temperature = 0.2
  116. learn_temperature = False
  117. no_cov = True
  118. kernel = CosKernel
  119. only_attention = False
  120. basis = "fourier"
  121. gp16 = GP(
  122. kernel,
  123. T=kernel_temperature,
  124. learn_temperature=learn_temperature,
  125. only_attention=only_attention,
  126. gp_dim=gp_dim,
  127. basis=basis,
  128. no_cov=no_cov,
  129. )
  130. gps = nn.ModuleDict({"16": gp16})
  131. proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
  132. proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
  133. proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
  134. proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
  135. proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
  136. proj = nn.ModuleDict({
  137. "16": proj16,
  138. "8": proj8,
  139. "4": proj4,
  140. "2": proj2,
  141. "1": proj1,
  142. })
  143. displacement_dropout_p = 0.0
  144. gm_warp_dropout_p = 0.0
  145. decoder = Decoder(coordinate_decoder,
  146. gps,
  147. proj,
  148. conv_refiner,
  149. detach=True,
  150. scales=["16", "8", "4", "2", "1"],
  151. displacement_dropout_p = displacement_dropout_p,
  152. gm_warp_dropout_p = gm_warp_dropout_p)
  153. h,w = resolutions[resolution]
  154. encoder = CNNandDinov2(
  155. cnn_kwargs = dict(
  156. pretrained=pretrained_backbone,
  157. amp = True),
  158. amp = True,
  159. use_vgg = True,
  160. )
  161. matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs)
  162. return matcher
  163. def train(args):
  164. dist.init_process_group('nccl')
  165. #torch._dynamo.config.verbose=True
  166. gpus = int(os.environ['WORLD_SIZE'])
  167. # create model and move it to GPU with id rank
  168. rank = dist.get_rank()
  169. print(f"Start running DDP on rank {rank}")
  170. device_id = rank % torch.cuda.device_count()
  171. romatch.LOCAL_RANK = device_id
  172. torch.cuda.set_device(device_id)
  173. resolution = args.train_resolution
  174. wandb_log = not args.dont_log_wandb
  175. experiment_name = os.path.splitext(os.path.basename(__file__))[0]
  176. wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
  177. wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
  178. checkpoint_dir = "workspace/checkpoints/"
  179. h,w = resolutions[resolution]
  180. model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
  181. # Num steps
  182. global_step = 0
  183. batch_size = args.gpu_batch_size
  184. step_size = gpus*batch_size
  185. romatch.STEP_SIZE = step_size
  186. N = (32 * 250000) # 250k steps of batch size 32
  187. # checkpoint every
  188. k = 25000 // romatch.STEP_SIZE
  189. # Data
  190. mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
  191. use_horizontal_flip_aug = True
  192. rot_prob = 0
  193. depth_interpolation_mode = "bilinear"
  194. megadepth_train1 = mega.build_scenes(
  195. split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
  196. ht=h,wt=w,
  197. )
  198. megadepth_train2 = mega.build_scenes(
  199. split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
  200. ht=h,wt=w,
  201. )
  202. megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
  203. mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
  204. # Loss and optimizer
  205. depth_loss = RobustLosses(
  206. ce_weight=0.01,
  207. local_dist={1:4, 2:4, 4:8, 8:8},
  208. local_largest_scale=8,
  209. depth_interpolation_mode=depth_interpolation_mode,
  210. alpha = 0.5,
  211. c = 1e-4,)
  212. parameters = [
  213. {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
  214. {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
  215. ]
  216. optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
  217. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
  218. optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
  219. megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
  220. checkpointer = CheckPoint(checkpoint_dir, experiment_name)
  221. model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
  222. romatch.GLOBAL_STEP = global_step
  223. ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
  224. grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
  225. grad_clip_norm = 0.01
  226. for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
  227. mega_sampler = torch.utils.data.WeightedRandomSampler(
  228. mega_ws, num_samples = batch_size * k, replacement=False
  229. )
  230. mega_dataloader = iter(
  231. torch.utils.data.DataLoader(
  232. megadepth_train,
  233. batch_size = batch_size,
  234. sampler = mega_sampler,
  235. num_workers = 8,
  236. )
  237. )
  238. train_k_steps(
  239. n, k, mega_dataloader, ddp_model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
  240. )
  241. checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
  242. wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
  243. def test_mega_8_scenes(model, name):
  244. mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
  245. scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
  246. 'mega_8_scenes_0025_0.1_0.3.npz',
  247. 'mega_8_scenes_0021_0.1_0.3.npz',
  248. 'mega_8_scenes_0008_0.1_0.3.npz',
  249. 'mega_8_scenes_0032_0.1_0.3.npz',
  250. 'mega_8_scenes_1589_0.1_0.3.npz',
  251. 'mega_8_scenes_0063_0.1_0.3.npz',
  252. 'mega_8_scenes_0024_0.1_0.3.npz',
  253. 'mega_8_scenes_0019_0.3_0.5.npz',
  254. 'mega_8_scenes_0025_0.3_0.5.npz',
  255. 'mega_8_scenes_0021_0.3_0.5.npz',
  256. 'mega_8_scenes_0008_0.3_0.5.npz',
  257. 'mega_8_scenes_0032_0.3_0.5.npz',
  258. 'mega_8_scenes_1589_0.3_0.5.npz',
  259. 'mega_8_scenes_0063_0.3_0.5.npz',
  260. 'mega_8_scenes_0024_0.3_0.5.npz'])
  261. mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
  262. print(mega_8_scenes_results)
  263. json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
  264. def test_mega1500(model, name):
  265. mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
  266. mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
  267. json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
  268. def test_mega_dense(model, name):
  269. megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
  270. megadense_results = megadense_benchmark.benchmark(model)
  271. json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
  272. def test_hpatches(model, name):
  273. hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
  274. hpatches_results = hpatches_benchmark.benchmark(model)
  275. json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
  276. if __name__ == "__main__":
  277. os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
  278. os.environ["OMP_NUM_THREADS"] = "16"
  279. torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
  280. warn('Current version of romatch is not tested for training, use at your own risk.')
  281. import romatch
  282. parser = ArgumentParser()
  283. parser.add_argument("--only_test", action='store_true')
  284. parser.add_argument("--debug_mode", action='store_true')
  285. parser.add_argument("--dont_log_wandb", action='store_true')
  286. parser.add_argument("--train_resolution", default='medium')
  287. parser.add_argument("--gpu_batch_size", default=8, type=int)
  288. parser.add_argument("--wandb_entity", required = False)
  289. args, _ = parser.parse_known_args()
  290. romatch.DEBUG_MODE = args.debug_mode
  291. if not args.only_test:
  292. train(args)