files.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. """W&B Public API for File objects.
  2. This module provides classes for interacting with files stored in W&B.
  3. Example:
  4. ```python
  5. from wandb.apis.public import Api
  6. # Get files from a specific run
  7. run = Api().run("entity/project/run_id")
  8. files = run.files()
  9. # Work with files
  10. for file in files:
  11. print(f"File: {file.name}")
  12. print(f"Size: {file.size} bytes")
  13. print(f"Type: {file.mimetype}")
  14. # Download file
  15. if file.size < 1000000: # Less than 1MB
  16. file.download(root="./downloads")
  17. # Get S3 URI for large files
  18. if file.size >= 1000000:
  19. print(f"S3 URI: {file.path_uri}")
  20. ```
  21. Note:
  22. This module is part of the W&B Public API and provides methods to access,
  23. download, and manage files stored in W&B. Files are typically associated
  24. with specific runs and can include model weights, datasets, visualizations,
  25. and other artifacts.
  26. """
  27. from __future__ import annotations
  28. import io
  29. import os
  30. from typing import TYPE_CHECKING, Any, Callable
  31. from wandb_gql import gql
  32. from wandb_gql.client import RetryError
  33. import wandb
  34. from wandb._strutils import nameof
  35. from wandb.apis.attrs import Attrs
  36. from wandb.apis.normalize import normalize_exceptions
  37. from wandb.apis.paginator import SizedPaginator
  38. from wandb.apis.public import utils
  39. from wandb.apis.public.const import RETRY_TIMEDELTA
  40. from wandb.apis.public.runs import Run
  41. from wandb.sdk.lib import retry
  42. from wandb.util import POW_2_BYTES, download_file_from_url, no_retry_auth, to_human_size
  43. if TYPE_CHECKING:
  44. from wandb_graphql.language.ast import Document
  45. from wandb.apis.public import Api, RetryingClient
  46. FILE_FRAGMENT = """fragment RunFilesFragment on Run {
  47. files(names: $fileNames, after: $fileCursor, first: $fileLimit, pattern: $pattern) {
  48. edges {
  49. node {
  50. id
  51. name
  52. url(upload: $upload)
  53. directUrl
  54. sizeBytes
  55. mimetype
  56. updatedAt
  57. md5
  58. }
  59. cursor
  60. }
  61. pageInfo {
  62. endCursor
  63. hasNextPage
  64. }
  65. }
  66. }"""
  67. class Files(SizedPaginator["File"]):
  68. """A lazy iterator over a collection of `File` objects.
  69. Access and manage files uploaded to W&B during a run. Handles pagination
  70. automatically when iterating through large collections of files.
  71. Example:
  72. ```python
  73. from wandb.apis.public.files import Files
  74. from wandb.apis.public.api import Api
  75. # Example run object
  76. run = Api().run("entity/project/run-id")
  77. # Create a Files object to iterate over files in the run
  78. files = Files(api.client, run)
  79. # Iterate over files
  80. for file in files:
  81. print(file.name)
  82. print(file.url)
  83. print(file.size)
  84. # Download the file
  85. file.download(root="download_directory", replace=True)
  86. ```
  87. """
  88. def _get_query(self) -> Document:
  89. """Generate query dynamically based on server capabilities."""
  90. return gql(
  91. f"""
  92. query RunFiles($project: String!, $entity: String!, $name: String!, $fileCursor: String,
  93. $fileLimit: Int = 50, $fileNames: [String] = [], $upload: Boolean = false, $pattern: String) {{
  94. project(name: $project, entityName: $entity) {{
  95. internalId
  96. run(name: $name) {{
  97. fileCount
  98. ...RunFilesFragment
  99. }}
  100. }}
  101. }}
  102. {FILE_FRAGMENT}
  103. """
  104. )
  105. def __init__(
  106. self,
  107. client: RetryingClient,
  108. run: Run,
  109. names: list[str] | None = None,
  110. per_page: int = 50,
  111. upload: bool = False,
  112. pattern: str | None = None,
  113. ):
  114. """Initialize a lazy iterator over a collection of `File` objects.
  115. Files are retrieved in pages from the W&B server as needed.
  116. Args:
  117. client: The run object that contains the files
  118. run: The run object that contains the files
  119. names (list, optional): A list of file names to filter the files
  120. per_page (int, optional): The number of files to fetch per page
  121. upload (bool, optional): If `True`, fetch the upload URL for each file
  122. pattern (str, optional): Pattern to match when returning files from W&B
  123. This pattern uses mySQL's LIKE syntax,
  124. so matching all files that end with .json would be "%.json".
  125. If both names and pattern are provided, a ValueError will be raised.
  126. """
  127. if names and pattern:
  128. raise ValueError(
  129. "Querying for files by both names and pattern is not supported."
  130. " Please provide either a list of names or a pattern to match.",
  131. )
  132. self.run = run
  133. variables = {
  134. "project": run.project,
  135. "entity": run.entity,
  136. "name": run.id,
  137. "fileNames": names or [],
  138. "upload": upload,
  139. "pattern": pattern,
  140. }
  141. super().__init__(client, variables, per_page)
  142. def _update_response(self) -> None:
  143. """Fetch and store the response data for the next page using dynamic query."""
  144. self.last_response = self.client.execute(
  145. self._get_query(), variable_values=self.variables
  146. )
  147. @property
  148. def _length(self) -> int:
  149. """
  150. Returns total number of files.
  151. <!-- lazydoc-ignore: internal -->
  152. """
  153. if not self.last_response:
  154. self._load_page()
  155. return self.last_response["project"]["run"]["fileCount"]
  156. @property
  157. def more(self) -> bool:
  158. """Returns whether there are more files to fetch.
  159. <!-- lazydoc-ignore: internal -->
  160. """
  161. if self.last_response:
  162. return self.last_response["project"]["run"]["files"]["pageInfo"][
  163. "hasNextPage"
  164. ]
  165. else:
  166. return True
  167. @property
  168. def cursor(self) -> str | None:
  169. """Returns the cursor position for pagination of file results.
  170. <!-- lazydoc-ignore: internal -->
  171. """
  172. if self.last_response:
  173. return self.last_response["project"]["run"]["files"]["edges"][-1]["cursor"]
  174. else:
  175. return None
  176. def update_variables(self) -> None:
  177. """Updates the GraphQL query variables for pagination.
  178. <!-- lazydoc-ignore: internal -->
  179. """
  180. self.variables.update({"fileLimit": self.per_page, "fileCursor": self.cursor})
  181. def convert_objects(self) -> list[File]:
  182. """Converts GraphQL edges to File objects.
  183. <!-- lazydoc-ignore: internal -->
  184. """
  185. return [
  186. File(self.client, r["node"], self.run)
  187. for r in self.last_response["project"]["run"]["files"]["edges"]
  188. ]
  189. def __repr__(self) -> str:
  190. return f"<{nameof(type(self))} {'/'.join(self.run.path)} ({len(self)})>"
  191. class File(Attrs):
  192. """File saved to W&B.
  193. Represents a single file stored in W&B. Includes access to file metadata.
  194. Files are associated with a specific run and
  195. can include text files, model weights, datasets, visualizations, and other
  196. artifacts. You can download the file, delete the file, and access file
  197. properties.
  198. Specify one or more attributes in a dictionary to fine a specific
  199. file logged to a specific run. You can search using the following keys:
  200. - id (str): The ID of the run that contains the file
  201. - name (str): Name of the file
  202. - url (str): path to file
  203. - direct_url (str): path to file in the bucket
  204. - sizeBytes (int): size of file in bytes
  205. - md5 (str): md5 of file
  206. - mimetype (str): mimetype of file
  207. - updated_at (str): timestamp of last update
  208. - path_uri (str): path to file in the bucket, currently only available for S3 objects and reference files
  209. Args:
  210. client: The run object that contains the file
  211. attrs (dict): A dictionary of attributes that define the file
  212. run: The run object that contains the file
  213. <!-- lazydoc-ignore-init: internal -->
  214. """
  215. def __init__(
  216. self,
  217. client: RetryingClient,
  218. attrs: dict[str, Any],
  219. run: Run | None = None,
  220. ):
  221. self.client = client
  222. self._attrs = attrs
  223. self.run = run
  224. self._download_decorated: Callable[..., Any] | None = None
  225. super().__init__(dict(attrs))
  226. @property
  227. def size(self) -> int:
  228. """Returns the size of the file in bytes."""
  229. size_bytes = self._attrs["sizeBytes"]
  230. if size_bytes is not None:
  231. return int(size_bytes)
  232. return 0
  233. @property
  234. def path_uri(self) -> str:
  235. """Returns the URI path to the file in the storage bucket.
  236. Returns:
  237. str: The S3 URI (e.g., 's3://bucket/path/to/file') if the file is stored in S3,
  238. the direct URL if it's a reference file, or an empty string if unavailable.
  239. """
  240. if not (direct_url := self._attrs.get("directUrl")):
  241. wandb.termwarn("Unable to find direct_url of file")
  242. return ""
  243. # For reference files, both the directUrl and the url are just the path to the file in the bucket
  244. if direct_url == self._attrs.get("url"):
  245. return direct_url
  246. try:
  247. return utils.parse_s3_url_to_s3_uri(direct_url)
  248. except ValueError:
  249. wandb.termwarn("path_uri is only available for files stored in S3")
  250. return ""
  251. def _build_download_wrapper(self) -> Callable[..., io.TextIOWrapper]:
  252. import requests
  253. @retry.retriable(
  254. retry_timedelta=RETRY_TIMEDELTA,
  255. check_retry_fn=no_retry_auth,
  256. retryable_exceptions=(RetryError, requests.RequestException),
  257. )
  258. def _impl(
  259. root: str = ".",
  260. replace: bool = False,
  261. exist_ok: bool = False,
  262. api: Api | None = None,
  263. ) -> io.TextIOWrapper:
  264. if api is None:
  265. api = wandb.Api()
  266. path = os.path.join(root, self.name)
  267. if os.path.exists(path) and not replace:
  268. if exist_ok:
  269. return open(path)
  270. raise ValueError(
  271. "File already exists, pass replace=True to overwrite "
  272. "or exist_ok=True to leave it as is and don't error."
  273. )
  274. download_file_from_url(path, self.url, api.api_key)
  275. return open(path)
  276. return _impl
  277. @normalize_exceptions
  278. def download(
  279. self,
  280. root: str = ".",
  281. replace: bool = False,
  282. exist_ok: bool = False,
  283. api: Api | None = None,
  284. ) -> io.TextIOWrapper:
  285. """Downloads a file previously saved by a run from the wandb server.
  286. Args:
  287. root: Local directory to save the file. Defaults to the
  288. current working directory (".").
  289. replace: If `True`, download will overwrite a local file
  290. if it exists. Defaults to `False`.
  291. exist_ok: If `True`, will not raise ValueError if file already
  292. exists and will not re-download unless replace=True.
  293. Defaults to `False`.
  294. api: If specified, the `Api` instance used to download the file.
  295. Raises:
  296. `ValueError` if file already exists, `replace=False` and
  297. `exist_ok=False`.
  298. """
  299. if self._download_decorated is None:
  300. self._download_decorated = self._build_download_wrapper()
  301. return self._download_decorated(root, replace, exist_ok, api)
  302. @normalize_exceptions
  303. def delete(self) -> None:
  304. """Delete the file from the W&B server."""
  305. variable_values = {
  306. "files": [self.id],
  307. "projectId": self.run._project_internal_id,
  308. }
  309. mutation = gql("""
  310. mutation deleteFiles($files: [ID!]!, $projectId: Int) {
  311. deleteFiles(input: {
  312. files: $files
  313. projectId: $projectId
  314. }) {
  315. success
  316. }
  317. }
  318. """)
  319. self.client.execute(
  320. mutation,
  321. variable_values=variable_values,
  322. )
  323. def __repr__(self) -> str:
  324. classname = nameof(type(self))
  325. size = to_human_size(self.size, units=POW_2_BYTES)
  326. return f"<{classname} {self.name} ({self.mimetype}) {size}>"