storage_policy.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. """Storage policy."""
  2. from __future__ import annotations
  3. import concurrent.futures
  4. from abc import ABC, abstractmethod
  5. from typing import TYPE_CHECKING, Any
  6. from wandb.sdk.internal.internal_api import Api as InternalApi
  7. from wandb.sdk.lib.paths import FilePathStr, URIStr
  8. if TYPE_CHECKING:
  9. from wandb.filesync.step_prepare import StepPrepare
  10. from wandb.sdk.artifacts._models.storage import StoragePolicyConfig
  11. from wandb.sdk.artifacts.artifact import Artifact
  12. from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
  13. from wandb.sdk.internal.progress import ProgressFn
  14. _POLICY_REGISTRY: dict[str, type[StoragePolicy]] = {}
  15. class StoragePolicy(ABC):
  16. _api: InternalApi | None = None
  17. def __init_subclass__(cls, **kwargs: Any) -> None:
  18. super().__init_subclass__(**kwargs)
  19. _POLICY_REGISTRY[cls.name()] = cls
  20. @classmethod
  21. def lookup_by_name(cls, name: str) -> type[StoragePolicy]:
  22. if policy := _POLICY_REGISTRY.get(name):
  23. return policy
  24. raise ValueError(f"Failed to find storage policy {name!r}")
  25. @classmethod
  26. @abstractmethod
  27. def name(cls) -> str:
  28. raise NotImplementedError
  29. @classmethod
  30. @abstractmethod
  31. def from_config(cls, config: StoragePolicyConfig) -> StoragePolicy:
  32. raise NotImplementedError
  33. @abstractmethod
  34. def config(self) -> dict[str, Any]:
  35. raise NotImplementedError
  36. @abstractmethod
  37. def load_file(
  38. self,
  39. artifact: Artifact,
  40. manifest_entry: ArtifactManifestEntry,
  41. dest_path: str | None = None,
  42. executor: concurrent.futures.Executor | None = None,
  43. ) -> FilePathStr:
  44. raise NotImplementedError
  45. @abstractmethod
  46. def store_file(
  47. self,
  48. artifact_id: str,
  49. artifact_manifest_id: str,
  50. entry: ArtifactManifestEntry,
  51. preparer: StepPrepare,
  52. progress_callback: ProgressFn | None = None,
  53. ) -> bool:
  54. raise NotImplementedError
  55. @abstractmethod
  56. def store_reference(
  57. self,
  58. artifact: Artifact,
  59. path: URIStr | FilePathStr,
  60. name: str | None = None,
  61. checksum: bool = True,
  62. max_objects: int | None = None,
  63. ) -> list[ArtifactManifestEntry]:
  64. raise NotImplementedError
  65. @abstractmethod
  66. def load_reference(
  67. self,
  68. manifest_entry: ArtifactManifestEntry,
  69. local: bool = False,
  70. dest_path: str | None = None,
  71. ) -> FilePathStr | URIStr:
  72. raise NotImplementedError