| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275 |
- """Artifact saver."""
- from __future__ import annotations
- import concurrent.futures
- import json
- import os
- import tempfile
- from collections.abc import Awaitable, Sequence
- from typing import TYPE_CHECKING
- import wandb
- import wandb.filesync.step_prepare
- from wandb import util
- from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
- from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, md5_file_b64
- from wandb.sdk.lib.paths import URIStr
- if TYPE_CHECKING:
- from typing import Protocol
- from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
- from wandb.sdk.internal.file_pusher import FilePusher
- from wandb.sdk.internal.internal_api import Api as InternalApi
- from wandb.sdk.internal.progress import ProgressFn
- class SaveFn(Protocol):
- def __call__(
- self, entry: ArtifactManifestEntry, progress_callback: ProgressFn
- ) -> bool:
- pass
- class SaveFnAsync(Protocol):
- def __call__(
- self, entry: ArtifactManifestEntry, progress_callback: ProgressFn
- ) -> Awaitable[bool]:
- pass
- class ArtifactSaver:
- _server_artifact: dict | None # TODO better define this dict
- def __init__(
- self,
- api: InternalApi,
- digest: str,
- manifest_json: dict,
- file_pusher: FilePusher,
- is_user_created: bool = False,
- ) -> None:
- self._api = api
- self._file_pusher = file_pusher
- self._digest = digest
- self._manifest = ArtifactManifest.from_manifest_json(manifest_json)
- self._manifest.storage_policy._api = self._api
- self._is_user_created = is_user_created
- self._server_artifact = None
- def save(
- self,
- entity: str,
- project: str,
- type: str,
- name: str,
- client_id: str,
- sequence_client_id: str,
- distributed_id: str | None = None,
- finalize: bool = True,
- metadata: dict | None = None,
- ttl_duration_seconds: int | None = None,
- description: str | None = None,
- aliases: Sequence[str] | None = None,
- tags: Sequence[str] | None = None,
- use_after_commit: bool = False,
- incremental: bool = False,
- history_step: int | None = None,
- base_id: str | None = None,
- ) -> dict | None:
- return self._save_internal(
- entity,
- project,
- type,
- name,
- client_id,
- sequence_client_id,
- distributed_id,
- finalize,
- metadata,
- ttl_duration_seconds,
- description,
- aliases,
- tags,
- use_after_commit,
- incremental,
- history_step,
- base_id,
- )
- def _save_internal(
- self,
- entity: str,
- project: str,
- type: str,
- name: str,
- client_id: str,
- sequence_client_id: str,
- distributed_id: str | None = None,
- finalize: bool = True,
- metadata: dict | None = None,
- ttl_duration_seconds: int | None = None,
- description: str | None = None,
- aliases: Sequence[str] | None = None,
- tags: Sequence[str] | None = None,
- use_after_commit: bool = False,
- incremental: bool = False,
- history_step: int | None = None,
- base_id: str | None = None,
- ) -> dict | None:
- alias_specs = []
- for alias in aliases or []:
- alias_specs.append({"artifactCollectionName": name, "alias": alias})
- tag_specs = [{"tagName": tag} for tag in tags or []]
- """Returns the server artifact."""
- self._server_artifact, latest = self._api.create_artifact(
- type,
- name,
- self._digest,
- metadata=metadata,
- ttl_duration_seconds=ttl_duration_seconds,
- aliases=alias_specs,
- tags=tag_specs,
- description=description,
- is_user_created=self._is_user_created,
- distributed_id=distributed_id,
- client_id=client_id,
- sequence_client_id=sequence_client_id,
- history_step=history_step,
- )
- assert self._server_artifact is not None # mypy optionality unwrapper
- artifact_id = self._server_artifact["id"]
- if base_id is None and latest:
- base_id = latest["id"]
- if self._server_artifact["state"] == "COMMITTED":
- if use_after_commit:
- self._api.use_artifact(
- artifact_id,
- artifact_entity_name=entity,
- artifact_project_name=project,
- )
- return self._server_artifact
- if (
- self._server_artifact["state"] != "PENDING"
- # For old servers, see https://github.com/wandb/wandb/pull/6190
- and self._server_artifact["state"] != "DELETED"
- ):
- raise Exception(
- 'Unknown artifact state "{}"'.format(self._server_artifact["state"])
- )
- manifest_type = "FULL"
- manifest_filename = "wandb_manifest.json"
- if incremental:
- manifest_type = "INCREMENTAL"
- manifest_filename = "wandb_manifest.incremental.json"
- elif distributed_id:
- manifest_type = "PATCH"
- manifest_filename = "wandb_manifest.patch.json"
- artifact_manifest_id, _ = self._api.create_artifact_manifest(
- manifest_filename,
- "",
- artifact_id,
- base_artifact_id=base_id,
- include_upload=False,
- type=manifest_type,
- )
- step_prepare = wandb.filesync.step_prepare.StepPrepare(
- self._api, 0.1, 0.01, 1000
- ) # TODO: params
- step_prepare.start()
- # Upload Artifact "L1" files, the actual artifact contents
- self._file_pusher.store_manifest_files(
- self._manifest,
- artifact_id,
- lambda entry, progress_callback: self._manifest.storage_policy.store_file(
- artifact_id,
- artifact_manifest_id,
- entry,
- step_prepare,
- progress_callback=progress_callback,
- ),
- )
- def before_commit() -> None:
- self._resolve_client_id_manifest_references()
- with tempfile.NamedTemporaryFile("w+", suffix=".json", delete=False) as fp:
- path = os.path.abspath(fp.name)
- json.dump(self._manifest.to_manifest_json(), fp, indent=4)
- digest = md5_file_b64(path)
- if distributed_id or incremental:
- # If we're in the distributed flow, we want to update the
- # patch manifest we created with our finalized digest.
- _, resp = self._api.update_artifact_manifest(
- artifact_manifest_id,
- digest=digest,
- )
- else:
- # In the regular flow, we can recreate the full manifest with the
- # updated digest.
- #
- # NOTE: We do this for backwards compatibility with older backends
- # that don't support the 'updateArtifactManifest' API.
- _, resp = self._api.create_artifact_manifest(
- manifest_filename,
- digest,
- artifact_id,
- base_artifact_id=base_id,
- )
- # We're duplicating the file upload logic a little, which isn't great.
- upload_url = resp["uploadUrl"]
- upload_headers = resp["uploadHeaders"]
- extra_headers = {}
- for upload_header in upload_headers:
- key, val = upload_header.split(":", 1)
- extra_headers[key] = val
- with open(path, "rb") as fp2:
- self._api.upload_file_retry(
- upload_url,
- fp2,
- extra_headers=extra_headers,
- )
- commit_result: concurrent.futures.Future[None] = concurrent.futures.Future()
- # Queue the commit. It will only happen after all file uploads finish.
- self._file_pusher.commit_artifact(
- artifact_id,
- finalize=finalize,
- before_commit=before_commit,
- result_future=commit_result,
- )
- # Block until all artifact files are uploaded and the
- # artifact is committed.
- try:
- commit_result.result()
- finally:
- step_prepare.shutdown()
- if finalize and use_after_commit:
- self._api.use_artifact(
- artifact_id,
- artifact_entity_name=entity,
- artifact_project_name=project,
- )
- return self._server_artifact
- def _resolve_client_id_manifest_references(self) -> None:
- for entry_path in self._manifest.entries:
- entry = self._manifest.entries[entry_path]
- if entry.ref is not None and entry.ref.startswith("wandb-client-artifact:"):
- client_id = util.host_from_path(entry.ref)
- artifact_file_path = util.uri_from_path(entry.ref)
- artifact_id = self._api._resolve_client_id(client_id)
- if artifact_id is None:
- raise RuntimeError(f"Could not resolve client id {client_id}")
- entry.ref = URIStr(
- f"wandb-artifact://{b64_to_hex_id(B64MD5(artifact_id))}/{artifact_file_path}"
- )
|