train_roma_outdoor.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import os
  2. import torch
  3. from argparse import ArgumentParser
  4. from torch import nn
  5. from torch.utils.data import ConcatDataset
  6. import torch.distributed as dist
  7. from torch.nn.parallel import DistributedDataParallel as DDP
  8. import json
  9. import wandb
  10. from romatch.benchmarks import MegadepthDenseBenchmark
  11. from romatch.datasets.megadepth import MegadepthBuilder
  12. from romatch.losses.robust_loss import RobustLosses
  13. from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
  14. from romatch.train.train import train_k_steps
  15. from romatch.models.matcher import *
  16. from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
  17. from romatch.models.encoders import *
  18. from romatch.checkpointing import CheckPoint
  19. resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
  20. def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
  21. import warnings
  22. warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
  23. gp_dim = 512
  24. feat_dim = 512
  25. decoder_dim = gp_dim + feat_dim
  26. cls_to_coord_res = 64
  27. coordinate_decoder = TransformerDecoder(
  28. nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
  29. decoder_dim,
  30. cls_to_coord_res**2 + 1,
  31. is_classifier=True,
  32. amp = True,
  33. pos_enc = False,)
  34. dw = True
  35. hidden_blocks = 8
  36. kernel_size = 5
  37. displacement_emb = "linear"
  38. disable_local_corr_grad = True
  39. conv_refiner = nn.ModuleDict(
  40. {
  41. "16": ConvRefiner(
  42. 2 * 512+128+(2*7+1)**2,
  43. 2 * 512+128+(2*7+1)**2,
  44. 2 + 1,
  45. kernel_size=kernel_size,
  46. dw=dw,
  47. hidden_blocks=hidden_blocks,
  48. displacement_emb=displacement_emb,
  49. displacement_emb_dim=128,
  50. local_corr_radius = 7,
  51. corr_in_other = True,
  52. amp = True,
  53. disable_local_corr_grad = disable_local_corr_grad,
  54. bn_momentum = 0.01,
  55. ),
  56. "8": ConvRefiner(
  57. 2 * 512+64+(2*3+1)**2,
  58. 2 * 512+64+(2*3+1)**2,
  59. 2 + 1,
  60. kernel_size=kernel_size,
  61. dw=dw,
  62. hidden_blocks=hidden_blocks,
  63. displacement_emb=displacement_emb,
  64. displacement_emb_dim=64,
  65. local_corr_radius = 3,
  66. corr_in_other = True,
  67. amp = True,
  68. disable_local_corr_grad = disable_local_corr_grad,
  69. bn_momentum = 0.01,
  70. ),
  71. "4": ConvRefiner(
  72. 2 * 256+32+(2*2+1)**2,
  73. 2 * 256+32+(2*2+1)**2,
  74. 2 + 1,
  75. kernel_size=kernel_size,
  76. dw=dw,
  77. hidden_blocks=hidden_blocks,
  78. displacement_emb=displacement_emb,
  79. displacement_emb_dim=32,
  80. local_corr_radius = 2,
  81. corr_in_other = True,
  82. amp = True,
  83. disable_local_corr_grad = disable_local_corr_grad,
  84. bn_momentum = 0.01,
  85. ),
  86. "2": ConvRefiner(
  87. 2 * 64+16,
  88. 128+16,
  89. 2 + 1,
  90. kernel_size=kernel_size,
  91. dw=dw,
  92. hidden_blocks=hidden_blocks,
  93. displacement_emb=displacement_emb,
  94. displacement_emb_dim=16,
  95. amp = True,
  96. disable_local_corr_grad = disable_local_corr_grad,
  97. bn_momentum = 0.01,
  98. ),
  99. "1": ConvRefiner(
  100. 2 * 9 + 6,
  101. 24,
  102. 2 + 1,
  103. kernel_size=kernel_size,
  104. dw=dw,
  105. hidden_blocks = hidden_blocks,
  106. displacement_emb = displacement_emb,
  107. displacement_emb_dim = 6,
  108. amp = True,
  109. disable_local_corr_grad = disable_local_corr_grad,
  110. bn_momentum = 0.01,
  111. ),
  112. }
  113. )
  114. kernel_temperature = 0.2
  115. learn_temperature = False
  116. no_cov = True
  117. kernel = CosKernel
  118. only_attention = False
  119. basis = "fourier"
  120. gp16 = GP(
  121. kernel,
  122. T=kernel_temperature,
  123. learn_temperature=learn_temperature,
  124. only_attention=only_attention,
  125. gp_dim=gp_dim,
  126. basis=basis,
  127. no_cov=no_cov,
  128. )
  129. gps = nn.ModuleDict({"16": gp16})
  130. proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
  131. proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
  132. proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
  133. proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
  134. proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
  135. proj = nn.ModuleDict({
  136. "16": proj16,
  137. "8": proj8,
  138. "4": proj4,
  139. "2": proj2,
  140. "1": proj1,
  141. })
  142. displacement_dropout_p = 0.0
  143. gm_warp_dropout_p = 0.0
  144. decoder = Decoder(coordinate_decoder,
  145. gps,
  146. proj,
  147. conv_refiner,
  148. detach=True,
  149. scales=["16", "8", "4", "2", "1"],
  150. displacement_dropout_p = displacement_dropout_p,
  151. gm_warp_dropout_p = gm_warp_dropout_p)
  152. h,w = resolutions[resolution]
  153. encoder = CNNandDinov2(
  154. cnn_kwargs = dict(
  155. pretrained=pretrained_backbone,
  156. amp = True),
  157. amp = True,
  158. use_vgg = True,
  159. )
  160. matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs)
  161. return matcher
  162. def train(args):
  163. dist.init_process_group('nccl')
  164. #torch._dynamo.config.verbose=True
  165. gpus = int(os.environ['WORLD_SIZE'])
  166. # create model and move it to GPU with id rank
  167. rank = dist.get_rank()
  168. print(f"Start running DDP on rank {rank}")
  169. device_id = rank % torch.cuda.device_count()
  170. romatch.LOCAL_RANK = device_id
  171. torch.cuda.set_device(device_id)
  172. resolution = args.train_resolution
  173. wandb_log = not args.dont_log_wandb
  174. experiment_name = os.path.splitext(os.path.basename(__file__))[0]
  175. wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
  176. wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
  177. checkpoint_dir = "workspace/checkpoints/"
  178. h,w = resolutions[resolution]
  179. model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
  180. # Num steps
  181. global_step = 0
  182. batch_size = args.gpu_batch_size
  183. step_size = gpus*batch_size
  184. romatch.STEP_SIZE = step_size
  185. N = (32 * 250000) # 250k steps of batch size 32
  186. # checkpoint every
  187. k = 25000 // romatch.STEP_SIZE
  188. # Data
  189. mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
  190. use_horizontal_flip_aug = True
  191. rot_prob = 0
  192. depth_interpolation_mode = "bilinear"
  193. megadepth_train1 = mega.build_scenes(
  194. split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
  195. ht=h,wt=w,
  196. )
  197. megadepth_train2 = mega.build_scenes(
  198. split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
  199. ht=h,wt=w,
  200. )
  201. megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
  202. mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
  203. # Loss and optimizer
  204. depth_loss = RobustLosses(
  205. ce_weight=0.01,
  206. local_dist={1:4, 2:4, 4:8, 8:8},
  207. local_largest_scale=8,
  208. depth_interpolation_mode=depth_interpolation_mode,
  209. alpha = 0.5,
  210. c = 1e-4,)
  211. parameters = [
  212. {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
  213. {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
  214. ]
  215. optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
  216. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
  217. optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
  218. megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
  219. checkpointer = CheckPoint(checkpoint_dir, experiment_name)
  220. model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
  221. romatch.GLOBAL_STEP = global_step
  222. ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
  223. grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
  224. grad_clip_norm = 0.01
  225. for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
  226. mega_sampler = torch.utils.data.WeightedRandomSampler(
  227. mega_ws, num_samples = batch_size * k, replacement=False
  228. )
  229. mega_dataloader = iter(
  230. torch.utils.data.DataLoader(
  231. megadepth_train,
  232. batch_size = batch_size,
  233. sampler = mega_sampler,
  234. num_workers = 8,
  235. )
  236. )
  237. train_k_steps(
  238. n, k, mega_dataloader, ddp_model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
  239. )
  240. checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
  241. wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
  242. def test_mega_8_scenes(model, name):
  243. mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
  244. scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
  245. 'mega_8_scenes_0025_0.1_0.3.npz',
  246. 'mega_8_scenes_0021_0.1_0.3.npz',
  247. 'mega_8_scenes_0008_0.1_0.3.npz',
  248. 'mega_8_scenes_0032_0.1_0.3.npz',
  249. 'mega_8_scenes_1589_0.1_0.3.npz',
  250. 'mega_8_scenes_0063_0.1_0.3.npz',
  251. 'mega_8_scenes_0024_0.1_0.3.npz',
  252. 'mega_8_scenes_0019_0.3_0.5.npz',
  253. 'mega_8_scenes_0025_0.3_0.5.npz',
  254. 'mega_8_scenes_0021_0.3_0.5.npz',
  255. 'mega_8_scenes_0008_0.3_0.5.npz',
  256. 'mega_8_scenes_0032_0.3_0.5.npz',
  257. 'mega_8_scenes_1589_0.3_0.5.npz',
  258. 'mega_8_scenes_0063_0.3_0.5.npz',
  259. 'mega_8_scenes_0024_0.3_0.5.npz'])
  260. mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
  261. print(mega_8_scenes_results)
  262. json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
  263. def test_mega1500(model, name):
  264. mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
  265. mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
  266. json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
  267. def test_mega_dense(model, name):
  268. megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
  269. megadense_results = megadense_benchmark.benchmark(model)
  270. json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
  271. def test_hpatches(model, name):
  272. hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
  273. hpatches_results = hpatches_benchmark.benchmark(model)
  274. json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
  275. if __name__ == "__main__":
  276. os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
  277. os.environ["OMP_NUM_THREADS"] = "16"
  278. torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
  279. import romatch
  280. parser = ArgumentParser()
  281. parser.add_argument("--only_test", action='store_true')
  282. parser.add_argument("--debug_mode", action='store_true')
  283. parser.add_argument("--dont_log_wandb", action='store_true')
  284. parser.add_argument("--train_resolution", default='medium')
  285. parser.add_argument("--gpu_batch_size", default=8, type=int)
  286. parser.add_argument("--wandb_entity", required = False)
  287. args, _ = parser.parse_known_args()
  288. romatch.DEBUG_MODE = args.debug_mode
  289. if not args.only_test:
  290. train(args)