artifact_manifest.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. """Artifact manifest."""
  2. from __future__ import annotations
  3. from abc import ABC, abstractmethod
  4. from typing import TYPE_CHECKING, Annotated, Any, Dict # noqa: UP035
  5. from pydantic import Field
  6. from wandb.sdk.lib.hashutil import HexMD5
  7. from ._models.base_model import ArtifactsBase
  8. if TYPE_CHECKING:
  9. from .artifact_manifest_entry import ArtifactManifestEntry
  10. from .storage_policy import StoragePolicy
  11. class ArtifactManifest(ArtifactsBase, ABC):
  12. # Note: we can't name this "version" since it conflicts with the prior
  13. # `version()` classmethod.
  14. manifest_version: Annotated[Any, Field(repr=False)]
  15. entries: Dict[str, ArtifactManifestEntry] = Field(default_factory=dict) # noqa: UP006
  16. storage_policy: Annotated[StoragePolicy, Field(exclude=True, repr=False)]
  17. @classmethod
  18. def version(cls) -> int:
  19. return cls.model_fields["manifest_version"].default
  20. @classmethod
  21. @abstractmethod
  22. def from_manifest_json(cls, manifest_json: dict[str, Any]) -> ArtifactManifest:
  23. if (version := manifest_json.get("version")) is None:
  24. raise ValueError("Invalid manifest format. Must contain version field.")
  25. for sub in cls.__subclasses__():
  26. if sub.version() == version:
  27. return sub.from_manifest_json(manifest_json)
  28. raise ValueError("Invalid manifest version.")
  29. def __len__(self) -> int:
  30. return len(self.entries)
  31. @abstractmethod
  32. def to_manifest_json(self) -> dict[str, Any]:
  33. raise NotImplementedError
  34. @abstractmethod
  35. def digest(self) -> HexMD5:
  36. raise NotImplementedError
  37. @abstractmethod
  38. def size(self) -> int:
  39. raise NotImplementedError
  40. def add_entry(self, entry: ArtifactManifestEntry, overwrite: bool = False) -> None:
  41. if (
  42. (not overwrite)
  43. and (old_entry := self.entries.get(entry.path))
  44. and (entry.digest != old_entry.digest)
  45. ):
  46. raise ValueError(f"Cannot add the same path twice: {entry.path!r}")
  47. self.entries[entry.path] = entry
  48. def remove_entry(self, entry: ArtifactManifestEntry) -> None:
  49. try:
  50. del self.entries[entry.path]
  51. except LookupError:
  52. raise FileNotFoundError(f"Cannot remove missing entry: {entry.path!r}")
  53. def get_entry_by_path(self, path: str) -> ArtifactManifestEntry | None:
  54. return self.entries.get(path)
  55. def get_entries_in_directory(self, directory: str) -> list[ArtifactManifestEntry]:
  56. # entry keys (paths) use forward slash even for windows
  57. dir_prefix = f"{directory}/"
  58. return [obj for key, obj in self.entries.items() if key.startswith(dir_prefix)]