train.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import math
  2. import argparse
  3. import pprint
  4. from distutils.util import strtobool
  5. from pathlib import Path
  6. from loguru import logger as loguru_logger
  7. import pytorch_lightning as pl
  8. from pytorch_lightning.utilities import rank_zero_only
  9. from pytorch_lightning.loggers import TensorBoardLogger
  10. from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
  11. from pytorch_lightning.plugins import DDPPlugin
  12. from src.config.default import get_cfg_defaults
  13. from src.utils.misc import get_rank_zero_only_logger, setup_gpus
  14. from src.utils.profiler import build_profiler
  15. from src.lightning.data import MultiSceneDataModule
  16. from src.lightning.lightning_loftr import PL_LoFTR
  17. loguru_logger = get_rank_zero_only_logger(loguru_logger)
  18. def parse_args():
  19. # init a costum parser which will be added into pl.Trainer parser
  20. # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
  21. parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  22. parser.add_argument(
  23. 'data_cfg_path', type=str, help='data config path')
  24. parser.add_argument(
  25. 'main_cfg_path', type=str, help='main config path')
  26. parser.add_argument(
  27. '--exp_name', type=str, default='default_exp_name')
  28. parser.add_argument(
  29. '--batch_size', type=int, default=4, help='batch_size per gpu')
  30. parser.add_argument(
  31. '--num_workers', type=int, default=4)
  32. parser.add_argument(
  33. '--pin_memory', type=lambda x: bool(strtobool(x)),
  34. nargs='?', default=True, help='whether loading data to pinned memory or not')
  35. parser.add_argument(
  36. '--ckpt_path', type=str, default=None,
  37. help='pretrained checkpoint path, helpful for using a pre-trained coarse-only LoFTR')
  38. parser.add_argument(
  39. '--disable_ckpt', action='store_true',
  40. help='disable checkpoint saving (useful for debugging).')
  41. parser.add_argument(
  42. '--profiler_name', type=str, default=None,
  43. help='options: [inference, pytorch], or leave it unset')
  44. parser.add_argument(
  45. '--parallel_load_data', action='store_true',
  46. help='load datasets in with multiple processes.')
  47. parser = pl.Trainer.add_argparse_args(parser)
  48. return parser.parse_args()
  49. def main():
  50. # parse arguments
  51. args = parse_args()
  52. rank_zero_only(pprint.pprint)(vars(args))
  53. # init default-cfg and merge it with the main- and data-cfg
  54. config = get_cfg_defaults()
  55. config.merge_from_file(args.main_cfg_path)
  56. config.merge_from_file(args.data_cfg_path)
  57. pl.seed_everything(config.TRAINER.SEED) # reproducibility
  58. # TODO: Use different seeds for each dataloader workers
  59. # This is needed for data augmentation
  60. # scale lr and warmup-step automatically
  61. args.gpus = _n_gpus = setup_gpus(args.gpus)
  62. config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes
  63. config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size
  64. _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS
  65. config.TRAINER.SCALING = _scaling
  66. config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling
  67. config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling)
  68. # lightning module
  69. profiler = build_profiler(args.profiler_name)
  70. model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler)
  71. loguru_logger.info(f"LoFTR LightningModule initialized!")
  72. # lightning data
  73. data_module = MultiSceneDataModule(args, config)
  74. loguru_logger.info(f"LoFTR DataModule initialized!")
  75. # TensorBoard Logger
  76. logger = TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)
  77. ckpt_dir = Path(logger.log_dir) / 'checkpoints'
  78. # Callbacks
  79. # TODO: update ModelCheckpoint to monitor multiple metrics
  80. ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=5, mode='max',
  81. save_last=True,
  82. dirpath=str(ckpt_dir),
  83. filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}')
  84. lr_monitor = LearningRateMonitor(logging_interval='step')
  85. callbacks = [lr_monitor]
  86. if not args.disable_ckpt:
  87. callbacks.append(ckpt_callback)
  88. # Lightning Trainer
  89. trainer = pl.Trainer.from_argparse_args(
  90. args,
  91. plugins=DDPPlugin(find_unused_parameters=False,
  92. num_nodes=args.num_nodes,
  93. sync_batchnorm=config.TRAINER.WORLD_SIZE > 0),
  94. gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
  95. callbacks=callbacks,
  96. logger=logger,
  97. sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
  98. replace_sampler_ddp=False, # use custom sampler
  99. reload_dataloaders_every_epoch=False, # avoid repeated samples!
  100. weights_summary='full',
  101. profiler=profiler)
  102. loguru_logger.info(f"Trainer initialized!")
  103. loguru_logger.info(f"Start training!")
  104. trainer.fit(model, datamodule=data_module)
  105. if __name__ == '__main__':
  106. main()