config.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import logging
  2. from .constants import *
  3. _logger = logging.getLogger(__name__)
  4. def resolve_data_config(
  5. args=None,
  6. pretrained_cfg=None,
  7. model=None,
  8. use_test_size=False,
  9. verbose=False
  10. ):
  11. assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config."
  12. args = args or {}
  13. pretrained_cfg = pretrained_cfg or {}
  14. if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'):
  15. pretrained_cfg = model.pretrained_cfg
  16. data_config = {}
  17. # Resolve input/image size
  18. in_chans = 3
  19. if args.get('in_chans', None) is not None:
  20. in_chans = args['in_chans']
  21. elif args.get('chans', None) is not None:
  22. in_chans = args['chans']
  23. input_size = (in_chans, 224, 224)
  24. if args.get('input_size', None) is not None:
  25. assert isinstance(args['input_size'], (tuple, list))
  26. assert len(args['input_size']) == 3
  27. input_size = tuple(args['input_size'])
  28. in_chans = input_size[0] # input_size overrides in_chans
  29. elif args.get('img_size', None) is not None:
  30. assert isinstance(args['img_size'], int)
  31. input_size = (in_chans, args['img_size'], args['img_size'])
  32. else:
  33. if use_test_size and pretrained_cfg.get('test_input_size', None) is not None:
  34. input_size = pretrained_cfg['test_input_size']
  35. elif pretrained_cfg.get('input_size', None) is not None:
  36. input_size = pretrained_cfg['input_size']
  37. data_config['input_size'] = input_size
  38. # resolve interpolation method
  39. data_config['interpolation'] = 'bicubic'
  40. if args.get('interpolation', None):
  41. data_config['interpolation'] = args['interpolation']
  42. elif pretrained_cfg.get('interpolation', None):
  43. data_config['interpolation'] = pretrained_cfg['interpolation']
  44. # resolve dataset + model mean for normalization
  45. data_config['mean'] = IMAGENET_DEFAULT_MEAN
  46. if args.get('mean', None) is not None:
  47. mean = tuple(args['mean'])
  48. if len(mean) == 1:
  49. mean = tuple(list(mean) * in_chans)
  50. else:
  51. assert len(mean) == in_chans
  52. data_config['mean'] = mean
  53. elif pretrained_cfg.get('mean', None):
  54. data_config['mean'] = pretrained_cfg['mean']
  55. # resolve dataset + model std deviation for normalization
  56. data_config['std'] = IMAGENET_DEFAULT_STD
  57. if args.get('std', None) is not None:
  58. std = tuple(args['std'])
  59. if len(std) == 1:
  60. std = tuple(list(std) * in_chans)
  61. else:
  62. assert len(std) == in_chans
  63. data_config['std'] = std
  64. elif pretrained_cfg.get('std', None):
  65. data_config['std'] = pretrained_cfg['std']
  66. # resolve default inference crop
  67. crop_pct = DEFAULT_CROP_PCT
  68. if args.get('crop_pct', None):
  69. crop_pct = args['crop_pct']
  70. else:
  71. if use_test_size and pretrained_cfg.get('test_crop_pct', None):
  72. crop_pct = pretrained_cfg['test_crop_pct']
  73. elif pretrained_cfg.get('crop_pct', None):
  74. crop_pct = pretrained_cfg['crop_pct']
  75. data_config['crop_pct'] = crop_pct
  76. # resolve default crop percentage
  77. crop_mode = DEFAULT_CROP_MODE
  78. if args.get('crop_mode', None):
  79. crop_mode = args['crop_mode']
  80. elif pretrained_cfg.get('crop_mode', None):
  81. crop_mode = pretrained_cfg['crop_mode']
  82. data_config['crop_mode'] = crop_mode
  83. if verbose:
  84. _logger.info('Data processing configuration for current model + dataset:')
  85. for n, v in data_config.items():
  86. _logger.info('\t%s: %s' % (n, str(v)))
  87. return data_config
  88. def resolve_model_data_config(
  89. model,
  90. args=None,
  91. pretrained_cfg=None,
  92. use_test_size=False,
  93. verbose=False,
  94. ):
  95. """ Resolve Model Data Config
  96. This is equivalent to resolve_data_config() but with arguments re-ordered to put model first.
  97. Args:
  98. model (nn.Module): the model instance
  99. args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg)
  100. pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model)
  101. use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution
  102. verbose (bool): enable extra logging of resolved values
  103. Returns:
  104. dictionary of config
  105. """
  106. return resolve_data_config(
  107. args=args,
  108. pretrained_cfg=pretrained_cfg,
  109. model=model,
  110. use_test_size=use_test_size,
  111. verbose=verbose,
  112. )