| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- import copy
- from collections import deque, defaultdict
- from dataclasses import dataclass, field, replace, asdict
- from typing import Any, Deque, Dict, Tuple, Optional, Union
- __all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']
- @dataclass
- class PretrainedCfg:
- """
- """
- # weight source locations
- url: Optional[Union[str, Tuple[str, str]]] = None # remote URL
- file: Optional[str] = None # local / shared filesystem path
- state_dict: Optional[Dict[str, Any]] = None # in-memory state dict
- hf_hub_id: Optional[str] = None # Hugging Face Hub model id ('organization/model')
- hf_hub_filename: Optional[str] = None # Hugging Face Hub filename (overrides default)
- source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
- architecture: Optional[str] = None # architecture variant can be set when not implicit
- tag: Optional[str] = None # pretrained tag of source
- custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
- # input / data config
- input_size: Tuple[int, int, int] = (3, 224, 224)
- test_input_size: Optional[Tuple[int, int, int]] = None
- min_input_size: Optional[Tuple[int, int, int]] = None
- fixed_input_size: bool = False
- interpolation: str = 'bicubic'
- crop_pct: float = 0.875
- test_crop_pct: Optional[float] = None
- crop_mode: str = 'center'
- mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
- std: Tuple[float, ...] = (0.229, 0.224, 0.225)
- # head / classifier config and meta-data
- num_classes: int = 1000
- label_offset: Optional[int] = None
- label_names: Optional[Tuple[str]] = None
- label_descriptions: Optional[Dict[str, str]] = None
- # model attributes that vary with above or required for pretrained adaptation
- pool_size: Optional[Tuple[int, ...]] = None
- test_pool_size: Optional[Tuple[int, ...]] = None
- first_conv: Optional[str] = None
- classifier: Optional[str] = None
- license: Optional[str] = None
- description: Optional[str] = None
- origin_url: Optional[str] = None
- paper_name: Optional[str] = None
- paper_ids: Optional[Union[str, Tuple[str]]] = None
- notes: Optional[Tuple[str]] = None
- @property
- def has_weights(self):
- return self.url or self.file or self.hf_hub_id
- def to_dict(self, remove_source=False, remove_null=True):
- return filter_pretrained_cfg(
- asdict(self),
- remove_source=remove_source,
- remove_null=remove_null
- )
- def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
- filtered_cfg = {}
- keep_null = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none
- for k, v in cfg.items():
- if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
- continue
- if remove_null and v is None and k not in keep_null:
- continue
- filtered_cfg[k] = v
- return filtered_cfg
- @dataclass
- class DefaultCfg:
- tags: Deque[str] = field(default_factory=deque) # priority queue of tags (first is default)
- cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict) # pretrained cfgs by tag
- is_pretrained: bool = False # at least one of the configs has a pretrained source set
- @property
- def default(self):
- return self.cfgs[self.tags[0]]
- @property
- def default_with_tag(self):
- tag = self.tags[0]
- return tag, self.cfgs[tag]
|