train.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import os
  2. import datetime
  3. from contextlib import nullcontext
  4. import argparse
  5. import torch
  6. import torch.nn as nn
  7. import torch.optim as optim
  8. if tuple(map(int, torch.__version__.split('+')[0].split(".")[:3])) >= (2, 5, 0):
  9. os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
  10. from config import Config
  11. from loss import PixLoss, ClsLoss
  12. from dataset import MyData
  13. from models.birefnet import BiRefNet
  14. from utils import Logger, AverageMeter, set_seed, check_state_dict
  15. from torch.utils.data.distributed import DistributedSampler
  16. from torch.nn.parallel import DistributedDataParallel as DDP
  17. from torch.distributed import init_process_group, destroy_process_group
  18. parser = argparse.ArgumentParser(description='')
  19. parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint')
  20. parser.add_argument('--epochs', default=120, type=int)
  21. parser.add_argument('--ckpt_dir', default='ckpts/tmp', help='Temporary folder')
  22. parser.add_argument('--dist', default=False, type=lambda x: x == 'True')
  23. parser.add_argument('--use_accelerate', action='store_true', help='`accelerate launch --multi_gpu train.py --use_accelerate`. Use accelerate for training, good for FP16/BF16/...')
  24. args = parser.parse_args()
  25. config = Config()
  26. if args.use_accelerate:
  27. from accelerate import Accelerator, utils
  28. mixed_precision = config.mixed_precision
  29. kwargs_handlers = [
  30. utils.InitProcessGroupKwargs(backend="nccl", timeout=datetime.timedelta(seconds=3600*10)),
  31. utils.DistributedDataParallelKwargs(find_unused_parameters=False),
  32. utils.GradScalerKwargs(backoff_factor=0.5),
  33. ]
  34. if mixed_precision == 'fp8':
  35. kwargs_handlers.append(utils.AORecipeKwargs())
  36. accelerator = Accelerator(
  37. mixed_precision=mixed_precision,
  38. gradient_accumulation_steps=1,
  39. kwargs_handlers=kwargs_handlers,
  40. )
  41. accelerator.print(accelerator.state)
  42. accelerator.print('backbone:', config.bb, ', freeze_bb:', config.freeze_bb)
  43. args.dist = False
  44. # DDP
  45. to_be_distributed = args.dist
  46. if to_be_distributed:
  47. init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=3600*10))
  48. device = int(os.environ["LOCAL_RANK"])
  49. else:
  50. if args.use_accelerate:
  51. device = accelerator.local_process_index
  52. else:
  53. device = config.device
  54. if config.rand_seed:
  55. set_seed(config.rand_seed + device)
  56. epoch_st = 1
  57. # make dir for ckpt
  58. os.makedirs(args.ckpt_dir, exist_ok=True)
  59. # Init log file
  60. logger = Logger(os.path.join(args.ckpt_dir, "log.txt"))
  61. logger_loss_idx = 1
  62. # log model and optimizer params
  63. # logger.info("Model details:"); logger.info(model)
  64. # if args.use_accelerate and accelerator.mixed_precision != 'no':
  65. # config.compile = False
  66. logger.info("datasets: load_all={}, compile={}.".format(config.load_all, config.compile))
  67. logger.info("Other hyperparameters:"); logger.info(args)
  68. print('batch size:', config.batch_size)
  69. from dataset import custom_collate_fn
  70. def prepare_dataloader(dataset: torch.utils.data.Dataset, batch_size: int, to_be_distributed=False, is_train=True):
  71. # Prepare dataloaders
  72. if to_be_distributed:
  73. return torch.utils.data.DataLoader(
  74. dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size), pin_memory=True,
  75. shuffle=False, sampler=DistributedSampler(dataset), drop_last=True, collate_fn=custom_collate_fn if is_train and config.dynamic_size else None
  76. )
  77. else:
  78. return torch.utils.data.DataLoader(
  79. dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size), pin_memory=True,
  80. shuffle=is_train, sampler=None, drop_last=True, collate_fn=custom_collate_fn if is_train and config.dynamic_size else None
  81. )
  82. def init_data_loaders(to_be_distributed):
  83. # Prepare datasets
  84. train_loader = prepare_dataloader(
  85. MyData(datasets=config.training_set, data_size=None if config.dynamic_size else config.size, is_train=True),
  86. config.batch_size, to_be_distributed=to_be_distributed, is_train=True
  87. )
  88. print(len(train_loader), "batches of train dataloader {} have been created.".format(config.training_set))
  89. return train_loader
  90. def init_models_optimizers(epochs, to_be_distributed):
  91. # Init models
  92. if config.model == 'BiRefNet':
  93. model = BiRefNet(bb_pretrained=True and not os.path.isfile(str(args.resume)))
  94. else:
  95. print('Undefined model: {}.'.format(config.model))
  96. return None
  97. if args.resume:
  98. if os.path.isfile(args.resume):
  99. logger.info("=> loading checkpoint '{}'".format(args.resume))
  100. state_dict = torch.load(args.resume, map_location='cpu', weights_only=True)
  101. state_dict = check_state_dict(state_dict)
  102. model.load_state_dict(state_dict)
  103. global epoch_st
  104. epoch_st = int(args.resume.rstrip('.pth').split('epoch_')[-1]) + 1
  105. else:
  106. logger.info("=> no checkpoint found at '{}'".format(args.resume))
  107. if not args.use_accelerate:
  108. if to_be_distributed:
  109. model = model.to(device)
  110. model = DDP(model, device_ids=[device])
  111. else:
  112. model = model.to(device)
  113. if config.compile:
  114. model = torch.compile(model, mode=['default', 'reduce-overhead', 'max-autotune'][0])
  115. if config.precisionHigh:
  116. torch.set_float32_matmul_precision('high')
  117. # Setting optimizer
  118. if config.optimizer == 'AdamW':
  119. optimizer = optim.AdamW(params=[p for p in model.parameters() if p.requires_grad], lr=config.lr, weight_decay=1e-2)
  120. elif config.optimizer == 'Adam':
  121. optimizer = optim.Adam(params=[p for p in model.parameters() if p.requires_grad], lr=config.lr, weight_decay=0)
  122. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
  123. optimizer,
  124. milestones=[lde if lde > 0 else epochs + lde + 1 for lde in config.lr_decay_epochs],
  125. gamma=config.lr_decay_rate
  126. )
  127. # logger.info("Optimizer details:"); logger.info(optimizer)
  128. return model, optimizer, lr_scheduler
  129. class Trainer:
  130. def __init__(
  131. self, data_loaders, model_opt_lrsch,
  132. ):
  133. self.model, self.optimizer, self.lr_scheduler = model_opt_lrsch
  134. self.train_loader = data_loaders
  135. if args.use_accelerate:
  136. self.train_loader, self.model, self.optimizer = accelerator.prepare(self.train_loader, self.model, self.optimizer)
  137. if config.out_ref:
  138. self.criterion_gdt = nn.BCELoss()
  139. # Setting Losses
  140. self.pix_loss = PixLoss()
  141. self.cls_loss = ClsLoss()
  142. # Others
  143. self.loss_log = AverageMeter()
  144. def _train_batch(self, batch):
  145. if args.use_accelerate:
  146. inputs = batch[0]#.to(device)
  147. gts = batch[1]#.to(device)
  148. class_labels = batch[2]#.to(device)
  149. else:
  150. inputs = batch[0].to(device)
  151. gts = batch[1].to(device)
  152. class_labels = batch[2].to(device)
  153. self.optimizer.zero_grad()
  154. scaled_preds, class_preds_lst = self.model(inputs)
  155. if config.out_ref:
  156. (outs_gdt_pred, outs_gdt_label), scaled_preds = scaled_preds
  157. for _idx, (_gdt_pred, _gdt_label) in enumerate(zip(outs_gdt_pred, outs_gdt_label)):
  158. _gdt_pred = nn.functional.interpolate(_gdt_pred, size=_gdt_label.shape[2:], mode='bilinear', align_corners=True).sigmoid()
  159. _gdt_label = _gdt_label.sigmoid()
  160. loss_gdt = self.criterion_gdt(_gdt_pred, _gdt_label) if _idx == 0 else self.criterion_gdt(_gdt_pred, _gdt_label) + loss_gdt
  161. # self.loss_dict['loss_gdt'] = loss_gdt.item()
  162. if None in class_preds_lst:
  163. loss_cls = 0.
  164. else:
  165. loss_cls = self.cls_loss(class_preds_lst, class_labels)
  166. self.loss_dict['loss_cls'] = loss_cls.item()
  167. # Loss
  168. loss_pix, loss_dict_pix = self.pix_loss(scaled_preds, torch.clamp(gts, 0, 1), pix_loss_lambda=1.0)
  169. self.loss_dict.update(loss_dict_pix)
  170. self.loss_dict['loss_pix'] = loss_pix.item()
  171. # since there may be several losses for sal, the lambdas for them (lambdas_pix) are inside the loss.py
  172. loss = loss_pix + loss_cls
  173. if config.out_ref:
  174. loss = loss + loss_gdt * 1.0
  175. self.loss_log.update(loss.item(), inputs.size(0))
  176. if args.use_accelerate:
  177. loss = loss / accelerator.gradient_accumulation_steps
  178. accelerator.backward(loss)
  179. else:
  180. loss.backward()
  181. self.optimizer.step()
  182. def train_epoch(self, epoch):
  183. global logger_loss_idx
  184. self.model.train()
  185. self.loss_dict = {}
  186. if epoch > args.epochs + config.finetune_last_epochs:
  187. if config.task == 'Matting':
  188. self.pix_loss.lambdas_pix_last['mae'] *= 1
  189. self.pix_loss.lambdas_pix_last['mse'] *= 0.9
  190. self.pix_loss.lambdas_pix_last['ssim'] *= 0.9
  191. else:
  192. self.pix_loss.lambdas_pix_last['bce'] *= 0
  193. self.pix_loss.lambdas_pix_last['ssim'] *= 1
  194. self.pix_loss.lambdas_pix_last['iou'] *= 0.5
  195. self.pix_loss.lambdas_pix_last['mae'] *= 0.9
  196. for batch_idx, batch in enumerate(self.train_loader):
  197. # with nullcontext if not args.use_accelerate or accelerator.gradient_accumulation_steps <= 1 else accelerator.accumulate(self.model):
  198. self._train_batch(batch)
  199. # Logger
  200. if (epoch < 2 and batch_idx < 100 and batch_idx % 20 == 0) or batch_idx % max(100, len(self.train_loader) / 100 // 100 * 100) == 0:
  201. info_progress = f'Epoch[{epoch}/{args.epochs}] Iter[{batch_idx}/{len(self.train_loader)}].'
  202. info_loss = 'Training Losses:'
  203. for loss_name, loss_value in self.loss_dict.items():
  204. info_loss += f' {loss_name}: {loss_value:.5g} |'
  205. logger.info(' '.join((info_progress, info_loss)))
  206. info_loss = f'@==Final== Epoch[{epoch}/{args.epochs}] Training Loss: {self.loss_log.avg:.5g} '
  207. logger.info(info_loss)
  208. self.lr_scheduler.step()
  209. return self.loss_log.avg
  210. def main():
  211. trainer = Trainer(
  212. data_loaders=init_data_loaders(to_be_distributed),
  213. model_opt_lrsch=init_models_optimizers(args.epochs, to_be_distributed)
  214. )
  215. for epoch in range(epoch_st, args.epochs+1):
  216. train_loss = trainer.train_epoch(epoch)
  217. # Save checkpoint
  218. if epoch >= args.epochs - config.save_last and epoch % config.save_step == 0:
  219. if args.use_accelerate:
  220. state_dict = trainer.model.state_dict()
  221. else:
  222. state_dict = trainer.model.module.state_dict() if to_be_distributed else trainer.model.state_dict()
  223. torch.save(state_dict, os.path.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch)))
  224. if to_be_distributed:
  225. destroy_process_group()
  226. if __name__ == '__main__':
  227. main()