artifact_saver.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. """Artifact saver."""
  2. from __future__ import annotations
  3. import concurrent.futures
  4. import json
  5. import os
  6. import tempfile
  7. from collections.abc import Awaitable, Sequence
  8. from typing import TYPE_CHECKING
  9. import wandb
  10. import wandb.filesync.step_prepare
  11. from wandb import util
  12. from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
  13. from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, md5_file_b64
  14. from wandb.sdk.lib.paths import URIStr
  15. if TYPE_CHECKING:
  16. from typing import Protocol
  17. from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
  18. from wandb.sdk.internal.file_pusher import FilePusher
  19. from wandb.sdk.internal.internal_api import Api as InternalApi
  20. from wandb.sdk.internal.progress import ProgressFn
  21. class SaveFn(Protocol):
  22. def __call__(
  23. self, entry: ArtifactManifestEntry, progress_callback: ProgressFn
  24. ) -> bool:
  25. pass
  26. class SaveFnAsync(Protocol):
  27. def __call__(
  28. self, entry: ArtifactManifestEntry, progress_callback: ProgressFn
  29. ) -> Awaitable[bool]:
  30. pass
  31. class ArtifactSaver:
  32. _server_artifact: dict | None # TODO better define this dict
  33. def __init__(
  34. self,
  35. api: InternalApi,
  36. digest: str,
  37. manifest_json: dict,
  38. file_pusher: FilePusher,
  39. is_user_created: bool = False,
  40. ) -> None:
  41. self._api = api
  42. self._file_pusher = file_pusher
  43. self._digest = digest
  44. self._manifest = ArtifactManifest.from_manifest_json(manifest_json)
  45. self._manifest.storage_policy._api = self._api
  46. self._is_user_created = is_user_created
  47. self._server_artifact = None
  48. def save(
  49. self,
  50. entity: str,
  51. project: str,
  52. type: str,
  53. name: str,
  54. client_id: str,
  55. sequence_client_id: str,
  56. distributed_id: str | None = None,
  57. finalize: bool = True,
  58. metadata: dict | None = None,
  59. ttl_duration_seconds: int | None = None,
  60. description: str | None = None,
  61. aliases: Sequence[str] | None = None,
  62. tags: Sequence[str] | None = None,
  63. use_after_commit: bool = False,
  64. incremental: bool = False,
  65. history_step: int | None = None,
  66. base_id: str | None = None,
  67. ) -> dict | None:
  68. return self._save_internal(
  69. entity,
  70. project,
  71. type,
  72. name,
  73. client_id,
  74. sequence_client_id,
  75. distributed_id,
  76. finalize,
  77. metadata,
  78. ttl_duration_seconds,
  79. description,
  80. aliases,
  81. tags,
  82. use_after_commit,
  83. incremental,
  84. history_step,
  85. base_id,
  86. )
  87. def _save_internal(
  88. self,
  89. entity: str,
  90. project: str,
  91. type: str,
  92. name: str,
  93. client_id: str,
  94. sequence_client_id: str,
  95. distributed_id: str | None = None,
  96. finalize: bool = True,
  97. metadata: dict | None = None,
  98. ttl_duration_seconds: int | None = None,
  99. description: str | None = None,
  100. aliases: Sequence[str] | None = None,
  101. tags: Sequence[str] | None = None,
  102. use_after_commit: bool = False,
  103. incremental: bool = False,
  104. history_step: int | None = None,
  105. base_id: str | None = None,
  106. ) -> dict | None:
  107. alias_specs = []
  108. for alias in aliases or []:
  109. alias_specs.append({"artifactCollectionName": name, "alias": alias})
  110. tag_specs = [{"tagName": tag} for tag in tags or []]
  111. """Returns the server artifact."""
  112. self._server_artifact, latest = self._api.create_artifact(
  113. type,
  114. name,
  115. self._digest,
  116. metadata=metadata,
  117. ttl_duration_seconds=ttl_duration_seconds,
  118. aliases=alias_specs,
  119. tags=tag_specs,
  120. description=description,
  121. is_user_created=self._is_user_created,
  122. distributed_id=distributed_id,
  123. client_id=client_id,
  124. sequence_client_id=sequence_client_id,
  125. history_step=history_step,
  126. )
  127. assert self._server_artifact is not None # mypy optionality unwrapper
  128. artifact_id = self._server_artifact["id"]
  129. if base_id is None and latest:
  130. base_id = latest["id"]
  131. if self._server_artifact["state"] == "COMMITTED":
  132. if use_after_commit:
  133. self._api.use_artifact(
  134. artifact_id,
  135. artifact_entity_name=entity,
  136. artifact_project_name=project,
  137. )
  138. return self._server_artifact
  139. if (
  140. self._server_artifact["state"] != "PENDING"
  141. # For old servers, see https://github.com/wandb/wandb/pull/6190
  142. and self._server_artifact["state"] != "DELETED"
  143. ):
  144. raise Exception(
  145. 'Unknown artifact state "{}"'.format(self._server_artifact["state"])
  146. )
  147. manifest_type = "FULL"
  148. manifest_filename = "wandb_manifest.json"
  149. if incremental:
  150. manifest_type = "INCREMENTAL"
  151. manifest_filename = "wandb_manifest.incremental.json"
  152. elif distributed_id:
  153. manifest_type = "PATCH"
  154. manifest_filename = "wandb_manifest.patch.json"
  155. artifact_manifest_id, _ = self._api.create_artifact_manifest(
  156. manifest_filename,
  157. "",
  158. artifact_id,
  159. base_artifact_id=base_id,
  160. include_upload=False,
  161. type=manifest_type,
  162. )
  163. step_prepare = wandb.filesync.step_prepare.StepPrepare(
  164. self._api, 0.1, 0.01, 1000
  165. ) # TODO: params
  166. step_prepare.start()
  167. # Upload Artifact "L1" files, the actual artifact contents
  168. self._file_pusher.store_manifest_files(
  169. self._manifest,
  170. artifact_id,
  171. lambda entry, progress_callback: self._manifest.storage_policy.store_file(
  172. artifact_id,
  173. artifact_manifest_id,
  174. entry,
  175. step_prepare,
  176. progress_callback=progress_callback,
  177. ),
  178. )
  179. def before_commit() -> None:
  180. self._resolve_client_id_manifest_references()
  181. with tempfile.NamedTemporaryFile("w+", suffix=".json", delete=False) as fp:
  182. path = os.path.abspath(fp.name)
  183. json.dump(self._manifest.to_manifest_json(), fp, indent=4)
  184. digest = md5_file_b64(path)
  185. if distributed_id or incremental:
  186. # If we're in the distributed flow, we want to update the
  187. # patch manifest we created with our finalized digest.
  188. _, resp = self._api.update_artifact_manifest(
  189. artifact_manifest_id,
  190. digest=digest,
  191. )
  192. else:
  193. # In the regular flow, we can recreate the full manifest with the
  194. # updated digest.
  195. #
  196. # NOTE: We do this for backwards compatibility with older backends
  197. # that don't support the 'updateArtifactManifest' API.
  198. _, resp = self._api.create_artifact_manifest(
  199. manifest_filename,
  200. digest,
  201. artifact_id,
  202. base_artifact_id=base_id,
  203. )
  204. # We're duplicating the file upload logic a little, which isn't great.
  205. upload_url = resp["uploadUrl"]
  206. upload_headers = resp["uploadHeaders"]
  207. extra_headers = {}
  208. for upload_header in upload_headers:
  209. key, val = upload_header.split(":", 1)
  210. extra_headers[key] = val
  211. with open(path, "rb") as fp2:
  212. self._api.upload_file_retry(
  213. upload_url,
  214. fp2,
  215. extra_headers=extra_headers,
  216. )
  217. commit_result: concurrent.futures.Future[None] = concurrent.futures.Future()
  218. # Queue the commit. It will only happen after all file uploads finish.
  219. self._file_pusher.commit_artifact(
  220. artifact_id,
  221. finalize=finalize,
  222. before_commit=before_commit,
  223. result_future=commit_result,
  224. )
  225. # Block until all artifact files are uploaded and the
  226. # artifact is committed.
  227. try:
  228. commit_result.result()
  229. finally:
  230. step_prepare.shutdown()
  231. if finalize and use_after_commit:
  232. self._api.use_artifact(
  233. artifact_id,
  234. artifact_entity_name=entity,
  235. artifact_project_name=project,
  236. )
  237. return self._server_artifact
  238. def _resolve_client_id_manifest_references(self) -> None:
  239. for entry_path in self._manifest.entries:
  240. entry = self._manifest.entries[entry_path]
  241. if entry.ref is not None and entry.ref.startswith("wandb-client-artifact:"):
  242. client_id = util.host_from_path(entry.ref)
  243. artifact_file_path = util.uri_from_path(entry.ref)
  244. artifact_id = self._api._resolve_client_id(client_id)
  245. if artifact_id is None:
  246. raise RuntimeError(f"Could not resolve client id {client_id}")
  247. entry.ref = URIStr(
  248. f"wandb-artifact://{b64_to_hex_id(B64MD5(artifact_id))}/{artifact_file_path}"
  249. )