| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262 |
- import os
- import datetime
- from contextlib import nullcontext
- import argparse
- import torch
- import torch.nn as nn
- import torch.optim as optim
- if tuple(map(int, torch.__version__.split('+')[0].split(".")[:3])) >= (2, 5, 0):
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
- from config import Config
- from loss import PixLoss, ClsLoss
- from dataset import MyData
- from models.birefnet import BiRefNet
- from utils import Logger, AverageMeter, set_seed, check_state_dict
- from torch.utils.data.distributed import DistributedSampler
- from torch.nn.parallel import DistributedDataParallel as DDP
- from torch.distributed import init_process_group, destroy_process_group
- parser = argparse.ArgumentParser(description='')
- parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint')
- parser.add_argument('--epochs', default=120, type=int)
- parser.add_argument('--ckpt_dir', default='ckpts/tmp', help='Temporary folder')
- parser.add_argument('--dist', default=False, type=lambda x: x == 'True')
- parser.add_argument('--use_accelerate', action='store_true', help='`accelerate launch --multi_gpu train.py --use_accelerate`. Use accelerate for training, good for FP16/BF16/...')
- args = parser.parse_args()
- config = Config()
- if args.use_accelerate:
- from accelerate import Accelerator, utils
- mixed_precision = config.mixed_precision
- kwargs_handlers = [
- utils.InitProcessGroupKwargs(backend="nccl", timeout=datetime.timedelta(seconds=3600*10)),
- utils.DistributedDataParallelKwargs(find_unused_parameters=False),
- utils.GradScalerKwargs(backoff_factor=0.5),
- ]
- if mixed_precision == 'fp8':
- kwargs_handlers.append(utils.AORecipeKwargs())
- accelerator = Accelerator(
- mixed_precision=mixed_precision,
- gradient_accumulation_steps=1,
- kwargs_handlers=kwargs_handlers,
- )
- accelerator.print(accelerator.state)
- accelerator.print('backbone:', config.bb, ', freeze_bb:', config.freeze_bb)
- args.dist = False
- # DDP
- to_be_distributed = args.dist
- if to_be_distributed:
- init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=3600*10))
- device = int(os.environ["LOCAL_RANK"])
- else:
- if args.use_accelerate:
- device = accelerator.local_process_index
- else:
- device = config.device
- if config.rand_seed:
- set_seed(config.rand_seed + device)
- epoch_st = 1
- # make dir for ckpt
- os.makedirs(args.ckpt_dir, exist_ok=True)
- # Init log file
- logger = Logger(os.path.join(args.ckpt_dir, "log.txt"))
- logger_loss_idx = 1
- # log model and optimizer params
- # logger.info("Model details:"); logger.info(model)
- # if args.use_accelerate and accelerator.mixed_precision != 'no':
- # config.compile = False
- logger.info("datasets: load_all={}, compile={}.".format(config.load_all, config.compile))
- logger.info("Other hyperparameters:"); logger.info(args)
- print('batch size:', config.batch_size)
- from dataset import custom_collate_fn
- def prepare_dataloader(dataset: torch.utils.data.Dataset, batch_size: int, to_be_distributed=False, is_train=True):
- # Prepare dataloaders
- if to_be_distributed:
- return torch.utils.data.DataLoader(
- dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size), pin_memory=True,
- shuffle=False, sampler=DistributedSampler(dataset), drop_last=True, collate_fn=custom_collate_fn if is_train and config.dynamic_size else None
- )
- else:
- return torch.utils.data.DataLoader(
- dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size), pin_memory=True,
- shuffle=is_train, sampler=None, drop_last=True, collate_fn=custom_collate_fn if is_train and config.dynamic_size else None
- )
- def init_data_loaders(to_be_distributed):
- # Prepare datasets
- train_loader = prepare_dataloader(
- MyData(datasets=config.training_set, data_size=None if config.dynamic_size else config.size, is_train=True),
- config.batch_size, to_be_distributed=to_be_distributed, is_train=True
- )
- print(len(train_loader), "batches of train dataloader {} have been created.".format(config.training_set))
- return train_loader
- def init_models_optimizers(epochs, to_be_distributed):
- # Init models
- if config.model == 'BiRefNet':
- model = BiRefNet(bb_pretrained=True and not os.path.isfile(str(args.resume)))
- else:
- print('Undefined model: {}.'.format(config.model))
- return None
- if args.resume:
- if os.path.isfile(args.resume):
- logger.info("=> loading checkpoint '{}'".format(args.resume))
- state_dict = torch.load(args.resume, map_location='cpu', weights_only=True)
- state_dict = check_state_dict(state_dict)
- model.load_state_dict(state_dict)
- global epoch_st
- epoch_st = int(args.resume.rstrip('.pth').split('epoch_')[-1]) + 1
- else:
- logger.info("=> no checkpoint found at '{}'".format(args.resume))
- if not args.use_accelerate:
- if to_be_distributed:
- model = model.to(device)
- model = DDP(model, device_ids=[device])
- else:
- model = model.to(device)
- if config.compile:
- model = torch.compile(model, mode=['default', 'reduce-overhead', 'max-autotune'][0])
- if config.precisionHigh:
- torch.set_float32_matmul_precision('high')
- # Setting optimizer
- if config.optimizer == 'AdamW':
- optimizer = optim.AdamW(params=[p for p in model.parameters() if p.requires_grad], lr=config.lr, weight_decay=1e-2)
- elif config.optimizer == 'Adam':
- optimizer = optim.Adam(params=[p for p in model.parameters() if p.requires_grad], lr=config.lr, weight_decay=0)
- lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
- optimizer,
- milestones=[lde if lde > 0 else epochs + lde + 1 for lde in config.lr_decay_epochs],
- gamma=config.lr_decay_rate
- )
- # logger.info("Optimizer details:"); logger.info(optimizer)
- return model, optimizer, lr_scheduler
- class Trainer:
- def __init__(
- self, data_loaders, model_opt_lrsch,
- ):
- self.model, self.optimizer, self.lr_scheduler = model_opt_lrsch
- self.train_loader = data_loaders
- if args.use_accelerate:
- self.train_loader, self.model, self.optimizer = accelerator.prepare(self.train_loader, self.model, self.optimizer)
- if config.out_ref:
- self.criterion_gdt = nn.BCELoss()
- # Setting Losses
- self.pix_loss = PixLoss()
- self.cls_loss = ClsLoss()
-
- # Others
- self.loss_log = AverageMeter()
- def _train_batch(self, batch):
- if args.use_accelerate:
- inputs = batch[0]#.to(device)
- gts = batch[1]#.to(device)
- class_labels = batch[2]#.to(device)
- else:
- inputs = batch[0].to(device)
- gts = batch[1].to(device)
- class_labels = batch[2].to(device)
- self.optimizer.zero_grad()
- scaled_preds, class_preds_lst = self.model(inputs)
- if config.out_ref:
- (outs_gdt_pred, outs_gdt_label), scaled_preds = scaled_preds
- for _idx, (_gdt_pred, _gdt_label) in enumerate(zip(outs_gdt_pred, outs_gdt_label)):
- _gdt_pred = nn.functional.interpolate(_gdt_pred, size=_gdt_label.shape[2:], mode='bilinear', align_corners=True).sigmoid()
- _gdt_label = _gdt_label.sigmoid()
- loss_gdt = self.criterion_gdt(_gdt_pred, _gdt_label) if _idx == 0 else self.criterion_gdt(_gdt_pred, _gdt_label) + loss_gdt
- # self.loss_dict['loss_gdt'] = loss_gdt.item()
- if None in class_preds_lst:
- loss_cls = 0.
- else:
- loss_cls = self.cls_loss(class_preds_lst, class_labels)
- self.loss_dict['loss_cls'] = loss_cls.item()
- # Loss
- loss_pix, loss_dict_pix = self.pix_loss(scaled_preds, torch.clamp(gts, 0, 1), pix_loss_lambda=1.0)
- self.loss_dict.update(loss_dict_pix)
- self.loss_dict['loss_pix'] = loss_pix.item()
- # since there may be several losses for sal, the lambdas for them (lambdas_pix) are inside the loss.py
- loss = loss_pix + loss_cls
- if config.out_ref:
- loss = loss + loss_gdt * 1.0
- self.loss_log.update(loss.item(), inputs.size(0))
- if args.use_accelerate:
- loss = loss / accelerator.gradient_accumulation_steps
- accelerator.backward(loss)
- else:
- loss.backward()
- self.optimizer.step()
- def train_epoch(self, epoch):
- global logger_loss_idx
- self.model.train()
- self.loss_dict = {}
- if epoch > args.epochs + config.finetune_last_epochs:
- if config.task == 'Matting':
- self.pix_loss.lambdas_pix_last['mae'] *= 1
- self.pix_loss.lambdas_pix_last['mse'] *= 0.9
- self.pix_loss.lambdas_pix_last['ssim'] *= 0.9
- else:
- self.pix_loss.lambdas_pix_last['bce'] *= 0
- self.pix_loss.lambdas_pix_last['ssim'] *= 1
- self.pix_loss.lambdas_pix_last['iou'] *= 0.5
- self.pix_loss.lambdas_pix_last['mae'] *= 0.9
- for batch_idx, batch in enumerate(self.train_loader):
- # with nullcontext if not args.use_accelerate or accelerator.gradient_accumulation_steps <= 1 else accelerator.accumulate(self.model):
- self._train_batch(batch)
- # Logger
- if (epoch < 2 and batch_idx < 100 and batch_idx % 20 == 0) or batch_idx % max(100, len(self.train_loader) / 100 // 100 * 100) == 0:
- info_progress = f'Epoch[{epoch}/{args.epochs}] Iter[{batch_idx}/{len(self.train_loader)}].'
- info_loss = 'Training Losses:'
- for loss_name, loss_value in self.loss_dict.items():
- info_loss += f' {loss_name}: {loss_value:.5g} |'
- logger.info(' '.join((info_progress, info_loss)))
- info_loss = f'@==Final== Epoch[{epoch}/{args.epochs}] Training Loss: {self.loss_log.avg:.5g} '
- logger.info(info_loss)
- self.lr_scheduler.step()
- return self.loss_log.avg
- def main():
- trainer = Trainer(
- data_loaders=init_data_loaders(to_be_distributed),
- model_opt_lrsch=init_models_optimizers(args.epochs, to_be_distributed)
- )
- for epoch in range(epoch_st, args.epochs+1):
- train_loss = trainer.train_epoch(epoch)
- # Save checkpoint
- if epoch >= args.epochs - config.save_last and epoch % config.save_step == 0:
- if args.use_accelerate:
- state_dict = trainer.model.state_dict()
- else:
- state_dict = trainer.model.module.state_dict() if to_be_distributed else trainer.model.state_dict()
- torch.save(state_dict, os.path.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch)))
- if to_be_distributed:
- destroy_process_group()
- if __name__ == '__main__':
- main()
|