_hub.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. import hashlib
  2. import json
  3. import logging
  4. import os
  5. from functools import partial
  6. from pathlib import Path
  7. from tempfile import TemporaryDirectory
  8. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
  9. import torch
  10. from torch.hub import HASH_REGEX, download_url_to_file, urlparse
  11. try:
  12. from torch.hub import get_dir
  13. except ImportError:
  14. from torch.hub import _get_torch_home as get_dir
  15. try:
  16. import safetensors.torch
  17. _has_safetensors = True
  18. except ImportError:
  19. _has_safetensors = False
  20. try:
  21. from typing import Literal
  22. except ImportError:
  23. from typing_extensions import Literal
  24. from timm import __version__
  25. from ._helpers import _torch_load, load_state_dict
  26. from ._pretrained import filter_pretrained_cfg
  27. try:
  28. from huggingface_hub import HfApi, hf_hub_download, model_info
  29. from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
  30. hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
  31. _has_hf_hub = True
  32. except ImportError:
  33. hf_hub_download = None
  34. _has_hf_hub = False
  35. _logger = logging.getLogger(__name__)
  36. __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
  37. 'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
  38. # Default name for a weights file hosted on the Huggingface Hub.
  39. HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
  40. HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
  41. HF_OPEN_CLIP_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
  42. HF_OPEN_CLIP_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
  43. def get_cache_dir(child_dir: str = ''):
  44. """
  45. Returns the location of the directory where models are cached (and creates it if necessary).
  46. """
  47. # Issue warning to move data if old env is set
  48. if os.getenv('TORCH_MODEL_ZOO'):
  49. _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
  50. hub_dir = get_dir()
  51. child_dir = () if not child_dir else (child_dir,)
  52. model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
  53. os.makedirs(model_dir, exist_ok=True)
  54. return model_dir
  55. def download_cached_file(
  56. url: Union[str, List[str], Tuple[str, str]],
  57. check_hash: bool = True,
  58. progress: bool = False,
  59. cache_dir: Optional[Union[str, Path]] = None,
  60. ):
  61. if isinstance(url, (list, tuple)):
  62. url, filename = url
  63. else:
  64. parts = urlparse(url)
  65. filename = os.path.basename(parts.path)
  66. if cache_dir:
  67. os.makedirs(cache_dir, exist_ok=True)
  68. else:
  69. cache_dir = get_cache_dir()
  70. cached_file = os.path.join(cache_dir, filename)
  71. if not os.path.exists(cached_file):
  72. _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
  73. hash_prefix = None
  74. if check_hash:
  75. r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
  76. hash_prefix = r.group(1) if r else None
  77. download_url_to_file(url, cached_file, hash_prefix, progress=progress)
  78. return cached_file
  79. def check_cached_file(
  80. url: Union[str, List[str], Tuple[str, str]],
  81. check_hash: bool = True,
  82. cache_dir: Optional[Union[str, Path]] = None,
  83. ):
  84. if isinstance(url, (list, tuple)):
  85. url, filename = url
  86. else:
  87. parts = urlparse(url)
  88. filename = os.path.basename(parts.path)
  89. if not cache_dir:
  90. cache_dir = get_cache_dir()
  91. cached_file = os.path.join(cache_dir, filename)
  92. if os.path.exists(cached_file):
  93. if check_hash:
  94. r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
  95. hash_prefix = r.group(1) if r else None
  96. if hash_prefix:
  97. with open(cached_file, 'rb') as f:
  98. hd = hashlib.sha256(f.read()).hexdigest()
  99. if hd[:len(hash_prefix)] != hash_prefix:
  100. return False
  101. return True
  102. return False
  103. def has_hf_hub(necessary: bool = False):
  104. if not _has_hf_hub and necessary:
  105. # if no HF Hub module installed, and it is necessary to continue, raise error
  106. raise RuntimeError(
  107. 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
  108. return _has_hf_hub
  109. def hf_split(hf_id: str):
  110. # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
  111. rev_split = hf_id.split('@')
  112. assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
  113. hf_model_id = rev_split[0]
  114. hf_revision = rev_split[-1] if len(rev_split) > 1 else None
  115. return hf_model_id, hf_revision
  116. def load_cfg_from_json(json_file: Union[str, Path]):
  117. with open(json_file, "r", encoding="utf-8") as reader:
  118. text = reader.read()
  119. return json.loads(text)
  120. def download_from_hf(
  121. model_id: str,
  122. filename: str,
  123. cache_dir: Optional[Union[str, Path]] = None,
  124. ):
  125. hf_model_id, hf_revision = hf_split(model_id)
  126. return hf_hub_download(
  127. hf_model_id,
  128. filename,
  129. revision=hf_revision,
  130. cache_dir=cache_dir,
  131. )
  132. def _parse_model_cfg(
  133. cfg: Dict[str, Any],
  134. extra_fields: Dict[str, Any],
  135. ) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
  136. """"""
  137. # legacy "single‑dict" → split
  138. if "pretrained_cfg" not in cfg:
  139. pretrained_cfg = cfg
  140. cfg = {
  141. "architecture": pretrained_cfg.pop("architecture"),
  142. "num_features": pretrained_cfg.pop("num_features", None),
  143. "pretrained_cfg": pretrained_cfg,
  144. }
  145. if "labels" in pretrained_cfg: # rename ‑‑> label_names
  146. pretrained_cfg["label_names"] = pretrained_cfg.pop("labels")
  147. pretrained_cfg = cfg["pretrained_cfg"]
  148. pretrained_cfg.update(extra_fields)
  149. # top‑level overrides
  150. if "num_classes" in cfg:
  151. pretrained_cfg["num_classes"] = cfg["num_classes"]
  152. if "label_names" in cfg:
  153. pretrained_cfg["label_names"] = cfg.pop("label_names")
  154. if "label_descriptions" in cfg:
  155. pretrained_cfg["label_descriptions"] = cfg.pop("label_descriptions")
  156. model_args = cfg.get("model_args", {})
  157. model_name = cfg["architecture"]
  158. return pretrained_cfg, model_name, model_args
  159. def load_model_config_from_hf(
  160. model_id: str,
  161. cache_dir: Optional[Union[str, Path]] = None,
  162. ):
  163. """Original HF‑Hub loader (unchanged download, shared parsing)."""
  164. assert has_hf_hub(True)
  165. cfg_path = download_from_hf(model_id, "config.json", cache_dir=cache_dir)
  166. cfg = load_cfg_from_json(cfg_path)
  167. return _parse_model_cfg(cfg, {"hf_hub_id": model_id, "source": "hf-hub"})
  168. def load_model_config_from_path(
  169. model_path: Union[str, Path],
  170. ):
  171. """Load from ``<model_path>/config.json`` on the local filesystem."""
  172. model_path = Path(model_path)
  173. cfg_file = model_path / "config.json"
  174. if not cfg_file.is_file():
  175. raise FileNotFoundError(f"Config file not found: {cfg_file}")
  176. cfg = load_cfg_from_json(cfg_file)
  177. extra_fields = {"file": str(model_path), "source": "local-dir"}
  178. return _parse_model_cfg(cfg, extra_fields=extra_fields)
  179. def load_state_dict_from_hf(
  180. model_id: str,
  181. filename: str = HF_WEIGHTS_NAME,
  182. weights_only: bool = True,
  183. cache_dir: Optional[Union[str, Path]] = None,
  184. ):
  185. assert has_hf_hub(True)
  186. hf_model_id, hf_revision = hf_split(model_id)
  187. # Look for .safetensors alternatives and load from it if it exists
  188. if _has_safetensors:
  189. for safe_filename in _get_safe_alternatives(filename):
  190. try:
  191. cached_safe_file = hf_hub_download(
  192. repo_id=hf_model_id,
  193. filename=safe_filename,
  194. revision=hf_revision,
  195. cache_dir=cache_dir,
  196. )
  197. _logger.info(
  198. f"[{model_id}] Safe alternative available for '{filename}' "
  199. f"(as '{safe_filename}'). Loading weights using safetensors.")
  200. return safetensors.torch.load_file(cached_safe_file, device="cpu")
  201. except EntryNotFoundError:
  202. pass
  203. # Otherwise, load using pytorch.load
  204. cached_file = hf_hub_download(
  205. hf_model_id,
  206. filename=filename,
  207. revision=hf_revision,
  208. cache_dir=cache_dir,
  209. )
  210. _logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
  211. state_dict = _torch_load(cached_file, map_location='cpu', weights_only=weights_only)
  212. return state_dict
  213. _PREFERRED_FILES = (
  214. "model.safetensors",
  215. "pytorch_model.bin",
  216. "pytorch_model.pth",
  217. "model.pth",
  218. "open_clip_model.safetensors",
  219. "open_clip_pytorch_model.safetensors",
  220. "open_clip_pytorch_model.bin",
  221. "open_clip_pytorch_model.pth",
  222. )
  223. _EXT_PRIORITY = ('.safetensors', '.pth', '.pth.tar', '.bin')
  224. def load_state_dict_from_path(
  225. path: str,
  226. weights_only: bool = True,
  227. ):
  228. found_file = None
  229. for fname in _PREFERRED_FILES:
  230. p = path / fname
  231. if p.exists():
  232. logging.info(f"Found preferred checkpoint: {p.name}")
  233. found_file = p
  234. break
  235. # fallback: first match per‑extension class
  236. for ext in _EXT_PRIORITY:
  237. files = sorted(path.glob(f"*{ext}"))
  238. if files:
  239. if len(files) > 1:
  240. logging.warning(
  241. f"Multiple {ext} checkpoints in {path}: {names}. "
  242. f"Using '{files[0].name}'."
  243. )
  244. found_file = files[0]
  245. if not found_file:
  246. raise RuntimeError(f"No suitable checkpoints found in {path}.")
  247. state_dict = load_state_dict(found_file, weights_only=weights_only)
  248. return state_dict
  249. def load_custom_from_hf(
  250. model_id: str,
  251. filename: str,
  252. model: torch.nn.Module,
  253. cache_dir: Optional[Union[str, Path]] = None,
  254. ):
  255. assert has_hf_hub(True)
  256. hf_model_id, hf_revision = hf_split(model_id)
  257. cached_file = hf_hub_download(
  258. hf_model_id,
  259. filename=filename,
  260. revision=hf_revision,
  261. cache_dir=cache_dir,
  262. )
  263. return model.load_pretrained(cached_file)
  264. def save_config_for_hf(
  265. model: torch.nn.Module,
  266. config_path: str,
  267. model_config: Optional[dict] = None,
  268. model_args: Optional[dict] = None
  269. ):
  270. model_config = model_config or {}
  271. hf_config = {}
  272. pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
  273. # set some values at root config level
  274. hf_config['architecture'] = pretrained_cfg.pop('architecture')
  275. hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
  276. # NOTE these attr saved for informational purposes, do not impact model build
  277. hf_config['num_features'] = model_config.pop('num_features', model.num_features)
  278. global_pool_type = model_config.pop('global_pool', getattr(model, 'global_pool', None))
  279. if isinstance(global_pool_type, str) and global_pool_type:
  280. hf_config['global_pool'] = global_pool_type
  281. # Save class label info
  282. if 'labels' in model_config:
  283. _logger.warning(
  284. "'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
  285. " Renaming provided 'labels' field to 'label_names'.")
  286. model_config.setdefault('label_names', model_config.pop('labels'))
  287. label_names = model_config.pop('label_names', None)
  288. if label_names:
  289. assert isinstance(label_names, (dict, list, tuple))
  290. # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
  291. # can be a dict id: name if there are id gaps, or tuple/list if no gaps.
  292. hf_config['label_names'] = label_names
  293. label_descriptions = model_config.pop('label_descriptions', None)
  294. if label_descriptions:
  295. assert isinstance(label_descriptions, dict)
  296. # maps label names -> descriptions
  297. hf_config['label_descriptions'] = label_descriptions
  298. if model_args:
  299. hf_config['model_args'] = model_args
  300. hf_config['pretrained_cfg'] = pretrained_cfg
  301. hf_config.update(model_config)
  302. with config_path.open('w') as f:
  303. json.dump(hf_config, f, indent=2)
  304. def save_for_hf(
  305. model: torch.nn.Module,
  306. save_directory: str,
  307. model_config: Optional[dict] = None,
  308. model_args: Optional[dict] = None,
  309. safe_serialization: Union[bool, Literal["both"]] = False,
  310. ):
  311. assert has_hf_hub(True)
  312. save_directory = Path(save_directory)
  313. save_directory.mkdir(exist_ok=True, parents=True)
  314. # Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
  315. tensors = model.state_dict()
  316. if safe_serialization is True or safe_serialization == "both":
  317. assert _has_safetensors, "`pip install safetensors` to use .safetensors"
  318. safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
  319. if safe_serialization is False or safe_serialization == "both":
  320. torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
  321. config_path = save_directory / 'config.json'
  322. save_config_for_hf(
  323. model,
  324. config_path,
  325. model_config=model_config,
  326. model_args=model_args,
  327. )
  328. def push_to_hf_hub(
  329. model: torch.nn.Module,
  330. repo_id: str,
  331. commit_message: str = 'Add model',
  332. token: Optional[str] = None,
  333. revision: Optional[str] = None,
  334. private: bool = False,
  335. create_pr: bool = False,
  336. model_config: Optional[dict] = None,
  337. model_card: Optional[dict] = None,
  338. model_args: Optional[dict] = None,
  339. task_name: str = 'image-classification',
  340. safe_serialization: Union[bool, Literal["both"]] = 'both',
  341. ):
  342. """
  343. Arguments:
  344. (...)
  345. safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
  346. Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
  347. Can be set to `"both"` in order to push both safe and unsafe weights.
  348. """
  349. api = HfApi(token=token, library_name="timm", library_version=__version__)
  350. # Create repo if it doesn't exist yet
  351. repo_url = api.create_repo(repo_id, private=private, exist_ok=True)
  352. # Can be different from the input `repo_id` if repo_owner was implicit
  353. repo_id = repo_url.repo_id
  354. # Check if README file already exist in repo
  355. has_readme = api.file_exists(repo_id=repo_id, filename="README.md", revision=revision)
  356. # Dump model and push to Hub
  357. with TemporaryDirectory() as tmpdir:
  358. # Save model weights and config.
  359. save_for_hf(
  360. model,
  361. tmpdir,
  362. model_config=model_config,
  363. model_args=model_args,
  364. safe_serialization=safe_serialization,
  365. )
  366. # Add readme if it does not exist
  367. if not has_readme:
  368. model_card = model_card or {}
  369. model_name = repo_id.split('/')[-1]
  370. readme_path = Path(tmpdir) / "README.md"
  371. readme_text = generate_readme(model_card, model_name, task_name=task_name)
  372. readme_path.write_text(readme_text)
  373. # Upload model and return
  374. return api.upload_folder(
  375. repo_id=repo_id,
  376. folder_path=tmpdir,
  377. revision=revision,
  378. create_pr=create_pr,
  379. commit_message=commit_message,
  380. )
  381. def generate_readme(
  382. model_card: dict,
  383. model_name: str,
  384. task_name: str = 'image-classification',
  385. ):
  386. tags = model_card.get('tags', None) or [task_name, 'timm', 'transformers']
  387. readme_text = "---\n"
  388. if tags:
  389. readme_text += "tags:\n"
  390. for t in tags:
  391. readme_text += f"- {t}\n"
  392. readme_text += f"pipeline_tag: {task_name}\n"
  393. readme_text += f"library_name: {model_card.get('library_name', 'timm')}\n"
  394. readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
  395. if 'license_name' in model_card:
  396. readme_text += f"license_name: {model_card.get('license_name')}\n"
  397. if 'license_link' in model_card:
  398. readme_text += f"license_link: {model_card.get('license_link')}\n"
  399. if 'details' in model_card and 'Dataset' in model_card['details']:
  400. readme_text += 'datasets:\n'
  401. if isinstance(model_card['details']['Dataset'], (tuple, list)):
  402. for d in model_card['details']['Dataset']:
  403. readme_text += f"- {d.lower()}\n"
  404. else:
  405. readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
  406. if 'Pretrain Dataset' in model_card['details']:
  407. if isinstance(model_card['details']['Pretrain Dataset'], (tuple, list)):
  408. for d in model_card['details']['Pretrain Dataset']:
  409. readme_text += f"- {d.lower()}\n"
  410. else:
  411. readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
  412. readme_text += "---\n"
  413. readme_text += f"# Model card for {model_name}\n"
  414. if 'description' in model_card:
  415. readme_text += f"\n{model_card['description']}\n"
  416. if 'details' in model_card:
  417. readme_text += f"\n## Model Details\n"
  418. for k, v in model_card['details'].items():
  419. if isinstance(v, (list, tuple)):
  420. readme_text += f"- **{k}:**\n"
  421. for vi in v:
  422. readme_text += f" - {vi}\n"
  423. elif isinstance(v, dict):
  424. readme_text += f"- **{k}:**\n"
  425. for ki, vi in v.items():
  426. readme_text += f" - {ki}: {vi}\n"
  427. else:
  428. readme_text += f"- **{k}:** {v}\n"
  429. if 'usage' in model_card:
  430. readme_text += f"\n## Model Usage\n"
  431. readme_text += model_card['usage']
  432. readme_text += '\n'
  433. if 'comparison' in model_card:
  434. readme_text += f"\n## Model Comparison\n"
  435. readme_text += model_card['comparison']
  436. readme_text += '\n'
  437. if 'citation' in model_card:
  438. readme_text += f"\n## Citation\n"
  439. if not isinstance(model_card['citation'], (list, tuple)):
  440. citations = [model_card['citation']]
  441. else:
  442. citations = model_card['citation']
  443. for c in citations:
  444. readme_text += f"```bibtex\n{c}\n```\n"
  445. return readme_text
  446. def _get_safe_alternatives(filename: str) -> Iterable[str]:
  447. """Returns potential safetensors alternatives for a given filename.
  448. Use case:
  449. When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
  450. Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
  451. """
  452. if filename == HF_WEIGHTS_NAME:
  453. yield HF_SAFE_WEIGHTS_NAME
  454. if filename == HF_OPEN_CLIP_WEIGHTS_NAME:
  455. yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
  456. if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"):
  457. yield filename[:-4] + ".safetensors"
  458. def _get_license_from_hf_hub(model_id: Optional[str], hf_hub_id: Optional[str]) -> Optional[str]:
  459. """Retrieve license information for a model from Hugging Face Hub.
  460. Fetches the license field from the model card metadata on Hugging Face Hub
  461. for the specified model. Returns None if the model is not found, if
  462. huggingface_hub is not installed, or if the model is marked as "untrained".
  463. Args:
  464. model_id: The model identifier/name. In the case of None we assume an untrained model.
  465. hf_hub_id: The Hugging Face Hub organization/user ID. If it is None,
  466. we will return None as we cannot infer the license terms.
  467. Returns:
  468. The license string in lowercase if found, None otherwise.
  469. Note:
  470. Requires huggingface_hub package to be installed. Will log a warning
  471. and return None if the package is not available.
  472. """
  473. if not has_hf_hub(True):
  474. msg = "For updated license information run `pip install huggingface_hub`."
  475. _logger.warning(msg=msg)
  476. return None
  477. if not (model_id and hf_hub_id):
  478. return None
  479. repo_id: str = hf_hub_id + model_id
  480. try:
  481. info = model_info(repo_id=repo_id)
  482. except RepositoryNotFoundError:
  483. msg = f"Repository {repo_id} was not found. Manual inspection of license needed."
  484. _logger.warning(msg=msg)
  485. return None
  486. except Exception as _:
  487. msg = f"Error for {repo_id}. Manual inspection of license needed."
  488. _logger.warning(msg=msg)
  489. return None
  490. license = info.card_data.get("license").lower() if info.card_data else None
  491. if license == 'other':
  492. name = info.card_data.get("license_name", None)
  493. if name is not None:
  494. return name
  495. return license