_pretrained.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import copy
  2. from collections import deque, defaultdict
  3. from dataclasses import dataclass, field, replace, asdict
  4. from typing import Any, Deque, Dict, Tuple, Optional, Union
  5. __all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']
  6. @dataclass
  7. class PretrainedCfg:
  8. """
  9. """
  10. # weight source locations
  11. url: Optional[Union[str, Tuple[str, str]]] = None # remote URL
  12. file: Optional[str] = None # local / shared filesystem path
  13. state_dict: Optional[Dict[str, Any]] = None # in-memory state dict
  14. hf_hub_id: Optional[str] = None # Hugging Face Hub model id ('organization/model')
  15. hf_hub_filename: Optional[str] = None # Hugging Face Hub filename (overrides default)
  16. source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
  17. architecture: Optional[str] = None # architecture variant can be set when not implicit
  18. tag: Optional[str] = None # pretrained tag of source
  19. custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
  20. # input / data config
  21. input_size: Tuple[int, int, int] = (3, 224, 224)
  22. test_input_size: Optional[Tuple[int, int, int]] = None
  23. min_input_size: Optional[Tuple[int, int, int]] = None
  24. fixed_input_size: bool = False
  25. interpolation: str = 'bicubic'
  26. crop_pct: float = 0.875
  27. test_crop_pct: Optional[float] = None
  28. crop_mode: str = 'center'
  29. mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
  30. std: Tuple[float, ...] = (0.229, 0.224, 0.225)
  31. # head / classifier config and meta-data
  32. num_classes: int = 1000
  33. label_offset: Optional[int] = None
  34. label_names: Optional[Tuple[str]] = None
  35. label_descriptions: Optional[Dict[str, str]] = None
  36. # model attributes that vary with above or required for pretrained adaptation
  37. pool_size: Optional[Tuple[int, ...]] = None
  38. test_pool_size: Optional[Tuple[int, ...]] = None
  39. first_conv: Optional[str] = None
  40. classifier: Optional[str] = None
  41. license: Optional[str] = None
  42. description: Optional[str] = None
  43. origin_url: Optional[str] = None
  44. paper_name: Optional[str] = None
  45. paper_ids: Optional[Union[str, Tuple[str]]] = None
  46. notes: Optional[Tuple[str]] = None
  47. @property
  48. def has_weights(self):
  49. return self.url or self.file or self.hf_hub_id
  50. def to_dict(self, remove_source=False, remove_null=True):
  51. return filter_pretrained_cfg(
  52. asdict(self),
  53. remove_source=remove_source,
  54. remove_null=remove_null
  55. )
  56. def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
  57. filtered_cfg = {}
  58. keep_null = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none
  59. for k, v in cfg.items():
  60. if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
  61. continue
  62. if remove_null and v is None and k not in keep_null:
  63. continue
  64. filtered_cfg[k] = v
  65. return filtered_cfg
  66. @dataclass
  67. class DefaultCfg:
  68. tags: Deque[str] = field(default_factory=deque) # priority queue of tags (first is default)
  69. cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict) # pretrained cfgs by tag
  70. is_pretrained: bool = False # at least one of the configs has a pretrained source set
  71. @property
  72. def default(self):
  73. return self.cfgs[self.tags[0]]
  74. @property
  75. def default_with_tag(self):
  76. tag = self.tags[0]
  77. return tag, self.cfgs[tag]