constants.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import os
  2. import re
  3. import typing
  4. from typing import Literal
  5. # Possible values for env variables
  6. ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
  7. ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
  8. def _is_true(value: str | None) -> bool:
  9. if value is None:
  10. return False
  11. return value.upper() in ENV_VARS_TRUE_VALUES
  12. def _as_int(value: str | None) -> int | None:
  13. if value is None:
  14. return None
  15. return int(value)
  16. # Constants for file downloads
  17. PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"
  18. TF2_WEIGHTS_NAME = "tf_model.h5"
  19. TF_WEIGHTS_NAME = "model.ckpt"
  20. FLAX_WEIGHTS_NAME = "flax_model.msgpack"
  21. CONFIG_NAME = "config.json"
  22. REPOCARD_NAME = "README.md"
  23. EVAL_RESULTS_FOLDER = ".eval_results"
  24. DEFAULT_ETAG_TIMEOUT = 10
  25. DEFAULT_DOWNLOAD_TIMEOUT = 10
  26. DEFAULT_REQUEST_TIMEOUT = 10
  27. DOWNLOAD_CHUNK_SIZE = 10 * 1024 * 1024
  28. MAX_HTTP_DOWNLOAD_SIZE = 50 * 1000 * 1000 * 1000 # 50 GB
  29. # Constants for serialization
  30. PYTORCH_WEIGHTS_FILE_PATTERN = "pytorch_model{suffix}.bin" # Unsafe pickle: use safetensors instead
  31. SAFETENSORS_WEIGHTS_FILE_PATTERN = "model{suffix}.safetensors"
  32. TF2_WEIGHTS_FILE_PATTERN = "tf_model{suffix}.h5"
  33. # Constants for safetensors repos
  34. SAFETENSORS_SINGLE_FILE = "model.safetensors"
  35. SAFETENSORS_INDEX_FILE = "model.safetensors.index.json"
  36. SAFETENSORS_MAX_HEADER_LENGTH = 25_000_000
  37. # Timeout of aquiring file lock and logging the attempt
  38. FILELOCK_LOG_EVERY_SECONDS = 10
  39. # Git-related constants
  40. DEFAULT_REVISION = "main"
  41. REGEX_COMMIT_OID = re.compile(r"[A-Fa-f0-9]{5,40}")
  42. HUGGINGFACE_CO_URL_HOME = "https://huggingface.co/"
  43. _staging_mode = _is_true(os.environ.get("HUGGINGFACE_CO_STAGING"))
  44. _HF_DEFAULT_ENDPOINT = "https://huggingface.co"
  45. _HF_DEFAULT_STAGING_ENDPOINT = "https://hub-ci.huggingface.co"
  46. ENDPOINT = os.getenv("HF_ENDPOINT", _HF_DEFAULT_ENDPOINT).rstrip("/")
  47. HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}"
  48. if _staging_mode:
  49. ENDPOINT = _HF_DEFAULT_STAGING_ENDPOINT
  50. HUGGINGFACE_CO_URL_TEMPLATE = _HF_DEFAULT_STAGING_ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}"
  51. DATASETS_SERVER_ENDPOINT = "https://datasets-server.huggingface.co"
  52. HUGGINGFACE_HEADER_X_REPO_COMMIT = "X-Repo-Commit"
  53. HUGGINGFACE_HEADER_X_LINKED_ETAG = "X-Linked-Etag"
  54. HUGGINGFACE_HEADER_X_LINKED_SIZE = "X-Linked-Size"
  55. HUGGINGFACE_HEADER_X_BILL_TO = "X-HF-Bill-To"
  56. INFERENCE_ENDPOINT = os.environ.get("HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co")
  57. # See https://huggingface.co/docs/inference-endpoints/index
  58. INFERENCE_ENDPOINTS_ENDPOINT = "https://api.endpoints.huggingface.cloud/v2"
  59. INFERENCE_CATALOG_ENDPOINT = "https://endpoints.huggingface.co/api/catalog"
  60. # See https://api.endpoints.huggingface.cloud/#post-/v2/endpoint/-namespace-
  61. INFERENCE_ENDPOINT_IMAGE_KEYS = [
  62. "custom",
  63. "huggingface",
  64. "huggingfaceNeuron",
  65. "llamacpp",
  66. "tei",
  67. "tgi",
  68. "tgiNeuron",
  69. ]
  70. # Proxy for third-party providers
  71. INFERENCE_PROXY_TEMPLATE = "https://router.huggingface.co/{provider}"
  72. REPO_ID_SEPARATOR = "--"
  73. # ^ this substring is not allowed in repo_ids on hf.co
  74. # and is the canonical one we use for serialization of repo ids elsewhere.
  75. REPO_TYPE_DATASET = "dataset"
  76. REPO_TYPE_SPACE = "space"
  77. REPO_TYPE_MODEL = "model"
  78. REPO_TYPE_KERNEL = "kernel"
  79. REPO_TYPES = [None, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE]
  80. REPO_TYPES_WITH_KERNEL = REPO_TYPES + [REPO_TYPE_KERNEL]
  81. SPACES_SDK_TYPES = ["gradio", "streamlit", "docker", "static"]
  82. REPO_TYPES_URL_PREFIXES = {
  83. REPO_TYPE_DATASET: "datasets/",
  84. REPO_TYPE_SPACE: "spaces/",
  85. REPO_TYPE_KERNEL: "kernels/",
  86. }
  87. REPO_TYPES_MAPPING = {
  88. "datasets": REPO_TYPE_DATASET,
  89. "spaces": REPO_TYPE_SPACE,
  90. "models": REPO_TYPE_MODEL,
  91. "kernels": REPO_TYPE_KERNEL,
  92. }
  93. DiscussionTypeFilter = Literal["all", "discussion", "pull_request"]
  94. DISCUSSION_TYPES: tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionTypeFilter)
  95. DiscussionStatusFilter = Literal["all", "open", "closed"]
  96. DISCUSSION_STATUS: tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionStatusFilter)
  97. # Webhook subscription types
  98. WEBHOOK_DOMAIN_T = Literal["repo", "discussions"]
  99. # default cache
  100. default_home = os.path.join(os.path.expanduser("~"), ".cache")
  101. HF_HOME = os.path.expandvars(
  102. os.path.expanduser(
  103. os.getenv(
  104. "HF_HOME",
  105. os.path.join(os.getenv("XDG_CACHE_HOME", default_home), "huggingface"),
  106. )
  107. )
  108. )
  109. default_cache_path = os.path.join(HF_HOME, "hub")
  110. default_assets_cache_path = os.path.join(HF_HOME, "assets")
  111. # Legacy env variables
  112. HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path)
  113. HUGGINGFACE_ASSETS_CACHE = os.getenv("HUGGINGFACE_ASSETS_CACHE", default_assets_cache_path)
  114. # New env variables
  115. HF_HUB_CACHE = os.path.expandvars(
  116. os.path.expanduser(
  117. os.getenv(
  118. "HF_HUB_CACHE",
  119. HUGGINGFACE_HUB_CACHE,
  120. )
  121. )
  122. )
  123. HF_ASSETS_CACHE = os.path.expandvars(
  124. os.path.expanduser(
  125. os.getenv(
  126. "HF_ASSETS_CACHE",
  127. HUGGINGFACE_ASSETS_CACHE,
  128. )
  129. )
  130. )
  131. HF_HUB_OFFLINE = _is_true(os.environ.get("HF_HUB_OFFLINE") or os.environ.get("TRANSFORMERS_OFFLINE"))
  132. def is_offline_mode() -> bool:
  133. """Returns whether we are in offline mode for the Hub.
  134. When offline mode is enabled, all HTTP requests made with `get_session` will raise an `OfflineModeIsEnabled` exception.
  135. Example:
  136. ```py
  137. from huggingface_hub import is_offline_mode
  138. def list_files(repo_id: str):
  139. if is_offline_mode():
  140. ... # list files from local cache (degraded experience but still functional)
  141. else:
  142. ... # list files from Hub (complete experience)
  143. ```
  144. """
  145. return HF_HUB_OFFLINE
  146. # File created to mark that the version check has been done.
  147. # Check is performed once per 24 hours at most.
  148. CHECK_FOR_UPDATE_DONE_PATH = os.path.join(HF_HOME, ".check_for_update_done")
  149. # If set, log level will be set to DEBUG and all requests made to the Hub will be logged
  150. # as curl commands for reproducibility.
  151. HF_DEBUG = _is_true(os.environ.get("HF_DEBUG"))
  152. # Opt-out from telemetry requests
  153. HF_HUB_DISABLE_TELEMETRY = (
  154. _is_true(os.environ.get("HF_HUB_DISABLE_TELEMETRY")) # HF-specific env variable
  155. or _is_true(os.environ.get("DISABLE_TELEMETRY"))
  156. or _is_true(os.environ.get("DO_NOT_TRACK")) # https://consoledonottrack.com/
  157. )
  158. HF_TOKEN_PATH = os.path.expandvars(
  159. os.path.expanduser(
  160. os.getenv(
  161. "HF_TOKEN_PATH",
  162. os.path.join(HF_HOME, "token"),
  163. )
  164. )
  165. )
  166. HF_STORED_TOKENS_PATH = os.path.join(os.path.dirname(HF_TOKEN_PATH), "stored_tokens")
  167. if _staging_mode:
  168. # In staging mode, we use a different cache to ensure we don't mix up production and staging data or tokens
  169. # In practice in `huggingface_hub` tests, we monkeypatch these values with temporary directories. The following
  170. # lines are only used in third-party libraries tests (e.g. `transformers`, `diffusers`, etc.).
  171. _staging_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface_staging")
  172. HUGGINGFACE_HUB_CACHE = os.path.join(_staging_home, "hub")
  173. HF_TOKEN_PATH = os.path.join(_staging_home, "token")
  174. # Here, `True` will disable progress bars globally without possibility of enabling it
  175. # programmatically. `False` will enable them without possibility of disabling them.
  176. # If environment variable is not set (None), then the user is free to enable/disable
  177. # them programmatically.
  178. # TL;DR: env variable has priority over code
  179. __HF_HUB_DISABLE_PROGRESS_BARS = os.environ.get("HF_HUB_DISABLE_PROGRESS_BARS")
  180. HF_HUB_DISABLE_PROGRESS_BARS: bool | None = (
  181. _is_true(__HF_HUB_DISABLE_PROGRESS_BARS) if __HF_HUB_DISABLE_PROGRESS_BARS is not None else None
  182. )
  183. # Disable symlinks in the cache (files are copied instead of symlinked)
  184. HF_HUB_DISABLE_SYMLINKS: bool = _is_true(os.environ.get("HF_HUB_DISABLE_SYMLINKS"))
  185. # Disable warning on machines that do not support symlinks (e.g. Windows non-developer)
  186. HF_HUB_DISABLE_SYMLINKS_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_SYMLINKS_WARNING"))
  187. # Disable warning when using experimental features
  188. HF_HUB_DISABLE_EXPERIMENTAL_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_EXPERIMENTAL_WARNING"))
  189. # Disable sending the cached token by default is all HTTP requests to the Hub
  190. HF_HUB_DISABLE_IMPLICIT_TOKEN: bool = _is_true(os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN"))
  191. HF_XET_HIGH_PERFORMANCE: bool = _is_true(os.environ.get("HF_XET_HIGH_PERFORMANCE"))
  192. # hf_transfer is not used anymore. Let's warn user is case they set the env variable
  193. if _is_true(os.environ.get("HF_HUB_ENABLE_HF_TRANSFER")) and not HF_XET_HIGH_PERFORMANCE:
  194. import warnings
  195. warnings.warn(
  196. "The `HF_HUB_ENABLE_HF_TRANSFER` environment variable is deprecated as 'hf_transfer' is not used anymore. "
  197. "Please use `HF_XET_HIGH_PERFORMANCE` instead to enable high performance transfer with Xet. "
  198. "Visit https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfxethighperformance for more details.",
  199. DeprecationWarning,
  200. )
  201. # Used to override the etag timeout on a system level
  202. HF_HUB_ETAG_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_ETAG_TIMEOUT")) or DEFAULT_ETAG_TIMEOUT
  203. # Used to override the get request timeout on a system level
  204. # Also used as a default timeout for other requests if not specified (kept the naming for legacy reasons)
  205. HF_HUB_DOWNLOAD_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT")) or DEFAULT_DOWNLOAD_TIMEOUT
  206. # Allows to add information about the requester in the user-agent (e.g. partner name)
  207. HF_HUB_USER_AGENT_ORIGIN: str | None = os.environ.get("HF_HUB_USER_AGENT_ORIGIN")
  208. # If OAuth didn't work after 2 redirects, there's likely a third-party cookie issue in the Space iframe view.
  209. # In this case, we redirect the user to the non-iframe view.
  210. OAUTH_MAX_REDIRECTS = 2
  211. # OAuth-related environment variables injected by the Space
  212. OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID")
  213. OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET")
  214. OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES")
  215. OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL")
  216. # Xet constants
  217. HUGGINGFACE_HEADER_X_XET_ENDPOINT = "X-Xet-Cas-Url"
  218. HUGGINGFACE_HEADER_X_XET_ACCESS_TOKEN = "X-Xet-Access-Token"
  219. HUGGINGFACE_HEADER_X_XET_EXPIRATION = "X-Xet-Token-Expiration"
  220. HUGGINGFACE_HEADER_X_XET_HASH = "X-Xet-Hash"
  221. HUGGINGFACE_HEADER_X_XET_REFRESH_ROUTE = "X-Xet-Refresh-Route"
  222. HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY = "xet-auth"
  223. default_xet_cache_path = os.path.join(HF_HOME, "xet")
  224. HF_XET_CACHE = os.getenv("HF_XET_CACHE", default_xet_cache_path)
  225. HF_HUB_DISABLE_XET: bool = _is_true(os.environ.get("HF_HUB_DISABLE_XET"))