_snapshot_download.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. import os
  2. from collections.abc import Iterable
  3. from pathlib import Path
  4. from typing import Literal, overload
  5. import httpx
  6. from tqdm.auto import tqdm as base_tqdm
  7. from tqdm.contrib.concurrent import thread_map
  8. from . import constants
  9. from .errors import (
  10. DryRunError,
  11. GatedRepoError,
  12. HfHubHTTPError,
  13. LocalEntryNotFoundError,
  14. RepositoryNotFoundError,
  15. RevisionNotFoundError,
  16. )
  17. from .file_download import REGEX_COMMIT_HASH, DryRunFileInfo, hf_hub_download, repo_folder_name
  18. from .hf_api import DatasetInfo, HfApi, KernelInfo, ModelInfo, RepoFile, SpaceInfo
  19. from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
  20. from .utils.tqdm import _create_progress_bar
  21. from .utils.tqdm import tqdm as hf_tqdm
  22. logger = logging.get_logger(__name__)
  23. LARGE_REPO_THRESHOLD = 1000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough
  24. @overload
  25. def snapshot_download(
  26. repo_id: str,
  27. *,
  28. repo_type: str | None = None,
  29. revision: str | None = None,
  30. cache_dir: str | Path | None = None,
  31. local_dir: str | Path | None = None,
  32. library_name: str | None = None,
  33. library_version: str | None = None,
  34. user_agent: dict | str | None = None,
  35. etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
  36. force_download: bool = False,
  37. token: bool | str | None = None,
  38. local_files_only: bool = False,
  39. allow_patterns: list[str] | str | None = None,
  40. ignore_patterns: list[str] | str | None = None,
  41. max_workers: int = 8,
  42. tqdm_class: type[base_tqdm] | None = None,
  43. headers: dict[str, str] | None = None,
  44. endpoint: str | None = None,
  45. dry_run: Literal[False] = False,
  46. ) -> str: ...
  47. @overload
  48. def snapshot_download(
  49. repo_id: str,
  50. *,
  51. repo_type: str | None = None,
  52. revision: str | None = None,
  53. cache_dir: str | Path | None = None,
  54. local_dir: str | Path | None = None,
  55. library_name: str | None = None,
  56. library_version: str | None = None,
  57. user_agent: dict | str | None = None,
  58. etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
  59. force_download: bool = False,
  60. token: bool | str | None = None,
  61. local_files_only: bool = False,
  62. allow_patterns: list[str] | str | None = None,
  63. ignore_patterns: list[str] | str | None = None,
  64. max_workers: int = 8,
  65. tqdm_class: type[base_tqdm] | None = None,
  66. headers: dict[str, str] | None = None,
  67. endpoint: str | None = None,
  68. dry_run: Literal[True] = True,
  69. ) -> list[DryRunFileInfo]: ...
  70. @overload
  71. def snapshot_download(
  72. repo_id: str,
  73. *,
  74. repo_type: str | None = None,
  75. revision: str | None = None,
  76. cache_dir: str | Path | None = None,
  77. local_dir: str | Path | None = None,
  78. library_name: str | None = None,
  79. library_version: str | None = None,
  80. user_agent: dict | str | None = None,
  81. etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
  82. force_download: bool = False,
  83. token: bool | str | None = None,
  84. local_files_only: bool = False,
  85. allow_patterns: list[str] | str | None = None,
  86. ignore_patterns: list[str] | str | None = None,
  87. max_workers: int = 8,
  88. tqdm_class: type[base_tqdm] | None = None,
  89. headers: dict[str, str] | None = None,
  90. endpoint: str | None = None,
  91. dry_run: bool = False,
  92. ) -> str | list[DryRunFileInfo]: ...
  93. @validate_hf_hub_args
  94. def snapshot_download(
  95. repo_id: str,
  96. *,
  97. repo_type: str | None = None,
  98. revision: str | None = None,
  99. cache_dir: str | Path | None = None,
  100. local_dir: str | Path | None = None,
  101. library_name: str | None = None,
  102. library_version: str | None = None,
  103. user_agent: dict | str | None = None,
  104. etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
  105. force_download: bool = False,
  106. token: bool | str | None = None,
  107. local_files_only: bool = False,
  108. allow_patterns: list[str] | str | None = None,
  109. ignore_patterns: list[str] | str | None = None,
  110. max_workers: int = 8,
  111. tqdm_class: type[base_tqdm] | None = None,
  112. headers: dict[str, str] | None = None,
  113. endpoint: str | None = None,
  114. dry_run: bool = False,
  115. ) -> str | list[DryRunFileInfo]:
  116. """Download repo files.
  117. Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from
  118. a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order
  119. to keep their actual filename relative to that folder. You can also filter which files to download using
  120. `allow_patterns` and `ignore_patterns`.
  121. If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this
  122. option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir`
  123. to store some metadata related to the downloaded files. While this mechanism is not as robust as the main
  124. cache-system, it's optimized for regularly pulling the latest version of a repository.
  125. An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly
  126. configured. It is also not possible to filter which files to download when cloning a repository using git.
  127. Args:
  128. repo_id (`str`):
  129. A user or an organization name and a repo name separated by a `/`.
  130. repo_type (`str`, *optional*):
  131. Set to `"dataset"`, `"space"` or `"kernel"` if downloading from a dataset, space or kernel repo,
  132. `None` or `"model"` if downloading from a model. Default is `None`.
  133. revision (`str`, *optional*):
  134. An optional Git revision id which can be a branch name, a tag, or a
  135. commit hash.
  136. cache_dir (`str`, `Path`, *optional*):
  137. Path to the folder where cached files are stored.
  138. local_dir (`str` or `Path`, *optional*):
  139. If provided, the downloaded files will be placed under this directory.
  140. library_name (`str`, *optional*):
  141. The name of the library to which the object corresponds.
  142. library_version (`str`, *optional*):
  143. The version of the library.
  144. user_agent (`str`, `dict`, *optional*):
  145. The user-agent info in the form of a dictionary or a string.
  146. etag_timeout (`float`, *optional*, defaults to `10`):
  147. When fetching ETag, how many seconds to wait for the server to send
  148. data before giving up which is passed to `httpx.request`.
  149. force_download (`bool`, *optional*, defaults to `False`):
  150. Whether the file should be downloaded even if it already exists in the local cache.
  151. token (`str`, `bool`, *optional*):
  152. A token to be used for the download.
  153. - If `True`, the token is read from the HuggingFace config
  154. folder.
  155. - If a string, it's used as the authentication token.
  156. headers (`dict`, *optional*):
  157. Additional headers to include in the request. Those headers take precedence over the others.
  158. local_files_only (`bool`, *optional*, defaults to `False`):
  159. If `True`, avoid downloading the file and return the path to the
  160. local cached file if it exists.
  161. allow_patterns (`list[str]` or `str`, *optional*):
  162. If provided, only files matching at least one pattern are downloaded.
  163. ignore_patterns (`list[str]` or `str`, *optional*):
  164. If provided, files matching any of the patterns are not downloaded.
  165. max_workers (`int`, *optional*):
  166. Number of concurrent threads to download files (1 thread = 1 file download).
  167. Defaults to 8.
  168. tqdm_class (`tqdm`, *optional*):
  169. If provided, overwrites the default behavior for the progress bar. Passed
  170. argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior.
  171. Note that the `tqdm_class` is not passed to each individual download.
  172. Defaults to the custom HF progress bar that can be disabled by setting
  173. `HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
  174. dry_run (`bool`, *optional*, defaults to `False`):
  175. If `True`, perform a dry run without actually downloading the files. Returns a list of
  176. [`DryRunFileInfo`] objects containing information about what would be downloaded.
  177. Returns:
  178. `str` or list of [`DryRunFileInfo`]:
  179. - If `dry_run=False`: Local snapshot path.
  180. - If `dry_run=True`: A list of [`DryRunFileInfo`] objects containing download information.
  181. Raises:
  182. [`~utils.RepositoryNotFoundError`]
  183. If the repository to download from cannot be found. This may be because it doesn't exist,
  184. or because it is set to `private` and you do not have access.
  185. [`~utils.RevisionNotFoundError`]
  186. If the revision to download from cannot be found.
  187. [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
  188. If `token=True` and the token cannot be found.
  189. [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
  190. ETag cannot be determined.
  191. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  192. if some parameter value is invalid.
  193. """
  194. if cache_dir is None:
  195. cache_dir = constants.HF_HUB_CACHE
  196. if revision is None:
  197. revision = constants.DEFAULT_REVISION
  198. if isinstance(cache_dir, Path):
  199. cache_dir = str(cache_dir)
  200. if repo_type is None:
  201. repo_type = "model"
  202. if repo_type not in constants.REPO_TYPES_WITH_KERNEL:
  203. raise ValueError(
  204. f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES_WITH_KERNEL)}"
  205. )
  206. storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
  207. api = HfApi(
  208. library_name=library_name,
  209. library_version=library_version,
  210. user_agent=user_agent,
  211. endpoint=endpoint,
  212. headers=headers,
  213. token=token,
  214. )
  215. repo_info: ModelInfo | DatasetInfo | SpaceInfo | KernelInfo | None = None
  216. api_call_error: Exception | None = None
  217. if not local_files_only:
  218. # try/except logic to handle different errors => taken from `hf_hub_download`
  219. try:
  220. # if we have internet connection we want to list files to download
  221. repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision)
  222. except httpx.ProxyError:
  223. # Actually raise on proxy error
  224. raise
  225. except (httpx.ConnectError, httpx.TimeoutException, OfflineModeIsEnabled) as error:
  226. # Internet connection is down
  227. # => will try to use local files only
  228. api_call_error = error
  229. pass
  230. except RevisionNotFoundError:
  231. # The repo was found but the revision doesn't exist on the Hub (never existed or got deleted)
  232. raise
  233. except HfHubHTTPError as error:
  234. # Multiple reasons for an http error:
  235. # - Repository is private and invalid/missing token sent
  236. # - Repository is gated and invalid/missing token sent
  237. # - Hub is down (error 500 or 504)
  238. # => let's switch to 'local_files_only=True' to check if the files are already cached.
  239. # (if it's not the case, the error will be re-raised)
  240. api_call_error = error
  241. pass
  242. # At this stage, if `repo_info` is None it means either:
  243. # - internet connection is down
  244. # - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True)
  245. # - repo is private/gated and invalid/missing token sent
  246. # - Hub is down
  247. # => let's look if we can find the appropriate folder in the cache:
  248. # - if the specified revision is a commit hash, look inside "snapshots".
  249. # - f the specified revision is a branch or tag, look inside "refs".
  250. # => if local_dir is not None, we will return the path to the local folder if it exists.
  251. if repo_info is None:
  252. if dry_run:
  253. raise DryRunError(
  254. "Dry run cannot be performed as the repository cannot be accessed. Please check your internet connection or authentication token."
  255. ) from api_call_error
  256. # Try to get which commit hash corresponds to the specified revision
  257. commit_hash = None
  258. if REGEX_COMMIT_HASH.match(revision):
  259. commit_hash = revision
  260. else:
  261. ref_path = os.path.join(storage_folder, "refs", revision)
  262. if os.path.exists(ref_path):
  263. # retrieve commit_hash from refs file
  264. with open(ref_path) as f:
  265. commit_hash = f.read()
  266. # Try to locate snapshot folder for this commit hash
  267. if commit_hash is not None and local_dir is None:
  268. snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
  269. if os.path.exists(snapshot_folder):
  270. # Snapshot folder exists => let's return it
  271. # (but we can't check if all the files are actually there)
  272. return snapshot_folder
  273. # If local_dir is not None, return it if it exists and is not empty
  274. if local_dir is not None:
  275. local_dir = Path(local_dir)
  276. if local_dir.is_dir() and any(local_dir.iterdir()):
  277. logger.warning(
  278. f"Returning existing local_dir `{local_dir}` as remote repo cannot be accessed in `snapshot_download` ({api_call_error})."
  279. )
  280. return str(local_dir.resolve())
  281. # If we couldn't find the appropriate folder on disk, raise an error.
  282. if local_files_only:
  283. raise LocalEntryNotFoundError(
  284. "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
  285. "outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass "
  286. "'local_files_only=False' as input."
  287. )
  288. elif isinstance(api_call_error, OfflineModeIsEnabled):
  289. raise LocalEntryNotFoundError(
  290. "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
  291. "outgoing traffic has been disabled. To enable repo look-ups and downloads online, set "
  292. "'HF_HUB_OFFLINE=0' as environment variable."
  293. ) from api_call_error
  294. elif isinstance(api_call_error, (RepositoryNotFoundError, GatedRepoError)) or (
  295. isinstance(api_call_error, HfHubHTTPError) and api_call_error.response.status_code == 401
  296. ):
  297. # Repo not found, gated, or specific authentication error => let's raise the actual error
  298. raise api_call_error
  299. else:
  300. # Otherwise: most likely a connection issue or Hub downtime => let's warn the user
  301. raise LocalEntryNotFoundError(
  302. f"Got: {api_call_error.__class__.__name__}: {api_call_error}"
  303. "\nAn error happened while trying to locate the files on the Hub and we cannot find the appropriate"
  304. " snapshot folder for the specified revision on the local disk. Please check your internet connection"
  305. " and try again."
  306. ) from api_call_error
  307. # At this stage, internet connection is up and running
  308. # => let's download the files!
  309. assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."
  310. # Corner case: on very large repos, the siblings list in `repo_info` might not contain all files.
  311. # In that case, we need to use the `list_repo_tree` method to prevent caching issues.
  312. # Note: kernel repos don't expose siblings in their info response, so we always fall back to `list_repo_tree`.
  313. siblings = getattr(repo_info, "siblings", None)
  314. repo_files: Iterable[str] = [f.rfilename for f in siblings] if siblings is not None else []
  315. unreliable_nb_files = siblings is None or len(siblings) == 0 or len(siblings) > LARGE_REPO_THRESHOLD
  316. if unreliable_nb_files:
  317. logger.info(
  318. "Number of files in the repo is unreliable. Using `list_repo_tree` to ensure all files are listed."
  319. )
  320. repo_files = (
  321. f.rfilename
  322. for f in api.list_repo_tree(repo_id=repo_id, recursive=True, revision=revision, repo_type=repo_type)
  323. if isinstance(f, RepoFile)
  324. )
  325. filtered_repo_files: Iterable[str] = filter_repo_objects(
  326. items=repo_files,
  327. allow_patterns=allow_patterns,
  328. ignore_patterns=ignore_patterns,
  329. )
  330. if not unreliable_nb_files:
  331. filtered_repo_files = list(filtered_repo_files)
  332. tqdm_desc = f"Fetching {len(filtered_repo_files)} files"
  333. else:
  334. tqdm_desc = "Fetching ... files"
  335. if dry_run:
  336. tqdm_desc = "[dry-run] " + tqdm_desc
  337. commit_hash = repo_info.sha
  338. snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
  339. # if passed revision is not identical to commit_hash
  340. # then revision has to be a branch name or tag name.
  341. # In that case store a ref.
  342. if revision != commit_hash:
  343. ref_path = os.path.join(storage_folder, "refs", revision)
  344. try:
  345. os.makedirs(os.path.dirname(ref_path), exist_ok=True)
  346. with open(ref_path, "w") as f:
  347. f.write(commit_hash)
  348. except OSError as e:
  349. logger.warning(f"Ignored error while writing commit hash to {ref_path}: {e}.")
  350. results: list[str | DryRunFileInfo] = []
  351. # User can use its own tqdm class or the default one from `huggingface_hub.utils`
  352. tqdm_class = tqdm_class or hf_tqdm
  353. # Create a progress bar for the bytes downloaded
  354. # This progress bar is shared across threads/files and gets updated each time we fetch
  355. # metadata for a file.
  356. bytes_progress = _create_progress_bar(
  357. cls=tqdm_class,
  358. log_level=logger.getEffectiveLevel(),
  359. name="huggingface_hub.snapshot_download",
  360. desc="Downloading (incomplete total...)",
  361. total=0,
  362. initial=0,
  363. unit="B",
  364. unit_scale=True,
  365. )
  366. class _AggregatedTqdm:
  367. """Fake tqdm object to aggregate progress into the parent `bytes_progress` bar.
  368. In practice the `_AggregatedTqdm` object won't be displayed, it's just used to update
  369. the `bytes_progress` bar from each thread/file download.
  370. """
  371. def __init__(self, *args, **kwargs):
  372. # Adjust the total of the parent progress bar
  373. total = kwargs.pop("total", None)
  374. if total is not None:
  375. bytes_progress.total += total
  376. bytes_progress.refresh()
  377. # Adjust initial of the parent progress bar
  378. initial = kwargs.pop("initial", 0)
  379. if initial:
  380. bytes_progress.update(initial)
  381. def __enter__(self):
  382. return self
  383. def __exit__(self, exc_type, exc_value, traceback):
  384. pass
  385. def update(self, n: int | float | None = 1) -> None:
  386. bytes_progress.update(n)
  387. # we pass the commit_hash to hf_hub_download
  388. # so no network call happens if we already
  389. # have the file locally.
  390. def _inner_hf_hub_download(repo_file: str) -> None:
  391. results.append(
  392. hf_hub_download( # type: ignore
  393. repo_id,
  394. filename=repo_file,
  395. repo_type=repo_type,
  396. revision=commit_hash,
  397. endpoint=endpoint,
  398. cache_dir=cache_dir,
  399. local_dir=local_dir,
  400. library_name=library_name,
  401. library_version=library_version,
  402. user_agent=user_agent,
  403. etag_timeout=etag_timeout,
  404. force_download=force_download,
  405. token=token,
  406. headers=headers,
  407. tqdm_class=_AggregatedTqdm, # type: ignore
  408. dry_run=dry_run,
  409. )
  410. )
  411. thread_map(
  412. _inner_hf_hub_download,
  413. filtered_repo_files,
  414. desc=tqdm_desc,
  415. max_workers=max_workers,
  416. tqdm_class=tqdm_class,
  417. )
  418. bytes_progress.set_description("Download complete")
  419. if dry_run:
  420. assert all(isinstance(r, DryRunFileInfo) for r in results)
  421. return results # type: ignore
  422. if local_dir is not None:
  423. return str(os.path.realpath(local_dir))
  424. return snapshot_folder