| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585 |
- import hashlib
- import json
- import logging
- import os
- from functools import partial
- from pathlib import Path
- from tempfile import TemporaryDirectory
- from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
- import torch
- from torch.hub import HASH_REGEX, download_url_to_file, urlparse
- try:
- from torch.hub import get_dir
- except ImportError:
- from torch.hub import _get_torch_home as get_dir
- try:
- import safetensors.torch
- _has_safetensors = True
- except ImportError:
- _has_safetensors = False
- try:
- from typing import Literal
- except ImportError:
- from typing_extensions import Literal
- from timm import __version__
- from ._helpers import _torch_load, load_state_dict
- from ._pretrained import filter_pretrained_cfg
- try:
- from huggingface_hub import HfApi, hf_hub_download, model_info
- from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
- hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
- _has_hf_hub = True
- except ImportError:
- hf_hub_download = None
- _has_hf_hub = False
- _logger = logging.getLogger(__name__)
- __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
- 'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
- # Default name for a weights file hosted on the Huggingface Hub.
- HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
- HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
- HF_OPEN_CLIP_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
- HF_OPEN_CLIP_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
- def get_cache_dir(child_dir: str = ''):
- """
- Returns the location of the directory where models are cached (and creates it if necessary).
- """
- # Issue warning to move data if old env is set
- if os.getenv('TORCH_MODEL_ZOO'):
- _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
- hub_dir = get_dir()
- child_dir = () if not child_dir else (child_dir,)
- model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
- os.makedirs(model_dir, exist_ok=True)
- return model_dir
- def download_cached_file(
- url: Union[str, List[str], Tuple[str, str]],
- check_hash: bool = True,
- progress: bool = False,
- cache_dir: Optional[Union[str, Path]] = None,
- ):
- if isinstance(url, (list, tuple)):
- url, filename = url
- else:
- parts = urlparse(url)
- filename = os.path.basename(parts.path)
- if cache_dir:
- os.makedirs(cache_dir, exist_ok=True)
- else:
- cache_dir = get_cache_dir()
- cached_file = os.path.join(cache_dir, filename)
- if not os.path.exists(cached_file):
- _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
- hash_prefix = None
- if check_hash:
- r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
- hash_prefix = r.group(1) if r else None
- download_url_to_file(url, cached_file, hash_prefix, progress=progress)
- return cached_file
- def check_cached_file(
- url: Union[str, List[str], Tuple[str, str]],
- check_hash: bool = True,
- cache_dir: Optional[Union[str, Path]] = None,
- ):
- if isinstance(url, (list, tuple)):
- url, filename = url
- else:
- parts = urlparse(url)
- filename = os.path.basename(parts.path)
- if not cache_dir:
- cache_dir = get_cache_dir()
- cached_file = os.path.join(cache_dir, filename)
- if os.path.exists(cached_file):
- if check_hash:
- r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
- hash_prefix = r.group(1) if r else None
- if hash_prefix:
- with open(cached_file, 'rb') as f:
- hd = hashlib.sha256(f.read()).hexdigest()
- if hd[:len(hash_prefix)] != hash_prefix:
- return False
- return True
- return False
- def has_hf_hub(necessary: bool = False):
- if not _has_hf_hub and necessary:
- # if no HF Hub module installed, and it is necessary to continue, raise error
- raise RuntimeError(
- 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
- return _has_hf_hub
- def hf_split(hf_id: str):
- # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
- rev_split = hf_id.split('@')
- assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
- hf_model_id = rev_split[0]
- hf_revision = rev_split[-1] if len(rev_split) > 1 else None
- return hf_model_id, hf_revision
- def load_cfg_from_json(json_file: Union[str, Path]):
- with open(json_file, "r", encoding="utf-8") as reader:
- text = reader.read()
- return json.loads(text)
- def download_from_hf(
- model_id: str,
- filename: str,
- cache_dir: Optional[Union[str, Path]] = None,
- ):
- hf_model_id, hf_revision = hf_split(model_id)
- return hf_hub_download(
- hf_model_id,
- filename,
- revision=hf_revision,
- cache_dir=cache_dir,
- )
- def _parse_model_cfg(
- cfg: Dict[str, Any],
- extra_fields: Dict[str, Any],
- ) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
- """"""
- # legacy "single‑dict" → split
- if "pretrained_cfg" not in cfg:
- pretrained_cfg = cfg
- cfg = {
- "architecture": pretrained_cfg.pop("architecture"),
- "num_features": pretrained_cfg.pop("num_features", None),
- "pretrained_cfg": pretrained_cfg,
- }
- if "labels" in pretrained_cfg: # rename ‑‑> label_names
- pretrained_cfg["label_names"] = pretrained_cfg.pop("labels")
- pretrained_cfg = cfg["pretrained_cfg"]
- pretrained_cfg.update(extra_fields)
- # top‑level overrides
- if "num_classes" in cfg:
- pretrained_cfg["num_classes"] = cfg["num_classes"]
- if "label_names" in cfg:
- pretrained_cfg["label_names"] = cfg.pop("label_names")
- if "label_descriptions" in cfg:
- pretrained_cfg["label_descriptions"] = cfg.pop("label_descriptions")
- model_args = cfg.get("model_args", {})
- model_name = cfg["architecture"]
- return pretrained_cfg, model_name, model_args
- def load_model_config_from_hf(
- model_id: str,
- cache_dir: Optional[Union[str, Path]] = None,
- ):
- """Original HF‑Hub loader (unchanged download, shared parsing)."""
- assert has_hf_hub(True)
- cfg_path = download_from_hf(model_id, "config.json", cache_dir=cache_dir)
- cfg = load_cfg_from_json(cfg_path)
- return _parse_model_cfg(cfg, {"hf_hub_id": model_id, "source": "hf-hub"})
- def load_model_config_from_path(
- model_path: Union[str, Path],
- ):
- """Load from ``<model_path>/config.json`` on the local filesystem."""
- model_path = Path(model_path)
- cfg_file = model_path / "config.json"
- if not cfg_file.is_file():
- raise FileNotFoundError(f"Config file not found: {cfg_file}")
- cfg = load_cfg_from_json(cfg_file)
- extra_fields = {"file": str(model_path), "source": "local-dir"}
- return _parse_model_cfg(cfg, extra_fields=extra_fields)
- def load_state_dict_from_hf(
- model_id: str,
- filename: str = HF_WEIGHTS_NAME,
- weights_only: bool = True,
- cache_dir: Optional[Union[str, Path]] = None,
- ):
- assert has_hf_hub(True)
- hf_model_id, hf_revision = hf_split(model_id)
- # Look for .safetensors alternatives and load from it if it exists
- if _has_safetensors:
- for safe_filename in _get_safe_alternatives(filename):
- try:
- cached_safe_file = hf_hub_download(
- repo_id=hf_model_id,
- filename=safe_filename,
- revision=hf_revision,
- cache_dir=cache_dir,
- )
- _logger.info(
- f"[{model_id}] Safe alternative available for '{filename}' "
- f"(as '{safe_filename}'). Loading weights using safetensors.")
- return safetensors.torch.load_file(cached_safe_file, device="cpu")
- except EntryNotFoundError:
- pass
- # Otherwise, load using pytorch.load
- cached_file = hf_hub_download(
- hf_model_id,
- filename=filename,
- revision=hf_revision,
- cache_dir=cache_dir,
- )
- _logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
- state_dict = _torch_load(cached_file, map_location='cpu', weights_only=weights_only)
- return state_dict
- _PREFERRED_FILES = (
- "model.safetensors",
- "pytorch_model.bin",
- "pytorch_model.pth",
- "model.pth",
- "open_clip_model.safetensors",
- "open_clip_pytorch_model.safetensors",
- "open_clip_pytorch_model.bin",
- "open_clip_pytorch_model.pth",
- )
- _EXT_PRIORITY = ('.safetensors', '.pth', '.pth.tar', '.bin')
- def load_state_dict_from_path(
- path: str,
- weights_only: bool = True,
- ):
- found_file = None
- for fname in _PREFERRED_FILES:
- p = path / fname
- if p.exists():
- logging.info(f"Found preferred checkpoint: {p.name}")
- found_file = p
- break
- # fallback: first match per‑extension class
- for ext in _EXT_PRIORITY:
- files = sorted(path.glob(f"*{ext}"))
- if files:
- if len(files) > 1:
- logging.warning(
- f"Multiple {ext} checkpoints in {path}: {names}. "
- f"Using '{files[0].name}'."
- )
- found_file = files[0]
- if not found_file:
- raise RuntimeError(f"No suitable checkpoints found in {path}.")
- state_dict = load_state_dict(found_file, weights_only=weights_only)
- return state_dict
- def load_custom_from_hf(
- model_id: str,
- filename: str,
- model: torch.nn.Module,
- cache_dir: Optional[Union[str, Path]] = None,
- ):
- assert has_hf_hub(True)
- hf_model_id, hf_revision = hf_split(model_id)
- cached_file = hf_hub_download(
- hf_model_id,
- filename=filename,
- revision=hf_revision,
- cache_dir=cache_dir,
- )
- return model.load_pretrained(cached_file)
- def save_config_for_hf(
- model: torch.nn.Module,
- config_path: str,
- model_config: Optional[dict] = None,
- model_args: Optional[dict] = None
- ):
- model_config = model_config or {}
- hf_config = {}
- pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
- # set some values at root config level
- hf_config['architecture'] = pretrained_cfg.pop('architecture')
- hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
- # NOTE these attr saved for informational purposes, do not impact model build
- hf_config['num_features'] = model_config.pop('num_features', model.num_features)
- global_pool_type = model_config.pop('global_pool', getattr(model, 'global_pool', None))
- if isinstance(global_pool_type, str) and global_pool_type:
- hf_config['global_pool'] = global_pool_type
- # Save class label info
- if 'labels' in model_config:
- _logger.warning(
- "'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
- " Renaming provided 'labels' field to 'label_names'.")
- model_config.setdefault('label_names', model_config.pop('labels'))
- label_names = model_config.pop('label_names', None)
- if label_names:
- assert isinstance(label_names, (dict, list, tuple))
- # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
- # can be a dict id: name if there are id gaps, or tuple/list if no gaps.
- hf_config['label_names'] = label_names
- label_descriptions = model_config.pop('label_descriptions', None)
- if label_descriptions:
- assert isinstance(label_descriptions, dict)
- # maps label names -> descriptions
- hf_config['label_descriptions'] = label_descriptions
- if model_args:
- hf_config['model_args'] = model_args
- hf_config['pretrained_cfg'] = pretrained_cfg
- hf_config.update(model_config)
- with config_path.open('w') as f:
- json.dump(hf_config, f, indent=2)
- def save_for_hf(
- model: torch.nn.Module,
- save_directory: str,
- model_config: Optional[dict] = None,
- model_args: Optional[dict] = None,
- safe_serialization: Union[bool, Literal["both"]] = False,
- ):
- assert has_hf_hub(True)
- save_directory = Path(save_directory)
- save_directory.mkdir(exist_ok=True, parents=True)
- # Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
- tensors = model.state_dict()
- if safe_serialization is True or safe_serialization == "both":
- assert _has_safetensors, "`pip install safetensors` to use .safetensors"
- safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
- if safe_serialization is False or safe_serialization == "both":
- torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
- config_path = save_directory / 'config.json'
- save_config_for_hf(
- model,
- config_path,
- model_config=model_config,
- model_args=model_args,
- )
- def push_to_hf_hub(
- model: torch.nn.Module,
- repo_id: str,
- commit_message: str = 'Add model',
- token: Optional[str] = None,
- revision: Optional[str] = None,
- private: bool = False,
- create_pr: bool = False,
- model_config: Optional[dict] = None,
- model_card: Optional[dict] = None,
- model_args: Optional[dict] = None,
- task_name: str = 'image-classification',
- safe_serialization: Union[bool, Literal["both"]] = 'both',
- ):
- """
- Arguments:
- (...)
- safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
- Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
- Can be set to `"both"` in order to push both safe and unsafe weights.
- """
- api = HfApi(token=token, library_name="timm", library_version=__version__)
- # Create repo if it doesn't exist yet
- repo_url = api.create_repo(repo_id, private=private, exist_ok=True)
- # Can be different from the input `repo_id` if repo_owner was implicit
- repo_id = repo_url.repo_id
- # Check if README file already exist in repo
- has_readme = api.file_exists(repo_id=repo_id, filename="README.md", revision=revision)
- # Dump model and push to Hub
- with TemporaryDirectory() as tmpdir:
- # Save model weights and config.
- save_for_hf(
- model,
- tmpdir,
- model_config=model_config,
- model_args=model_args,
- safe_serialization=safe_serialization,
- )
- # Add readme if it does not exist
- if not has_readme:
- model_card = model_card or {}
- model_name = repo_id.split('/')[-1]
- readme_path = Path(tmpdir) / "README.md"
- readme_text = generate_readme(model_card, model_name, task_name=task_name)
- readme_path.write_text(readme_text)
- # Upload model and return
- return api.upload_folder(
- repo_id=repo_id,
- folder_path=tmpdir,
- revision=revision,
- create_pr=create_pr,
- commit_message=commit_message,
- )
- def generate_readme(
- model_card: dict,
- model_name: str,
- task_name: str = 'image-classification',
- ):
- tags = model_card.get('tags', None) or [task_name, 'timm', 'transformers']
- readme_text = "---\n"
- if tags:
- readme_text += "tags:\n"
- for t in tags:
- readme_text += f"- {t}\n"
- readme_text += f"pipeline_tag: {task_name}\n"
- readme_text += f"library_name: {model_card.get('library_name', 'timm')}\n"
- readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
- if 'license_name' in model_card:
- readme_text += f"license_name: {model_card.get('license_name')}\n"
- if 'license_link' in model_card:
- readme_text += f"license_link: {model_card.get('license_link')}\n"
- if 'details' in model_card and 'Dataset' in model_card['details']:
- readme_text += 'datasets:\n'
- if isinstance(model_card['details']['Dataset'], (tuple, list)):
- for d in model_card['details']['Dataset']:
- readme_text += f"- {d.lower()}\n"
- else:
- readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
- if 'Pretrain Dataset' in model_card['details']:
- if isinstance(model_card['details']['Pretrain Dataset'], (tuple, list)):
- for d in model_card['details']['Pretrain Dataset']:
- readme_text += f"- {d.lower()}\n"
- else:
- readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
- readme_text += "---\n"
- readme_text += f"# Model card for {model_name}\n"
- if 'description' in model_card:
- readme_text += f"\n{model_card['description']}\n"
- if 'details' in model_card:
- readme_text += f"\n## Model Details\n"
- for k, v in model_card['details'].items():
- if isinstance(v, (list, tuple)):
- readme_text += f"- **{k}:**\n"
- for vi in v:
- readme_text += f" - {vi}\n"
- elif isinstance(v, dict):
- readme_text += f"- **{k}:**\n"
- for ki, vi in v.items():
- readme_text += f" - {ki}: {vi}\n"
- else:
- readme_text += f"- **{k}:** {v}\n"
- if 'usage' in model_card:
- readme_text += f"\n## Model Usage\n"
- readme_text += model_card['usage']
- readme_text += '\n'
- if 'comparison' in model_card:
- readme_text += f"\n## Model Comparison\n"
- readme_text += model_card['comparison']
- readme_text += '\n'
- if 'citation' in model_card:
- readme_text += f"\n## Citation\n"
- if not isinstance(model_card['citation'], (list, tuple)):
- citations = [model_card['citation']]
- else:
- citations = model_card['citation']
- for c in citations:
- readme_text += f"```bibtex\n{c}\n```\n"
- return readme_text
- def _get_safe_alternatives(filename: str) -> Iterable[str]:
- """Returns potential safetensors alternatives for a given filename.
- Use case:
- When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
- Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
- """
- if filename == HF_WEIGHTS_NAME:
- yield HF_SAFE_WEIGHTS_NAME
- if filename == HF_OPEN_CLIP_WEIGHTS_NAME:
- yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
- if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"):
- yield filename[:-4] + ".safetensors"
- def _get_license_from_hf_hub(model_id: Optional[str], hf_hub_id: Optional[str]) -> Optional[str]:
- """Retrieve license information for a model from Hugging Face Hub.
- Fetches the license field from the model card metadata on Hugging Face Hub
- for the specified model. Returns None if the model is not found, if
- huggingface_hub is not installed, or if the model is marked as "untrained".
- Args:
- model_id: The model identifier/name. In the case of None we assume an untrained model.
- hf_hub_id: The Hugging Face Hub organization/user ID. If it is None,
- we will return None as we cannot infer the license terms.
- Returns:
- The license string in lowercase if found, None otherwise.
- Note:
- Requires huggingface_hub package to be installed. Will log a warning
- and return None if the package is not available.
- """
- if not has_hf_hub(True):
- msg = "For updated license information run `pip install huggingface_hub`."
- _logger.warning(msg=msg)
- return None
- if not (model_id and hf_hub_id):
- return None
- repo_id: str = hf_hub_id + model_id
- try:
- info = model_info(repo_id=repo_id)
- except RepositoryNotFoundError:
- msg = f"Repository {repo_id} was not found. Manual inspection of license needed."
- _logger.warning(msg=msg)
- return None
- except Exception as _:
- msg = f"Error for {repo_id}. Manual inspection of license needed."
- _logger.warning(msg=msg)
- return None
- license = info.card_data.get("license").lower() if info.card_data else None
- if license == 'other':
- name = info.card_data.get("license_name", None)
- if name is not None:
- return name
- return license
|