_builder.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. import dataclasses
  2. import logging
  3. import os
  4. from copy import deepcopy
  5. from pathlib import Path
  6. from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
  7. from torch import nn as nn
  8. from torch.hub import load_state_dict_from_url
  9. from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
  10. from timm.models._features_fx import FeatureGraphNet
  11. from timm.models._helpers import load_state_dict
  12. from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf, \
  13. load_state_dict_from_path, load_custom_from_hf
  14. from timm.models._manipulate import adapt_input_conv
  15. from timm.models._pretrained import PretrainedCfg
  16. from timm.models._prune import adapt_model_from_file
  17. from timm.models._registry import get_pretrained_cfg
  18. _logger = logging.getLogger(__name__)
  19. # Global variables for rarely used pretrained checkpoint download progress and hash check.
  20. # Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
  21. _DOWNLOAD_PROGRESS = False
  22. _CHECK_HASH = False
  23. _USE_OLD_CACHE = int(os.environ.get('TIMM_USE_OLD_CACHE', 0)) > 0
  24. __all__ = [
  25. 'set_pretrained_download_progress',
  26. 'set_pretrained_check_hash',
  27. 'load_custom_pretrained',
  28. 'load_pretrained',
  29. 'pretrained_cfg_for_features',
  30. 'resolve_pretrained_cfg',
  31. 'build_model_with_cfg',
  32. ]
  33. ModelT = TypeVar("ModelT", bound=nn.Module) # any subclass of nn.Module
  34. def _resolve_pretrained_source(pretrained_cfg: Dict[str, Any]) -> Tuple[str, str]:
  35. cfg_source = pretrained_cfg.get('source', '')
  36. pretrained_url = pretrained_cfg.get('url', None)
  37. pretrained_file = pretrained_cfg.get('file', None)
  38. pretrained_sd = pretrained_cfg.get('state_dict', None)
  39. hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
  40. # resolve where to load pretrained weights from
  41. load_from = ''
  42. pretrained_loc = ''
  43. if cfg_source == 'hf-hub' and has_hf_hub(necessary=True):
  44. # hf-hub specified as source via model identifier
  45. load_from = 'hf-hub'
  46. assert hf_hub_id
  47. pretrained_loc = hf_hub_id
  48. elif cfg_source == 'local-dir':
  49. load_from = 'local-dir'
  50. pretrained_loc = pretrained_file
  51. else:
  52. # default source == timm or unspecified
  53. if pretrained_sd:
  54. # direct state_dict pass through is the highest priority
  55. load_from = 'state_dict'
  56. pretrained_loc = pretrained_sd
  57. assert isinstance(pretrained_loc, dict)
  58. elif pretrained_file:
  59. # file load override is the second-highest priority if set
  60. load_from = 'file'
  61. pretrained_loc = pretrained_file
  62. else:
  63. old_cache_valid = False
  64. if _USE_OLD_CACHE:
  65. # prioritized old cached weights if exists and env var enabled
  66. old_cache_valid = check_cached_file(pretrained_url) if pretrained_url else False
  67. if not old_cache_valid and hf_hub_id and has_hf_hub(necessary=True):
  68. # hf-hub available as alternate weight source in default_cfg
  69. load_from = 'hf-hub'
  70. pretrained_loc = hf_hub_id
  71. elif pretrained_url:
  72. load_from = 'url'
  73. pretrained_loc = pretrained_url
  74. if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
  75. # if a filename override is set, return tuple for location w/ (hub_id, filename)
  76. pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
  77. return load_from, pretrained_loc
  78. def set_pretrained_download_progress(enable: bool = True) -> None:
  79. """ Set download progress for pretrained weights on/off (globally). """
  80. global _DOWNLOAD_PROGRESS
  81. _DOWNLOAD_PROGRESS = enable
  82. def set_pretrained_check_hash(enable: bool = True) -> None:
  83. """ Set hash checking for pretrained weights on/off (globally). """
  84. global _CHECK_HASH
  85. _CHECK_HASH = enable
  86. def load_custom_pretrained(
  87. model: nn.Module,
  88. pretrained_cfg: Optional[Dict[str, Any]] = None,
  89. load_fn: Optional[Callable] = None,
  90. cache_dir: Optional[Union[str, Path]] = None,
  91. ) -> None:
  92. """Loads a custom (read non .pth) weight file
  93. Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
  94. a passed in custom load fun, or the `load_pretrained` model member fn.
  95. If the object is already present in `model_dir`, it's deserialized and returned.
  96. The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
  97. `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
  98. Args:
  99. model: The instantiated model to load weights into
  100. pretrained_cfg: Default pretrained model cfg
  101. load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
  102. 'load_pretrained' on the model will be called if it exists
  103. cache_dir: Override model checkpoint cache dir for this load
  104. """
  105. pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
  106. if not pretrained_cfg:
  107. _logger.warning("Invalid pretrained config, cannot load weights.")
  108. return
  109. load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
  110. if not load_from:
  111. _logger.warning("No pretrained weights exist for this model. Using random initialization.")
  112. return
  113. if load_from == 'hf-hub':
  114. _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
  115. elif load_from == 'url':
  116. pretrained_loc = download_cached_file(
  117. pretrained_loc,
  118. check_hash=_CHECK_HASH,
  119. progress=_DOWNLOAD_PROGRESS,
  120. cache_dir=cache_dir,
  121. )
  122. if load_fn is not None:
  123. load_fn(model, pretrained_loc)
  124. elif hasattr(model, 'load_pretrained'):
  125. model.load_pretrained(pretrained_loc)
  126. else:
  127. _logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
  128. def load_pretrained(
  129. model: nn.Module,
  130. pretrained_cfg: Optional[Dict[str, Any]] = None,
  131. num_classes: int = 1000,
  132. in_chans: int = 3,
  133. filter_fn: Optional[Callable] = None,
  134. strict: bool = True,
  135. cache_dir: Optional[Union[str, Path]] = None,
  136. ) -> None:
  137. """ Load pretrained checkpoint
  138. Args:
  139. model: PyTorch module
  140. pretrained_cfg: Configuration for pretrained weights / target dataset
  141. num_classes: Number of classes for target model. Will adapt pretrained if different.
  142. in_chans: Number of input chans for target model. Will adapt pretrained if different.
  143. filter_fn: state_dict filter fn for load (takes state_dict, model as args)
  144. strict: Strict load of checkpoint
  145. cache_dir: Override model checkpoint cache dir for this load
  146. """
  147. pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
  148. if not pretrained_cfg:
  149. raise RuntimeError("Invalid pretrained config, cannot load weights. Use `pretrained=False` for random init.")
  150. load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
  151. if load_from == 'state_dict':
  152. _logger.info(f'Loading pretrained weights from state dict')
  153. state_dict = pretrained_loc # pretrained_loc is the actual state dict for this override
  154. elif load_from == 'file':
  155. _logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
  156. if pretrained_cfg.get('custom_load', False):
  157. model.load_pretrained(pretrained_loc)
  158. return
  159. else:
  160. state_dict = load_state_dict(pretrained_loc)
  161. elif load_from == 'url':
  162. _logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
  163. if pretrained_cfg.get('custom_load', False):
  164. pretrained_loc = download_cached_file(
  165. pretrained_loc,
  166. progress=_DOWNLOAD_PROGRESS,
  167. check_hash=_CHECK_HASH,
  168. cache_dir=cache_dir,
  169. )
  170. model.load_pretrained(pretrained_loc)
  171. return
  172. else:
  173. try:
  174. state_dict = load_state_dict_from_url(
  175. pretrained_loc,
  176. map_location='cpu',
  177. progress=_DOWNLOAD_PROGRESS,
  178. check_hash=_CHECK_HASH,
  179. weights_only=True,
  180. model_dir=cache_dir,
  181. )
  182. except TypeError:
  183. state_dict = load_state_dict_from_url(
  184. pretrained_loc,
  185. map_location='cpu',
  186. progress=_DOWNLOAD_PROGRESS,
  187. check_hash=_CHECK_HASH,
  188. model_dir=cache_dir,
  189. )
  190. elif load_from == 'hf-hub':
  191. _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
  192. if isinstance(pretrained_loc, (list, tuple)):
  193. custom_load = pretrained_cfg.get('custom_load', False)
  194. if isinstance(custom_load, str) and custom_load == 'hf':
  195. load_custom_from_hf(*pretrained_loc, model, cache_dir=cache_dir)
  196. return
  197. else:
  198. state_dict = load_state_dict_from_hf(*pretrained_loc, cache_dir=cache_dir)
  199. else:
  200. state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True, cache_dir=cache_dir)
  201. elif load_from == 'local-dir':
  202. _logger.info(f'Loading pretrained weights from local directory ({pretrained_loc})')
  203. pretrained_path = Path(pretrained_loc)
  204. if pretrained_path.is_dir():
  205. state_dict = load_state_dict_from_path(pretrained_path)
  206. else:
  207. raise RuntimeError(f"Specified path is not a directory: {pretrained_loc}")
  208. else:
  209. model_name = pretrained_cfg.get('architecture', 'this model')
  210. raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")
  211. if filter_fn is not None:
  212. try:
  213. state_dict = filter_fn(state_dict, model)
  214. except TypeError as e:
  215. # for backwards compat with filter fn that take one arg
  216. state_dict = filter_fn(state_dict)
  217. input_convs = pretrained_cfg.get('first_conv', None)
  218. if input_convs is not None and in_chans != 3:
  219. if isinstance(input_convs, str):
  220. input_convs = (input_convs,)
  221. for input_conv_name in input_convs:
  222. weight_name = input_conv_name + '.weight'
  223. try:
  224. state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
  225. _logger.info(
  226. f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
  227. except NotImplementedError as e:
  228. del state_dict[weight_name]
  229. strict = False
  230. _logger.warning(
  231. f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
  232. classifiers = pretrained_cfg.get('classifier', None)
  233. label_offset = pretrained_cfg.get('label_offset', 0)
  234. if classifiers is not None:
  235. if isinstance(classifiers, str):
  236. classifiers = (classifiers,)
  237. if num_classes != pretrained_cfg['num_classes']:
  238. for classifier_name in classifiers:
  239. # completely discard fully connected if model num_classes doesn't match pretrained weights
  240. state_dict.pop(classifier_name + '.weight', None)
  241. state_dict.pop(classifier_name + '.bias', None)
  242. strict = False
  243. elif label_offset > 0:
  244. for classifier_name in classifiers:
  245. # special case for pretrained weights with an extra background class in pretrained weights
  246. classifier_weight = state_dict[classifier_name + '.weight']
  247. state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
  248. classifier_bias = state_dict[classifier_name + '.bias']
  249. state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
  250. load_result = model.load_state_dict(state_dict, strict=strict)
  251. if load_result.missing_keys:
  252. _logger.info(
  253. f'Missing keys ({", ".join(load_result.missing_keys)}) discovered while loading pretrained weights.'
  254. f' This is expected if model is being adapted.')
  255. if load_result.unexpected_keys:
  256. _logger.warning(
  257. f'Unexpected keys ({", ".join(load_result.unexpected_keys)}) found while loading pretrained weights.'
  258. f' This may be expected if model is being adapted.')
  259. def pretrained_cfg_for_features(pretrained_cfg: Dict[str, Any]) -> Dict[str, Any]:
  260. pretrained_cfg = deepcopy(pretrained_cfg)
  261. # remove default pretrained cfg fields that don't have much relevance for feature backbone
  262. to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size?
  263. for tr in to_remove:
  264. pretrained_cfg.pop(tr, None)
  265. return pretrained_cfg
  266. def _filter_kwargs(kwargs: Dict[str, Any], names: List[str]) -> None:
  267. if not kwargs or not names:
  268. return
  269. for n in names:
  270. kwargs.pop(n, None)
  271. def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter) -> None:
  272. """ Update the default_cfg and kwargs before passing to model
  273. Args:
  274. pretrained_cfg: input pretrained cfg (updated in-place)
  275. kwargs: keyword args passed to model build fn (updated in-place)
  276. kwargs_filter: keyword arg keys that must be removed before model __init__
  277. """
  278. # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
  279. default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
  280. if pretrained_cfg.get('fixed_input_size', False):
  281. # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
  282. default_kwarg_names += ('img_size',)
  283. for n in default_kwarg_names:
  284. # for legacy reasons, model __init__args uses img_size + in_chans as separate args while
  285. # pretrained_cfg has one input_size=(C, H ,W) entry
  286. if n == 'img_size':
  287. input_size = pretrained_cfg.get('input_size', None)
  288. if input_size is not None:
  289. assert len(input_size) == 3
  290. kwargs.setdefault(n, input_size[-2:])
  291. elif n == 'in_chans':
  292. input_size = pretrained_cfg.get('input_size', None)
  293. if input_size is not None:
  294. assert len(input_size) == 3
  295. kwargs.setdefault(n, input_size[0])
  296. elif n == 'num_classes':
  297. default_val = pretrained_cfg.get(n, None)
  298. # if default is < 0, don't pass through to model
  299. if default_val is not None and default_val >= 0:
  300. kwargs.setdefault(n, pretrained_cfg[n])
  301. else:
  302. default_val = pretrained_cfg.get(n, None)
  303. if default_val is not None:
  304. kwargs.setdefault(n, pretrained_cfg[n])
  305. # Filter keyword args for task specific model variants (some 'features only' models, etc.)
  306. _filter_kwargs(kwargs, names=kwargs_filter)
  307. def resolve_pretrained_cfg(
  308. variant: str,
  309. pretrained_cfg: Optional[Union[str, Dict[str, Any]]] = None,
  310. pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
  311. ) -> PretrainedCfg:
  312. """Resolve pretrained configuration from various sources."""
  313. model_with_tag = variant
  314. pretrained_tag = None
  315. if pretrained_cfg:
  316. if isinstance(pretrained_cfg, dict):
  317. # pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg
  318. pretrained_cfg = PretrainedCfg(**pretrained_cfg)
  319. elif isinstance(pretrained_cfg, str):
  320. pretrained_tag = pretrained_cfg
  321. pretrained_cfg = None
  322. # fallback to looking up pretrained cfg in model registry by variant identifier
  323. if not pretrained_cfg:
  324. if pretrained_tag:
  325. model_with_tag = '.'.join([variant, pretrained_tag])
  326. pretrained_cfg = get_pretrained_cfg(model_with_tag)
  327. if not pretrained_cfg:
  328. _logger.warning(
  329. f"No pretrained configuration specified for {model_with_tag} model. Using a default."
  330. f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
  331. pretrained_cfg = PretrainedCfg() # instance with defaults
  332. pretrained_cfg_overlay = pretrained_cfg_overlay or {}
  333. if not pretrained_cfg.architecture:
  334. pretrained_cfg_overlay.setdefault('architecture', variant)
  335. pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay)
  336. return pretrained_cfg
  337. def build_model_with_cfg(
  338. model_cls: Union[Type[ModelT], Callable[..., ModelT]],
  339. variant: str,
  340. pretrained: bool,
  341. pretrained_cfg: Optional[Dict] = None,
  342. pretrained_cfg_overlay: Optional[Dict] = None,
  343. model_cfg: Optional[Any] = None,
  344. feature_cfg: Optional[Dict] = None,
  345. pretrained_strict: bool = True,
  346. pretrained_filter_fn: Optional[Callable] = None,
  347. cache_dir: Optional[Union[str, Path]] = None,
  348. kwargs_filter: Optional[Tuple[str]] = None,
  349. **kwargs,
  350. ) -> ModelT:
  351. """ Build model with specified default_cfg and optional model_cfg
  352. This helper fn aids in the construction of a model including:
  353. * handling default_cfg and associated pretrained weight loading
  354. * passing through optional model_cfg for models with config based arch spec
  355. * features_only model adaptation
  356. * pruning config / model adaptation
  357. Args:
  358. model_cls: Model class
  359. variant: Model variant name
  360. pretrained: Load the pretrained weights
  361. pretrained_cfg: Model's pretrained weight/task config
  362. pretrained_cfg_overlay: Entries that will override those in pretrained_cfg
  363. model_cfg: Model's architecture config
  364. feature_cfg: Feature extraction adapter config
  365. pretrained_strict: Load pretrained weights strictly
  366. pretrained_filter_fn: Filter callable for pretrained weights
  367. cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints
  368. kwargs_filter: Kwargs keys to filter (remove) before passing to model
  369. **kwargs: Model args passed through to model __init__
  370. """
  371. pruned = kwargs.pop('pruned', False)
  372. features = False
  373. feature_cfg = feature_cfg or {}
  374. # resolve and update model pretrained config and model kwargs
  375. pretrained_cfg = resolve_pretrained_cfg(
  376. variant,
  377. pretrained_cfg=pretrained_cfg,
  378. pretrained_cfg_overlay=pretrained_cfg_overlay
  379. )
  380. pretrained_cfg = pretrained_cfg.to_dict()
  381. _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter)
  382. # Setup for feature extraction wrapper done at end of this fn
  383. if kwargs.pop('features_only', False):
  384. features = True
  385. feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
  386. if 'out_indices' in kwargs:
  387. feature_cfg['out_indices'] = kwargs.pop('out_indices')
  388. if 'feature_cls' in kwargs:
  389. feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
  390. # Instantiate the model
  391. if model_cfg is None:
  392. model = model_cls(**kwargs)
  393. else:
  394. model = model_cls(cfg=model_cfg, **kwargs)
  395. model.pretrained_cfg = pretrained_cfg
  396. model.default_cfg = model.pretrained_cfg # alias for backwards compat
  397. if pruned:
  398. model = adapt_model_from_file(model, variant)
  399. # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
  400. num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
  401. if pretrained:
  402. load_pretrained(
  403. model,
  404. pretrained_cfg=pretrained_cfg,
  405. num_classes=num_classes_pretrained,
  406. in_chans=kwargs.get('in_chans', 3),
  407. filter_fn=pretrained_filter_fn,
  408. strict=pretrained_strict,
  409. cache_dir=cache_dir,
  410. )
  411. # Wrap the model in a feature extraction module if enabled
  412. if features:
  413. use_getter = False
  414. if 'feature_cls' in feature_cfg:
  415. feature_cls = feature_cfg.pop('feature_cls')
  416. if isinstance(feature_cls, str):
  417. feature_cls = feature_cls.lower()
  418. # flatten_sequential only valid for some feature extractors
  419. if feature_cls not in ('dict', 'list', 'hook'):
  420. feature_cfg.pop('flatten_sequential', None)
  421. if 'hook' in feature_cls:
  422. feature_cls = FeatureHookNet
  423. elif feature_cls == 'list':
  424. feature_cls = FeatureListNet
  425. elif feature_cls == 'dict':
  426. feature_cls = FeatureDictNet
  427. elif feature_cls == 'fx':
  428. feature_cls = FeatureGraphNet
  429. elif feature_cls == 'getter':
  430. use_getter = True
  431. feature_cls = FeatureGetterNet
  432. else:
  433. assert False, f'Unknown feature class {feature_cls}'
  434. else:
  435. feature_cls = FeatureListNet
  436. output_fmt = getattr(model, 'output_fmt', None)
  437. if output_fmt is not None and not use_getter: # don't set default for intermediate feat getter
  438. feature_cfg.setdefault('output_fmt', output_fmt)
  439. model = feature_cls(model, **feature_cfg)
  440. model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back pretrained cfg
  441. model.default_cfg = model.pretrained_cfg # alias for rename backwards compat (default_cfg -> pretrained_cfg)
  442. return model