_xet.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import time
  2. from dataclasses import dataclass
  3. from enum import Enum
  4. import httpx
  5. from .. import constants
  6. from . import hf_raise_for_status, http_backoff, validate_hf_hub_args
  7. XET_CONNECTION_INFO_SAFETY_PERIOD = 60 # seconds
  8. XET_CONNECTION_INFO_CACHE_SIZE = 1_000
  9. XET_CONNECTION_INFO_CACHE: dict[str, "XetConnectionInfo"] = {}
  10. class XetTokenType(str, Enum):
  11. READ = "read"
  12. WRITE = "write"
  13. @dataclass(frozen=True)
  14. class XetFileData:
  15. file_hash: str
  16. refresh_route: str
  17. @dataclass(frozen=True)
  18. class XetConnectionInfo:
  19. access_token: str
  20. expiration_unix_epoch: int
  21. endpoint: str
  22. def parse_xet_file_data_from_response(response: httpx.Response, endpoint: str | None = None) -> XetFileData | None:
  23. """
  24. Parse XET file metadata from an HTTP response.
  25. This function extracts XET file metadata from the HTTP headers or HTTP links
  26. of a given response object. If the required metadata is not found, it returns `None`.
  27. Args:
  28. response (`httpx.Response`):
  29. The HTTP response object containing headers dict and links dict to extract the XET metadata from.
  30. Returns:
  31. `Optional[XetFileData]`:
  32. An instance of `XetFileData` containing the file hash and refresh route if the metadata
  33. is found. Returns `None` if the required metadata is missing.
  34. """
  35. if response is None:
  36. return None
  37. try:
  38. file_hash = response.headers[constants.HUGGINGFACE_HEADER_X_XET_HASH]
  39. if constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY in response.links:
  40. refresh_route = response.links[constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY]["url"]
  41. else:
  42. refresh_route = response.headers[constants.HUGGINGFACE_HEADER_X_XET_REFRESH_ROUTE]
  43. except KeyError:
  44. return None
  45. endpoint = endpoint if endpoint is not None else constants.ENDPOINT
  46. if refresh_route.startswith(constants.HUGGINGFACE_CO_URL_HOME):
  47. refresh_route = refresh_route.replace(constants.HUGGINGFACE_CO_URL_HOME.rstrip("/"), endpoint.rstrip("/"))
  48. return XetFileData(
  49. file_hash=file_hash,
  50. refresh_route=refresh_route,
  51. )
  52. def parse_xet_connection_info_from_headers(headers: dict[str, str]) -> XetConnectionInfo | None:
  53. """
  54. Parse XET connection info from the HTTP headers or return None if not found.
  55. Args:
  56. headers (`dict`):
  57. HTTP headers to extract the XET metadata from.
  58. Returns:
  59. `XetConnectionInfo` or `None`:
  60. The information needed to connect to the XET storage service.
  61. Returns `None` if the headers do not contain the XET connection info.
  62. """
  63. try:
  64. endpoint = headers[constants.HUGGINGFACE_HEADER_X_XET_ENDPOINT]
  65. access_token = headers[constants.HUGGINGFACE_HEADER_X_XET_ACCESS_TOKEN]
  66. expiration_unix_epoch = int(headers[constants.HUGGINGFACE_HEADER_X_XET_EXPIRATION])
  67. except (KeyError, ValueError, TypeError):
  68. return None
  69. return XetConnectionInfo(
  70. endpoint=endpoint,
  71. access_token=access_token,
  72. expiration_unix_epoch=expiration_unix_epoch,
  73. )
  74. @validate_hf_hub_args
  75. def refresh_xet_connection_info(
  76. *,
  77. file_data: XetFileData,
  78. headers: dict[str, str],
  79. ) -> XetConnectionInfo:
  80. """
  81. Utilizes the information in the parsed metadata to request the Hub xet connection information.
  82. This includes the access token, expiration, and XET service URL.
  83. Args:
  84. file_data: (`XetFileData`):
  85. The file data needed to refresh the xet connection information.
  86. headers (`dict[str, str]`):
  87. Headers to use for the request, including authorization headers and user agent.
  88. Returns:
  89. `XetConnectionInfo`:
  90. The connection information needed to make the request to the xet storage service.
  91. Raises:
  92. [`~utils.HfHubHTTPError`]
  93. If the Hub API returned an error.
  94. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  95. If the Hub API response is improperly formatted.
  96. """
  97. if file_data.refresh_route is None:
  98. raise ValueError("The provided xet metadata does not contain a refresh endpoint.")
  99. return _fetch_xet_connection_info_with_url(file_data.refresh_route, headers)
  100. @validate_hf_hub_args
  101. def fetch_xet_connection_info_from_repo_info(
  102. *,
  103. token_type: XetTokenType,
  104. repo_id: str,
  105. repo_type: str,
  106. revision: str | None = None,
  107. headers: dict[str, str],
  108. endpoint: str | None = None,
  109. params: dict[str, str] | None = None,
  110. ) -> XetConnectionInfo:
  111. """
  112. Uses the repo info to request a xet access token from Hub.
  113. Args:
  114. token_type (`XetTokenType`):
  115. Type of the token to request: `"read"` or `"write"`.
  116. repo_id (`str`):
  117. A namespace (user or an organization) and a repo name separated by a `/`.
  118. repo_type (`str`):
  119. Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
  120. revision (`str`, `optional`):
  121. The revision of the repo to get the token for.
  122. headers (`dict[str, str]`):
  123. Headers to use for the request, including authorization headers and user agent.
  124. endpoint (`str`, `optional`):
  125. The endpoint to use for the request. Defaults to the Hub endpoint.
  126. params (`dict[str, str]`, `optional`):
  127. Additional parameters to pass with the request.
  128. Returns:
  129. `XetConnectionInfo`:
  130. The connection information needed to make the request to the xet storage service.
  131. Raises:
  132. [`~utils.HfHubHTTPError`]
  133. If the Hub API returned an error.
  134. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  135. If the Hub API response is improperly formatted.
  136. """
  137. endpoint = endpoint if endpoint is not None else constants.ENDPOINT
  138. url = f"{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type.value}-token"
  139. if repo_type != "bucket" or revision is not None:
  140. # On "bucket" repo type, the revision never needed => don't use it
  141. # Otherwise, use the revision.
  142. # Note: when creating a PR on a git-based repo, user needs write access but they don't know the revision in advance.
  143. # => pass "/None" in URL and server will return a token for PR refs.
  144. url += f"/{revision}"
  145. return _fetch_xet_connection_info_with_url(url, headers, params, cache_key_prefix=f"{repo_type}-{repo_id}")
  146. @validate_hf_hub_args
  147. def _fetch_xet_connection_info_with_url(
  148. url: str,
  149. headers: dict[str, str],
  150. params: dict[str, str] | None = None,
  151. cache_key_prefix: str | None = None,
  152. ) -> XetConnectionInfo:
  153. """
  154. Requests the xet connection info from the supplied URL. This includes the
  155. access token, expiration time, and endpoint to use for the xet storage service.
  156. Result is cached to avoid redundant requests.
  157. Args:
  158. url: (`str`):
  159. The access token endpoint URL.
  160. headers (`dict[str, str]`):
  161. Headers to use for the request, including authorization headers and user agent.
  162. params (`dict[str, str]`, `optional`):
  163. Additional parameters to pass with the request.
  164. Returns:
  165. `XetConnectionInfo`:
  166. The connection information needed to make the request to the xet storage service.
  167. Raises:
  168. [`~utils.HfHubHTTPError`]
  169. If the Hub API returned an error.
  170. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  171. If the Hub API response is improperly formatted.
  172. """
  173. # Check cache first
  174. cache_key = _cache_key(url, headers, params, prefix=cache_key_prefix)
  175. cached_info = XET_CONNECTION_INFO_CACHE.get(cache_key)
  176. if cached_info is not None:
  177. if not _is_expired(cached_info):
  178. return cached_info
  179. # Fetch from server
  180. resp = http_backoff("GET", url, headers=headers, params=params)
  181. hf_raise_for_status(resp)
  182. metadata = parse_xet_connection_info_from_headers(resp.headers) # type: ignore
  183. if metadata is None:
  184. raise ValueError("Xet headers have not been correctly set by the server.")
  185. # Delete expired cache entries
  186. for k, v in list(XET_CONNECTION_INFO_CACHE.items()):
  187. if _is_expired(v):
  188. XET_CONNECTION_INFO_CACHE.pop(k, None)
  189. # Enforce cache size limit
  190. if len(XET_CONNECTION_INFO_CACHE) >= XET_CONNECTION_INFO_CACHE_SIZE:
  191. XET_CONNECTION_INFO_CACHE.pop(next(iter(XET_CONNECTION_INFO_CACHE)))
  192. # Update cache
  193. XET_CONNECTION_INFO_CACHE[cache_key] = metadata
  194. return metadata
  195. def reset_xet_connection_info_cache_for_repo(repo_type: str | None, repo_id: str) -> None:
  196. """Reset the XET connection info cache for the given repo type and repo id.
  197. Used when a repo is deleted.
  198. """
  199. if repo_type is None:
  200. repo_type = constants.REPO_TYPE_MODEL
  201. prefix = f"{repo_type}-{repo_id}|"
  202. for k in list(XET_CONNECTION_INFO_CACHE.keys()):
  203. if k.startswith(prefix):
  204. XET_CONNECTION_INFO_CACHE.pop(k, None)
  205. def _cache_key(url: str, headers: dict[str, str], params: dict[str, str] | None, prefix: str | None = None) -> str:
  206. """Return a unique cache key for the given request parameters."""
  207. lower_headers = {k.lower(): v for k, v in headers.items()} # casing is not guaranteed here
  208. auth_header = lower_headers.get("authorization", "")
  209. params_str = "&".join(f"{k}={v}" for k, v in sorted((params or {}).items(), key=lambda x: x[0]))
  210. return f"{prefix}|{url}|{auth_header}|{params_str}"
  211. def _is_expired(connection_info: XetConnectionInfo) -> bool:
  212. """Check if the given XET connection info is expired."""
  213. return connection_info.expiration_unix_epoch <= int(time.time()) + XET_CONNECTION_INFO_SAFETY_PERIOD