inference.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import os
  2. import argparse
  3. from glob import glob
  4. from tqdm import tqdm
  5. import cv2
  6. import torch
  7. from contextlib import nullcontext
  8. from dataset import MyData
  9. from models.birefnet import BiRefNet
  10. from utils import save_tensor_img, check_state_dict
  11. from config import Config
  12. config = Config()
  13. mixed_precision = config.mixed_precision
  14. if mixed_precision == 'fp16':
  15. mixed_dtype = torch.float16
  16. elif mixed_precision == 'bf16':
  17. mixed_dtype = torch.bfloat16
  18. else:
  19. mixed_dtype = None
  20. autocast_ctx = torch.amp.autocast(device_type='cuda', dtype=mixed_dtype) if mixed_dtype else nullcontext()
  21. def inference(model, data_loader_test, pred_root, method, testset, device=0):
  22. model_training = model.training
  23. if model_training:
  24. model.eval()
  25. for batch in tqdm(data_loader_test, total=len(data_loader_test)) if config.verbose_eval else data_loader_test:
  26. inputs = batch[0].to(device)
  27. label_paths = batch[-1]
  28. with autocast_ctx, torch.no_grad():
  29. scaled_preds = model(inputs)[-1].sigmoid().to(torch.float32)
  30. os.makedirs(os.path.join(pred_root, method, testset), exist_ok=True)
  31. for idx_sample in range(scaled_preds.shape[0]):
  32. res = torch.nn.functional.interpolate(
  33. scaled_preds[idx_sample].unsqueeze(0),
  34. size=cv2.imread(label_paths[idx_sample], cv2.IMREAD_GRAYSCALE).shape[:2],
  35. mode='bilinear',
  36. align_corners=True
  37. )
  38. save_tensor_img(res, os.path.join(os.path.join(pred_root, method, testset), label_paths[idx_sample].replace('\\', '/').split('/')[-1])) # test set dir + file name
  39. if model_training:
  40. model.train()
  41. return None
  42. def main(args):
  43. device = config.device
  44. if args.ckpt_folder:
  45. print('Testing with models in {}'.format(args.ckpt_folder))
  46. else:
  47. print('Testing with model {}'.format(args.ckpt))
  48. if config.model == 'BiRefNet':
  49. model = BiRefNet(bb_pretrained=False)
  50. else:
  51. print('Undefined model: {}.'.format(config.model))
  52. return None
  53. weights_lst = sorted(
  54. glob(os.path.join(args.ckpt_folder, '*.pth')) if args.ckpt_folder else [args.ckpt],
  55. key=lambda x: int(x.split('epoch_')[-1].split('.pth')[0]),
  56. reverse=True
  57. )
  58. try:
  59. if args.resolution in [None, 'None', 0, '']:
  60. # Use original resolution for inference.
  61. data_size = None
  62. elif args.resolution in ['config.size']:
  63. data_size = config.size
  64. else:
  65. data_size = [int(l) for l in args.resolution.split('x')]
  66. except Exception as e:
  67. print(f"Exception: {type(e).__name__} at line {e.__traceback__.tb_lineno} of {__file__}: {e}")
  68. # default as the config.size.
  69. data_size = config.size
  70. for testset in args.testsets.split('+'):
  71. print('>>>> Testset: {}...'.format(testset))
  72. data_loader_test = torch.utils.data.DataLoader(
  73. dataset=MyData(testset, data_size=data_size, is_train=False),
  74. batch_size=config.batch_size_valid, shuffle=False, num_workers=config.num_workers, pin_memory=True
  75. )
  76. for weights in weights_lst:
  77. if int(weights.strip('.pth').split('epoch_')[-1]) % 1 != 0:
  78. continue
  79. print('\tInferencing {}...'.format(weights))
  80. state_dict = torch.load(weights, map_location='cpu', weights_only=True)
  81. state_dict = check_state_dict(state_dict)
  82. model.load_state_dict(state_dict)
  83. model = model.to(device)
  84. inference(
  85. model, data_loader_test=data_loader_test, pred_root=args.pred_root,
  86. method='--'.join([w.rstrip('.pth') for w in weights.split(os.sep)[-2:]]) + '-reso_{}'.format('x'.join([str(s) for s in data_size])),
  87. testset=testset, device=config.device
  88. )
  89. if __name__ == '__main__':
  90. # Parameter from command line
  91. parser = argparse.ArgumentParser(description='')
  92. parser.add_argument('--ckpt', type=str, help='model folder')
  93. parser.add_argument('--ckpt_folder', default=sorted(glob(os.path.join('ckpts', '*')))[-1], type=str, help='model folder')
  94. parser.add_argument('--pred_root', default='e_preds', type=str, help='Output folder')
  95. parser.add_argument('--resolution', default='default', type=str, help='WeixHei')
  96. parser.add_argument('--testsets',
  97. default=config.testsets.replace(',', '+'),
  98. type=str,
  99. help="Test all sets: DIS5K -> 'DIS-VD+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'")
  100. args = parser.parse_args()
  101. if config.precisionHigh:
  102. torch.set_float32_matmul_precision('high')
  103. main(args)