_factory.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import os
  2. from pathlib import Path
  3. from typing import Any, Dict, Optional, Tuple, Union
  4. from urllib.parse import urlsplit
  5. from torch import nn
  6. from timm.layers import set_layer_config
  7. from ._helpers import load_checkpoint
  8. from ._hub import load_model_config_from_hf, load_model_config_from_path
  9. from ._pretrained import PretrainedCfg
  10. from ._registry import is_model, model_entrypoint, split_model_name_tag
  11. __all__ = ['parse_model_name', 'safe_model_name', 'create_model']
  12. def parse_model_name(model_name: str) -> Tuple[Optional[str], str]:
  13. """Parse source and name from potentially prefixed model name."""
  14. if model_name.startswith('hf_hub'):
  15. # NOTE for backwards compat, deprecate hf_hub use
  16. model_name = model_name.replace('hf_hub', 'hf-hub')
  17. parsed = urlsplit(model_name)
  18. assert parsed.scheme in ('', 'hf-hub', 'local-dir')
  19. if parsed.scheme == 'hf-hub':
  20. # FIXME may use fragment as revision, currently `@` in URI path
  21. return parsed.scheme, parsed.path
  22. elif parsed.scheme == 'local-dir':
  23. return parsed.scheme, parsed.path
  24. else:
  25. model_name = os.path.split(parsed.path)[-1]
  26. return None, model_name
  27. def safe_model_name(model_name: str, remove_source: bool = True) -> str:
  28. """Return a filename / path safe model name."""
  29. def make_safe(name: str) -> str:
  30. return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
  31. if remove_source:
  32. model_name = parse_model_name(model_name)[-1]
  33. return make_safe(model_name)
  34. def create_model(
  35. model_name: str,
  36. pretrained: bool = False,
  37. pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
  38. pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
  39. checkpoint_path: Optional[Union[str, Path]] = None,
  40. cache_dir: Optional[Union[str, Path]] = None,
  41. scriptable: Optional[bool] = None,
  42. exportable: Optional[bool] = None,
  43. no_jit: Optional[bool] = None,
  44. **kwargs: Any,
  45. ) -> nn.Module:
  46. """Create a model.
  47. Lookup model's entrypoint function and pass relevant args to create a new model.
  48. Tip:
  49. **kwargs will be passed through entrypoint fn to ``timm.models.build_model_with_cfg()``
  50. and then the model class __init__(). kwargs values set to None are pruned before passing.
  51. Args:
  52. model_name: Name of model to instantiate.
  53. pretrained: If set to `True`, load pretrained ImageNet-1k weights.
  54. pretrained_cfg: Pass in an external pretrained_cfg for model.
  55. pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
  56. checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
  57. cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints.
  58. scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
  59. exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
  60. no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
  61. Keyword Args:
  62. drop_rate (float): Classifier dropout rate for training.
  63. drop_path_rate (float): Stochastic depth drop rate for training.
  64. global_pool (str): Classifier global pooling type.
  65. Example:
  66. ```py
  67. >>> from timm import create_model
  68. >>> # Create a MobileNetV3-Large model with no pretrained weights.
  69. >>> model = create_model('mobilenetv3_large_100')
  70. >>> # Create a MobileNetV3-Large model with pretrained weights.
  71. >>> model = create_model('mobilenetv3_large_100', pretrained=True)
  72. >>> model.num_classes
  73. 1000
  74. >>> # Create a MobileNetV3-Large model with pretrained weights and a new head with 10 classes.
  75. >>> model = create_model('mobilenetv3_large_100', pretrained=True, num_classes=10)
  76. >>> model.num_classes
  77. 10
  78. >>> # Create a Dinov2 small model with pretrained weights and save weights in a custom directory.
  79. >>> model = create_model('vit_small_patch14_dinov2.lvd142m', pretrained=True, cache_dir="/data/my-models")
  80. >>> # Data will be stored at `/data/my-models/models--timm--vit_small_patch14_dinov2.lvd142m/`
  81. ```
  82. """
  83. # Parameters that aren't supported by all models or are intended to only override model defaults if set
  84. # should default to None in command line args/cfg. Remove them if they are present and not set so that
  85. # non-supporting models don't break and default args remain in effect.
  86. kwargs = {k: v for k, v in kwargs.items() if v is not None}
  87. model_source, model_id = parse_model_name(model_name)
  88. if model_source:
  89. assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
  90. if model_source == 'hf-hub':
  91. # For model names specified in the form `hf-hub:path/architecture_name@revision`,
  92. # load model weights + pretrained_cfg from Hugging Face hub.
  93. pretrained_cfg, model_name, model_args = load_model_config_from_hf(
  94. model_id,
  95. cache_dir=cache_dir,
  96. )
  97. elif model_source == 'local-dir':
  98. pretrained_cfg, model_name, model_args = load_model_config_from_path(
  99. model_id,
  100. )
  101. else:
  102. assert False, f'Unknown model_source {model_source}'
  103. if model_args:
  104. for k, v in model_args.items():
  105. kwargs.setdefault(k, v)
  106. else:
  107. model_name, pretrained_tag = split_model_name_tag(model_id)
  108. if pretrained_tag and not pretrained_cfg:
  109. # a valid pretrained_cfg argument takes priority over tag in model name
  110. pretrained_cfg = pretrained_tag
  111. if not is_model(model_name):
  112. raise RuntimeError('Unknown model (%s)' % model_name)
  113. create_fn = model_entrypoint(model_name)
  114. with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
  115. model = create_fn(
  116. pretrained=pretrained,
  117. pretrained_cfg=pretrained_cfg,
  118. pretrained_cfg_overlay=pretrained_cfg_overlay,
  119. cache_dir=cache_dir,
  120. **kwargs,
  121. )
  122. if checkpoint_path:
  123. load_checkpoint(model, checkpoint_path)
  124. return model