hub.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Hub utilities: utilities related to download and cache models
  16. """
  17. import json
  18. import os
  19. import re
  20. import sys
  21. import tempfile
  22. from concurrent import futures
  23. from pathlib import Path
  24. from typing import TypedDict
  25. from uuid import uuid4
  26. import httpx
  27. from huggingface_hub import (
  28. _CACHED_NO_EXIST,
  29. CommitOperationAdd,
  30. ModelCard,
  31. ModelCardData,
  32. constants,
  33. create_branch,
  34. create_commit,
  35. create_repo,
  36. hf_hub_download,
  37. hf_hub_url,
  38. is_offline_mode,
  39. list_repo_tree,
  40. snapshot_download,
  41. try_to_load_from_cache,
  42. )
  43. from huggingface_hub.file_download import REGEX_COMMIT_HASH
  44. from huggingface_hub.utils import (
  45. EntryNotFoundError,
  46. GatedRepoError,
  47. HfHubHTTPError,
  48. LocalEntryNotFoundError,
  49. OfflineModeIsEnabled,
  50. RepositoryNotFoundError,
  51. RevisionNotFoundError,
  52. build_hf_headers,
  53. get_session,
  54. hf_raise_for_status,
  55. )
  56. from . import __version__, logging
  57. from .import_utils import (
  58. ENV_VARS_TRUE_VALUES,
  59. get_torch_version,
  60. is_torch_available,
  61. is_training_run_on_sagemaker,
  62. )
  63. LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE = "chat_template.json"
  64. CHAT_TEMPLATE_FILE = "chat_template.jinja"
  65. CHAT_TEMPLATE_DIR = "additional_chat_templates"
  66. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  67. class DownloadKwargs(TypedDict, total=False):
  68. cache_dir: str | os.PathLike | None
  69. force_download: bool
  70. proxies: dict[str, str] | None
  71. local_files_only: bool
  72. token: str | bool | None
  73. revision: str | None
  74. subfolder: str
  75. commit_hash: str | None
  76. tqdm_class: type | None
  77. # Determine default cache directory.
  78. # The best way to set the cache path is with the environment variable HF_HOME. For more details, check out this
  79. # documentation page: https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables.
  80. HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(constants.HF_HOME, "modules"))
  81. TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
  82. SESSION_ID = uuid4().hex
  83. S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
  84. CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
  85. def _get_cache_file_to_return(
  86. path_or_repo_id: str,
  87. full_filename: str,
  88. cache_dir: str | Path | None = None,
  89. revision: str | None = None,
  90. repo_type: str | None = None,
  91. ):
  92. # We try to see if we have a cached version (not up to date):
  93. resolved_file = try_to_load_from_cache(
  94. path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision, repo_type=repo_type
  95. )
  96. if resolved_file is not None and resolved_file != _CACHED_NO_EXIST:
  97. return resolved_file
  98. return None
  99. def list_repo_templates(
  100. repo_id: str,
  101. *,
  102. local_files_only: bool,
  103. revision: str | None = None,
  104. cache_dir: str | None = None,
  105. token: str | bool | None = None,
  106. ) -> list[str]:
  107. """List template files from a repo.
  108. A template is a jinja file located under the `additional_chat_templates/` folder.
  109. If working in offline mode or if internet is down, the method will list jinja template from the local cache - if any.
  110. """
  111. if not local_files_only:
  112. try:
  113. return [
  114. entry.path.removeprefix(f"{CHAT_TEMPLATE_DIR}/")
  115. for entry in list_repo_tree(
  116. repo_id=repo_id,
  117. revision=revision,
  118. path_in_repo=CHAT_TEMPLATE_DIR,
  119. recursive=False,
  120. token=token,
  121. )
  122. if entry.path.endswith(".jinja")
  123. ]
  124. except (GatedRepoError, RepositoryNotFoundError, RevisionNotFoundError):
  125. raise # valid errors => do not catch
  126. except (HfHubHTTPError, OfflineModeIsEnabled, httpx.NetworkError):
  127. pass # offline mode, internet down, etc. => try local files
  128. # check local files
  129. try:
  130. snapshot_dir = snapshot_download(
  131. repo_id=repo_id, revision=revision, cache_dir=cache_dir, local_files_only=True
  132. )
  133. except LocalEntryNotFoundError: # No local repo means no local files
  134. return []
  135. templates_dir = Path(snapshot_dir, CHAT_TEMPLATE_DIR)
  136. if not templates_dir.is_dir():
  137. return []
  138. return [entry.stem for entry in templates_dir.iterdir() if entry.is_file() and entry.name.endswith(".jinja")]
  139. def define_sagemaker_information():
  140. try:
  141. instance_data = httpx.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
  142. dlc_container_used = instance_data["Image"]
  143. dlc_tag = instance_data["Image"].split(":")[1]
  144. except Exception:
  145. dlc_container_used = None
  146. dlc_tag = None
  147. sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}"))
  148. runs_distributed_training = "sagemaker_distributed_dataparallel_enabled" in sagemaker_params
  149. training_job_arn = os.getenv("TRAINING_JOB_ARN")
  150. account_id = training_job_arn.split(":")[4] if training_job_arn is not None else None
  151. sagemaker_object = {
  152. "sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None),
  153. "sm_region": os.getenv("AWS_REGION", None),
  154. "sm_number_gpu": os.getenv("SM_NUM_GPUS", "0"),
  155. "sm_number_cpu": os.getenv("SM_NUM_CPUS", "0"),
  156. "sm_distributed_training": runs_distributed_training,
  157. "sm_deep_learning_container": dlc_container_used,
  158. "sm_deep_learning_container_tag": dlc_tag,
  159. "sm_account_id": account_id,
  160. }
  161. return sagemaker_object
  162. def http_user_agent(user_agent: dict | str | None = None) -> str:
  163. """
  164. Formats a user-agent string with basic info about a request.
  165. """
  166. ua = f"transformers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
  167. if is_torch_available():
  168. ua += f"; torch/{get_torch_version()}"
  169. if constants.HF_HUB_DISABLE_TELEMETRY:
  170. return ua + "; telemetry/off"
  171. if is_training_run_on_sagemaker():
  172. ua += "; " + "; ".join(f"{k}/{v}" for k, v in define_sagemaker_information().items())
  173. # CI will set this value to True
  174. if os.environ.get("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
  175. ua += "; is_ci/true"
  176. if isinstance(user_agent, dict):
  177. ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
  178. elif isinstance(user_agent, str):
  179. ua += "; " + user_agent
  180. return ua
  181. def extract_commit_hash(resolved_file: str | None, commit_hash: str | None) -> str | None:
  182. """
  183. Extracts the commit hash from a resolved filename toward a cache file.
  184. """
  185. if resolved_file is None or commit_hash is not None:
  186. return commit_hash
  187. resolved_file = str(Path(resolved_file).as_posix())
  188. search = re.search(r"snapshots/([^/]+)/", resolved_file)
  189. if search is None:
  190. return None
  191. commit_hash = search.groups()[0]
  192. return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
  193. def cached_file(
  194. path_or_repo_id: str | os.PathLike,
  195. filename: str,
  196. **kwargs,
  197. ) -> str | None:
  198. """
  199. Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
  200. Args:
  201. path_or_repo_id (`str` or `os.PathLike`):
  202. This can be either:
  203. - a string, the *model id* of a model repo on huggingface.co.
  204. - a path to a *directory* potentially containing the file.
  205. filename (`str`):
  206. The name of the file to locate in `path_or_repo`.
  207. cache_dir (`str` or `os.PathLike`, *optional*):
  208. Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
  209. cache should not be used.
  210. force_download (`bool`, *optional*, defaults to `False`):
  211. Whether or not to force to (re-)download the configuration files and override the cached versions if they
  212. exist.
  213. proxies (`dict[str, str]`, *optional*):
  214. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  215. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
  216. token (`str` or *bool*, *optional*):
  217. The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
  218. when running `hf auth login` (stored in `~/.huggingface`).
  219. revision (`str`, *optional*, defaults to `"main"`):
  220. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  221. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  222. identifier allowed by git.
  223. local_files_only (`bool`, *optional*, defaults to `False`):
  224. If `True`, will only try to load the tokenizer configuration from local files.
  225. subfolder (`str`, *optional*, defaults to `""`):
  226. In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
  227. specify the folder name here.
  228. repo_type (`str`, *optional*):
  229. Specify the repo type (useful when downloading from a space for instance).
  230. <Tip>
  231. Passing `token=True` is required when you want to use a private model.
  232. </Tip>
  233. Returns:
  234. `Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo).
  235. Examples:
  236. ```python
  237. # Download a model weight from the Hub and cache it.
  238. model_weights_file = cached_file("google-bert/bert-base-uncased", "pytorch_model.bin")
  239. ```
  240. """
  241. file = cached_files(path_or_repo_id=path_or_repo_id, filenames=[filename], **kwargs)
  242. file = file[0] if file is not None else file
  243. return file
  244. def cached_files(
  245. path_or_repo_id: str | os.PathLike,
  246. filenames: list[str],
  247. cache_dir: str | os.PathLike | None = None,
  248. force_download: bool = False,
  249. proxies: dict[str, str] | None = None,
  250. token: bool | str | None = None,
  251. revision: str | None = None,
  252. local_files_only: bool = False,
  253. subfolder: str = "",
  254. repo_type: str | None = None,
  255. user_agent: str | dict[str, str] | None = None,
  256. _raise_exceptions_for_gated_repo: bool = True,
  257. _raise_exceptions_for_missing_entries: bool = True,
  258. _raise_exceptions_for_connection_errors: bool = True,
  259. _commit_hash: str | None = None,
  260. tqdm_class: type | None = None,
  261. **deprecated_kwargs,
  262. ) -> list[str] | None:
  263. """
  264. Tries to locate several files in a local folder and repo, downloads and cache them if necessary.
  265. Args:
  266. path_or_repo_id (`str` or `os.PathLike`):
  267. This can be either:
  268. - a string, the *model id* of a model repo on huggingface.co.
  269. - a path to a *directory* potentially containing the file.
  270. filenames (`list[str]`):
  271. The name of all the files to locate in `path_or_repo`.
  272. cache_dir (`str` or `os.PathLike`, *optional*):
  273. Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
  274. cache should not be used.
  275. force_download (`bool`, *optional*, defaults to `False`):
  276. Whether or not to force to (re-)download the configuration files and override the cached versions if they
  277. exist.
  278. proxies (`dict[str, str]`, *optional*):
  279. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  280. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
  281. token (`str` or *bool*, *optional*):
  282. The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
  283. when running `hf auth login` (stored in `~/.huggingface`).
  284. revision (`str`, *optional*, defaults to `"main"`):
  285. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  286. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  287. identifier allowed by git.
  288. local_files_only (`bool`, *optional*, defaults to `False`):
  289. If `True`, will only try to load the tokenizer configuration from local files.
  290. subfolder (`str`, *optional*, defaults to `""`):
  291. In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
  292. specify the folder name here.
  293. repo_type (`str`, *optional*):
  294. Specify the repo type (useful when downloading from a space for instance).
  295. Private args:
  296. _raise_exceptions_for_gated_repo (`bool`):
  297. if False, do not raise an exception for gated repo error but return None.
  298. _raise_exceptions_for_missing_entries (`bool`):
  299. if False, do not raise an exception for missing entries but return None.
  300. _raise_exceptions_for_connection_errors (`bool`):
  301. if False, do not raise an exception for connection errors but return None.
  302. _commit_hash (`str`, *optional*):
  303. passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
  304. a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
  305. <Tip>
  306. Passing `token=True` is required when you want to use a private model.
  307. </Tip>
  308. Returns:
  309. `Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo).
  310. Examples:
  311. ```python
  312. # Download a model weight from the Hub and cache it.
  313. model_weights_file = cached_file("google-bert/bert-base-uncased", "pytorch_model.bin")
  314. ```
  315. """
  316. if is_offline_mode() and not local_files_only:
  317. logger.info("Offline mode: forcing local_files_only=True")
  318. local_files_only = True
  319. if subfolder is None:
  320. subfolder = ""
  321. # Add folder to filenames
  322. full_filenames = [os.path.join(subfolder, file) for file in filenames]
  323. path_or_repo_id = str(path_or_repo_id)
  324. existing_files = []
  325. for filename in full_filenames:
  326. if os.path.isdir(path_or_repo_id):
  327. resolved_file = os.path.join(path_or_repo_id, filename)
  328. if not os.path.isfile(resolved_file):
  329. if _raise_exceptions_for_missing_entries and filename != os.path.join(subfolder, "config.json"):
  330. revision_ = "main" if revision is None else revision
  331. raise OSError(
  332. f"{path_or_repo_id} does not appear to have a file named {filename}. Checkout "
  333. f"'https://huggingface.co/{path_or_repo_id}/tree/{revision_}' for available files."
  334. )
  335. else:
  336. continue
  337. existing_files.append(resolved_file)
  338. if os.path.isdir(path_or_repo_id):
  339. return existing_files if existing_files else None
  340. if cache_dir is None:
  341. cache_dir = constants.HF_HUB_CACHE
  342. if isinstance(cache_dir, Path):
  343. cache_dir = str(cache_dir)
  344. existing_files = []
  345. file_counter = 0
  346. if _commit_hash is not None and not force_download:
  347. for filename in full_filenames:
  348. # If the file is cached under that commit hash, we return it directly.
  349. resolved_file = try_to_load_from_cache(
  350. path_or_repo_id, filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
  351. )
  352. if resolved_file is not None:
  353. if resolved_file is not _CACHED_NO_EXIST:
  354. file_counter += 1
  355. existing_files.append(resolved_file)
  356. elif not _raise_exceptions_for_missing_entries:
  357. file_counter += 1
  358. else:
  359. raise OSError(f"Could not locate {filename} inside {path_or_repo_id}.")
  360. # Either all the files were found, or some were _CACHED_NO_EXIST but we do not raise for missing entries
  361. if file_counter == len(full_filenames):
  362. return existing_files if len(existing_files) > 0 else None
  363. user_agent = http_user_agent(user_agent)
  364. # download the files if needed
  365. try:
  366. if len(full_filenames) == 1:
  367. # This is slightly better for only 1 file
  368. hf_hub_download(
  369. path_or_repo_id,
  370. filenames[0],
  371. subfolder=None if len(subfolder) == 0 else subfolder,
  372. repo_type=repo_type,
  373. revision=revision,
  374. cache_dir=cache_dir,
  375. user_agent=user_agent,
  376. force_download=force_download,
  377. proxies=proxies,
  378. token=token,
  379. local_files_only=local_files_only,
  380. tqdm_class=tqdm_class,
  381. )
  382. else:
  383. snapshot_download(
  384. path_or_repo_id,
  385. allow_patterns=full_filenames,
  386. repo_type=repo_type,
  387. revision=revision,
  388. cache_dir=cache_dir,
  389. user_agent=user_agent,
  390. force_download=force_download,
  391. proxies=proxies,
  392. token=token,
  393. local_files_only=local_files_only,
  394. tqdm_class=tqdm_class,
  395. )
  396. except Exception as e:
  397. # We cannot recover from them
  398. if isinstance(e, RepositoryNotFoundError) and not isinstance(e, GatedRepoError):
  399. raise OSError(
  400. f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
  401. "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token "
  402. "having permission to this repo either by logging in with `hf auth login` or by passing "
  403. "`token=<your_token>`"
  404. ) from e
  405. elif isinstance(e, RevisionNotFoundError):
  406. raise OSError(
  407. f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
  408. "for this model name. Check the model page at "
  409. f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
  410. ) from e
  411. elif isinstance(e, PermissionError):
  412. raise OSError(
  413. f"PermissionError at {e.filename} when downloading {path_or_repo_id}. "
  414. "Check cache directory permissions. Common causes: 1) another user is downloading the same model (please wait); "
  415. "2) a previous download was canceled and the lock file needs manual removal."
  416. ) from e
  417. elif isinstance(e, ValueError):
  418. raise OSError(f"{e}") from e
  419. # Now we try to recover if we can find all files correctly in the cache
  420. resolved_files = [
  421. _get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision, repo_type)
  422. for filename in full_filenames
  423. ]
  424. if all(file is not None for file in resolved_files):
  425. return resolved_files
  426. # Raise based on the flags. Note that we will raise for missing entries at the very end, even when
  427. # not entering this Except block, as it may also happen when `snapshot_download` does not raise
  428. if isinstance(e, GatedRepoError):
  429. if not _raise_exceptions_for_gated_repo:
  430. return None
  431. raise OSError(
  432. "You are trying to access a gated repo.\nMake sure to have access to it at "
  433. f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}"
  434. ) from e
  435. elif isinstance(e, LocalEntryNotFoundError):
  436. if not _raise_exceptions_for_connection_errors:
  437. return None
  438. # Here we only raise if both flags for missing entry and connection errors are True (because it can be raised
  439. # even when `local_files_only` is True, in which case raising for connections errors only would not make sense)
  440. elif _raise_exceptions_for_missing_entries:
  441. raise OSError(
  442. f"We couldn't connect to '{constants.ENDPOINT}' to load the files, and couldn't find them in the"
  443. f" cached files.\nCheck your internet connection or see how to run the library in offline mode at"
  444. " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
  445. ) from e
  446. # snapshot_download will not raise EntryNotFoundError, but hf_hub_download can. If this is the case, it will be treated
  447. # later on anyway and re-raised if needed
  448. elif isinstance(e, HfHubHTTPError) and not isinstance(e, EntryNotFoundError):
  449. if not _raise_exceptions_for_connection_errors:
  450. return None
  451. raise OSError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{e}") from e
  452. # Any other Exception type should now be re-raised, in order to provide helpful error messages and break the execution flow
  453. # (EntryNotFoundError will be treated outside this block and correctly re-raised if needed)
  454. elif not isinstance(e, EntryNotFoundError):
  455. raise e
  456. resolved_files = [
  457. _get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision) for filename in full_filenames
  458. ]
  459. # If there are any missing file and the flag is active, raise
  460. if any(file is None for file in resolved_files) and _raise_exceptions_for_missing_entries:
  461. missing_entries = [original for original, resolved in zip(full_filenames, resolved_files) if resolved is None]
  462. # Last escape
  463. if len(resolved_files) == 1 and missing_entries[0] == os.path.join(subfolder, "config.json"):
  464. return None
  465. # Now we raise for missing entries
  466. revision_ = "main" if revision is None else revision
  467. msg = (
  468. f"a file named {missing_entries[0]}" if len(missing_entries) == 1 else f"files named {(*missing_entries,)}"
  469. )
  470. raise OSError(
  471. f"{path_or_repo_id} does not appear to have {msg}. Checkout 'https://huggingface.co/{path_or_repo_id}/tree/{revision_}'"
  472. " for available files."
  473. )
  474. # Remove potential missing entries (we can silently remove them at this point based on the flags)
  475. resolved_files = [file for file in resolved_files if file is not None]
  476. # Return `None` if the list is empty, coherent with other Exception when the flag is not active
  477. resolved_files = None if len(resolved_files) == 0 else resolved_files
  478. return resolved_files
  479. def has_file(
  480. path_or_repo: str | os.PathLike,
  481. filename: str,
  482. revision: str | None = None,
  483. proxies: dict[str, str] | None = None,
  484. token: bool | str | None = None,
  485. *,
  486. local_files_only: bool = False,
  487. cache_dir: str | Path | None = None,
  488. repo_type: str | None = None,
  489. **deprecated_kwargs,
  490. ):
  491. """
  492. Checks if a repo contains a given file without downloading it. Works for remote repos and local folders.
  493. If offline mode is enabled, checks if the file exists in the cache.
  494. <Tip warning={false}>
  495. This function will raise an error if the repository `path_or_repo` is not valid or if `revision` does not exist for
  496. this repo, but will return False for regular connection errors.
  497. </Tip>
  498. """
  499. # If path to local directory, check if the file exists
  500. if os.path.isdir(path_or_repo):
  501. return os.path.isfile(os.path.join(path_or_repo, filename))
  502. # Else it's a repo => let's check if the file exists in local cache or on the Hub
  503. # Check if file exists in cache
  504. # This information might be outdated so it's best to also make a HEAD call (if allowed).
  505. cached_path = try_to_load_from_cache(
  506. repo_id=path_or_repo,
  507. filename=filename,
  508. revision=revision,
  509. repo_type=repo_type,
  510. cache_dir=cache_dir,
  511. )
  512. has_file_in_cache = isinstance(cached_path, str)
  513. # If local_files_only, don't try the HEAD call
  514. if local_files_only:
  515. return has_file_in_cache
  516. # Check if the file exists
  517. try:
  518. response = get_session().head(
  519. hf_hub_url(path_or_repo, filename=filename, revision=revision, repo_type=repo_type),
  520. headers=build_hf_headers(token=token, user_agent=http_user_agent()),
  521. follow_redirects=False,
  522. timeout=10,
  523. )
  524. except httpx.ProxyError:
  525. # Actually raise for those subclasses of ConnectionError
  526. raise
  527. except (httpx.ConnectError, httpx.TimeoutException, OfflineModeIsEnabled):
  528. return has_file_in_cache
  529. try:
  530. hf_raise_for_status(response)
  531. return True
  532. except GatedRepoError as e:
  533. logger.error(e)
  534. raise OSError(
  535. f"{path_or_repo} is a gated repository. Make sure to request access at "
  536. f"https://huggingface.co/{path_or_repo} and pass a token having permission to this repo either by "
  537. "logging in with `hf auth login` or by passing `token=<your_token>`."
  538. ) from e
  539. except RepositoryNotFoundError as e:
  540. logger.error(e)
  541. raise OSError(f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'.") from e
  542. except RevisionNotFoundError as e:
  543. logger.error(e)
  544. raise OSError(
  545. f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
  546. f"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions."
  547. ) from e
  548. except EntryNotFoundError:
  549. return False # File does not exist
  550. except HfHubHTTPError:
  551. # Any authentication/authorization error will be caught here => default to cache
  552. return has_file_in_cache
  553. class PushToHubMixin:
  554. """
  555. A Mixin containing the functionality to push a model or tokenizer to the hub.
  556. """
  557. def _get_files_timestamps(self, working_dir: str | os.PathLike):
  558. """
  559. Returns the list of files with their last modification timestamp.
  560. """
  561. return {f: os.path.getmtime(os.path.join(working_dir, f)) for f in os.listdir(working_dir)}
  562. def _upload_modified_files(
  563. self,
  564. working_dir: str | os.PathLike,
  565. repo_id: str,
  566. files_timestamps: dict[str, float],
  567. commit_message: str | None = None,
  568. token: bool | str | None = None,
  569. create_pr: bool = False,
  570. revision: str | None = None,
  571. commit_description: str | None = None,
  572. ):
  573. """
  574. Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.
  575. """
  576. if commit_message is None:
  577. if "Model" in self.__class__.__name__:
  578. commit_message = "Upload model"
  579. elif "Config" in self.__class__.__name__:
  580. commit_message = "Upload config"
  581. elif "Tokenizer" in self.__class__.__name__:
  582. commit_message = "Upload tokenizer"
  583. elif "FeatureExtractor" in self.__class__.__name__:
  584. commit_message = "Upload feature extractor"
  585. elif "Processor" in self.__class__.__name__:
  586. commit_message = "Upload processor"
  587. else:
  588. commit_message = f"Upload {self.__class__.__name__}"
  589. modified_files = [
  590. f
  591. for f in os.listdir(working_dir)
  592. if f not in files_timestamps or os.path.getmtime(os.path.join(working_dir, f)) > files_timestamps[f]
  593. ]
  594. # filter for actual files + folders at the root level
  595. modified_files = [
  596. f
  597. for f in modified_files
  598. if os.path.isfile(os.path.join(working_dir, f)) or os.path.isdir(os.path.join(working_dir, f))
  599. ]
  600. operations = []
  601. # upload standalone files
  602. for file in modified_files:
  603. if os.path.isdir(os.path.join(working_dir, file)):
  604. # go over individual files of folder
  605. for f in os.listdir(os.path.join(working_dir, file)):
  606. operations.append(
  607. CommitOperationAdd(
  608. path_or_fileobj=os.path.join(working_dir, file, f), path_in_repo=os.path.join(file, f)
  609. )
  610. )
  611. else:
  612. operations.append(
  613. CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file), path_in_repo=file)
  614. )
  615. if revision is not None and not revision.startswith("refs/pr"):
  616. try:
  617. create_branch(repo_id=repo_id, branch=revision, token=token, exist_ok=True)
  618. except HfHubHTTPError as e:
  619. if e.response.status_code == 403 and create_pr:
  620. # If we are creating a PR on a repo we don't have access to, we can't create the branch.
  621. # so let's assume the branch already exists. If it's not the case, an error will be raised when
  622. # calling `create_commit` below.
  623. pass
  624. else:
  625. raise
  626. logger.info(f"Uploading the following files to {repo_id}: {','.join(modified_files)}")
  627. return create_commit(
  628. repo_id=repo_id,
  629. operations=operations,
  630. commit_message=commit_message,
  631. commit_description=commit_description,
  632. token=token,
  633. create_pr=create_pr,
  634. revision=revision,
  635. )
  636. def save_pretrained(self, *args, **kwargs):
  637. # explicit contract
  638. raise NotImplementedError(f"{self.__class__.__name__} must implement `save_pretrained` to use `push_to_hub`.")
  639. def push_to_hub(
  640. self,
  641. repo_id: str,
  642. *,
  643. # Commit details
  644. commit_message: str | None = None,
  645. commit_description: str | None = None,
  646. # Repo / upload details
  647. private: bool | None = None,
  648. token: bool | str | None = None,
  649. revision: str | None = None,
  650. create_pr: bool = False,
  651. # Serialization details
  652. max_shard_size: int | str | None = "50GB",
  653. tags: list[str] | None = None,
  654. ) -> str:
  655. """
  656. Upload the {object_files} to the 🤗 Model Hub.
  657. Parameters:
  658. repo_id (`str`):
  659. The name of the repository you want to push your {object} to. It should contain your organization name
  660. when pushing to a given organization.
  661. commit_message (`str`, *optional*):
  662. Message to commit while pushing. Will default to `"Upload {object}"`.
  663. commit_description (`str`, *optional*):
  664. The description of the commit that will be created
  665. private (`bool`, *optional*):
  666. Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
  667. token (`bool` or `str`, *optional*):
  668. The token to use as HTTP bearer authorization for remote files. If `True` (default), will use the token generated
  669. when running `hf auth login` (stored in `~/.huggingface`).
  670. revision (`str`, *optional*):
  671. Branch to push the uploaded files to.
  672. create_pr (`bool`, *optional*, defaults to `False`):
  673. Whether or not to create a PR with the uploaded files or directly commit.
  674. max_shard_size (`int` or `str`, *optional*, defaults to `"50GB"`):
  675. Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
  676. will then be each of size lower than this size. If expressed as a string, needs to be digits followed
  677. by a unit (like `"5MB"`).
  678. tags (`list[str]`, *optional*):
  679. List of tags to push on the Hub.
  680. Examples:
  681. ```python
  682. from transformers import {object_class}
  683. {object} = {object_class}.from_pretrained("google-bert/bert-base-cased")
  684. # Push the {object} to your namespace with the name "my-finetuned-bert".
  685. {object}.push_to_hub("my-finetuned-bert")
  686. # Push the {object} to an organization with the name "my-finetuned-bert".
  687. {object}.push_to_hub("huggingface/my-finetuned-bert")
  688. ```
  689. """
  690. # Create repo if it doesn't exist yet
  691. repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id
  692. # Load model card or create a new one + eventually tag it
  693. model_card = create_and_tag_model_card(repo_id, tags, token=token)
  694. with tempfile.TemporaryDirectory() as tmp_dir:
  695. # Save all files.
  696. self.save_pretrained(tmp_dir, max_shard_size=max_shard_size)
  697. # Update model card
  698. model_card.save(os.path.join(tmp_dir, "README.md"))
  699. # Upload
  700. return self._upload_modified_files(
  701. tmp_dir,
  702. repo_id,
  703. files_timestamps={},
  704. commit_message=commit_message,
  705. token=token,
  706. create_pr=create_pr,
  707. revision=revision,
  708. commit_description=commit_description,
  709. )
  710. def convert_file_size_to_int(size: int | str):
  711. """
  712. Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
  713. Args:
  714. size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
  715. Example:
  716. ```py
  717. >>> convert_file_size_to_int("1MiB")
  718. 1048576
  719. ```
  720. """
  721. if isinstance(size, int):
  722. return size
  723. if size.upper().endswith("GIB"):
  724. return int(size[:-3]) * (2**30)
  725. if size.upper().endswith("MIB"):
  726. return int(size[:-3]) * (2**20)
  727. if size.upper().endswith("KIB"):
  728. return int(size[:-3]) * (2**10)
  729. if size.upper().endswith("GB"):
  730. int_size = int(size[:-2]) * (10**9)
  731. return int_size // 8 if size.endswith("b") else int_size
  732. if size.upper().endswith("MB"):
  733. int_size = int(size[:-2]) * (10**6)
  734. return int_size // 8 if size.endswith("b") else int_size
  735. if size.upper().endswith("KB"):
  736. int_size = int(size[:-2]) * (10**3)
  737. return int_size // 8 if size.endswith("b") else int_size
  738. raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
  739. def get_checkpoint_shard_files(
  740. pretrained_model_name_or_path,
  741. index_filename,
  742. cache_dir=None,
  743. force_download=False,
  744. proxies=None,
  745. local_files_only=False,
  746. token=None,
  747. user_agent=None,
  748. revision=None,
  749. subfolder="",
  750. _commit_hash=None,
  751. tqdm_class=None,
  752. **deprecated_kwargs,
  753. ):
  754. """
  755. For a given model:
  756. - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
  757. Hub
  758. - returns the list of paths to all the shards, as well as some metadata.
  759. For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
  760. index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
  761. """
  762. if not os.path.isfile(index_filename):
  763. raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
  764. with open(index_filename) as f:
  765. index = json.loads(f.read())
  766. shard_filenames = sorted(set(index["weight_map"].values()))
  767. sharded_metadata = index["metadata"]
  768. sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
  769. sharded_metadata["weight_map"] = index["weight_map"].copy()
  770. # First, let's deal with local folder.
  771. if os.path.isdir(pretrained_model_name_or_path):
  772. shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
  773. return shard_filenames, sharded_metadata
  774. # At this stage pretrained_model_name_or_path is a model identifier on the Hub. Try to get everything from cache,
  775. # or download the files
  776. cached_filenames = cached_files(
  777. pretrained_model_name_or_path,
  778. shard_filenames,
  779. cache_dir=cache_dir,
  780. force_download=force_download,
  781. proxies=proxies,
  782. local_files_only=local_files_only,
  783. token=token,
  784. user_agent=user_agent,
  785. revision=revision,
  786. subfolder=subfolder,
  787. _commit_hash=_commit_hash,
  788. tqdm_class=tqdm_class,
  789. )
  790. return cached_filenames, sharded_metadata
  791. def create_and_tag_model_card(repo_id: str, tags: list[str] | None = None, token: str | None = None) -> ModelCard:
  792. """
  793. Creates or loads an existing model card and tags it.
  794. Args:
  795. repo_id (`str`):
  796. The repo_id where to look for the model card.
  797. tags (`list[str]`, *optional*):
  798. The list of tags to add in the model card
  799. token (`str`, *optional*):
  800. Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token.
  801. """
  802. try:
  803. # Check if the model card is present on the remote repo
  804. model_card = ModelCard.load(repo_id, token=token)
  805. except EntryNotFoundError:
  806. # Otherwise create a simple model card from template
  807. model_description = "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated."
  808. card_data = ModelCardData(tags=[] if tags is None else tags, library_name="transformers")
  809. model_card = ModelCard.from_template(card_data, model_description=model_description)
  810. if tags is not None:
  811. # Ensure model_card.data.tags is a list and not None
  812. if model_card.data.tags is None:
  813. model_card.data.tags = []
  814. for model_tag in tags:
  815. if model_tag not in model_card.data.tags:
  816. model_card.data.tags.append(model_tag)
  817. return model_card
  818. class PushInProgress:
  819. """
  820. Internal class to keep track of a push in progress (which might contain multiple `Future` jobs).
  821. """
  822. def __init__(self, jobs: futures.Future | None = None) -> None:
  823. self.jobs = [] if jobs is None else jobs
  824. def is_done(self):
  825. return all(job.done() for job in self.jobs)
  826. def wait_until_done(self):
  827. futures.wait(self.jobs)
  828. def cancel(self) -> None:
  829. self.jobs = [
  830. job
  831. for job in self.jobs
  832. # Cancel the job if it wasn't started yet and remove cancelled/done jobs from the list
  833. if not (job.cancel() or job.done())
  834. ]