| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- """Artifact manifest."""
- from __future__ import annotations
- from abc import ABC, abstractmethod
- from typing import TYPE_CHECKING, Annotated, Any, Dict # noqa: UP035
- from pydantic import Field
- from wandb.sdk.lib.hashutil import HexMD5
- from ._models.base_model import ArtifactsBase
- if TYPE_CHECKING:
- from .artifact_manifest_entry import ArtifactManifestEntry
- from .storage_policy import StoragePolicy
- class ArtifactManifest(ArtifactsBase, ABC):
- # Note: we can't name this "version" since it conflicts with the prior
- # `version()` classmethod.
- manifest_version: Annotated[Any, Field(repr=False)]
- entries: Dict[str, ArtifactManifestEntry] = Field(default_factory=dict) # noqa: UP006
- storage_policy: Annotated[StoragePolicy, Field(exclude=True, repr=False)]
- @classmethod
- def version(cls) -> int:
- return cls.model_fields["manifest_version"].default
- @classmethod
- @abstractmethod
- def from_manifest_json(cls, manifest_json: dict[str, Any]) -> ArtifactManifest:
- if (version := manifest_json.get("version")) is None:
- raise ValueError("Invalid manifest format. Must contain version field.")
- for sub in cls.__subclasses__():
- if sub.version() == version:
- return sub.from_manifest_json(manifest_json)
- raise ValueError("Invalid manifest version.")
- def __len__(self) -> int:
- return len(self.entries)
- @abstractmethod
- def to_manifest_json(self) -> dict[str, Any]:
- raise NotImplementedError
- @abstractmethod
- def digest(self) -> HexMD5:
- raise NotImplementedError
- @abstractmethod
- def size(self) -> int:
- raise NotImplementedError
- def add_entry(self, entry: ArtifactManifestEntry, overwrite: bool = False) -> None:
- if (
- (not overwrite)
- and (old_entry := self.entries.get(entry.path))
- and (entry.digest != old_entry.digest)
- ):
- raise ValueError(f"Cannot add the same path twice: {entry.path!r}")
- self.entries[entry.path] = entry
- def remove_entry(self, entry: ArtifactManifestEntry) -> None:
- try:
- del self.entries[entry.path]
- except LookupError:
- raise FileNotFoundError(f"Cannot remove missing entry: {entry.path!r}")
- def get_entry_by_path(self, path: str) -> ArtifactManifestEntry | None:
- return self.entries.get(path)
- def get_entries_in_directory(self, directory: str) -> list[ArtifactManifestEntry]:
- # entry keys (paths) use forward slash even for windows
- dir_prefix = f"{directory}/"
- return [obj for key, obj in self.entries.items() if key.startswith(dir_prefix)]
|