artifact_manifest_entry.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. """Artifact manifest entry."""
  2. # Older-style type annotations required for Pydantic v1 / python 3.8 compatibility.
  3. # ruff: noqa: UP006, UP007, UP035, UP045
  4. from __future__ import annotations
  5. import concurrent.futures
  6. import hashlib
  7. import logging
  8. import os
  9. from contextlib import suppress
  10. from os.path import getsize
  11. from typing import TYPE_CHECKING, Annotated, Any, Dict, Final, Optional, Union
  12. from urllib.parse import urlparse
  13. from pydantic import Field, NonNegativeInt
  14. from typing_extensions import Self
  15. from wandb._pydantic import field_validator, model_validator
  16. from wandb._strutils import nameof
  17. from wandb.proto.wandb_telemetry_pb2 import Deprecated
  18. from wandb.sdk.lib.deprecation import warn_and_record_deprecation
  19. from wandb.sdk.lib.filesystem import copy_or_overwrite_changed
  20. from wandb.sdk.lib.hashutil import (
  21. B64MD5,
  22. ETag,
  23. b64_to_hex_id,
  24. hex_to_b64_id,
  25. md5_file_b64,
  26. )
  27. from wandb.sdk.lib.paths import FilePathStr, LogicalPath, URIStr
  28. from ._models.base_model import ArtifactsBase
  29. if TYPE_CHECKING:
  30. from .artifact import Artifact
  31. logger = logging.getLogger(__name__)
  32. _WB_ARTIFACT_SCHEME: Final[str] = "wandb-artifact"
  33. def _checksum_cache_path(file_path: str) -> str:
  34. """Get path for checksum in central cache directory."""
  35. from wandb.sdk.artifacts.artifact_file_cache import artifacts_cache_dir
  36. # Create a unique cache key based on the file's absolute path
  37. abs_path = os.path.abspath(file_path)
  38. path_hash = hashlib.sha256(abs_path.encode()).hexdigest()
  39. # Store in wandb cache directory under checksums subdirectory
  40. cache_dir = artifacts_cache_dir() / "checksums"
  41. cache_dir.mkdir(parents=True, exist_ok=True)
  42. return str(cache_dir / f"{path_hash}.checksum")
  43. def _read_cached_checksum(file_path: str) -> str | None:
  44. """Read checksum from cache if it exists and is valid."""
  45. checksum_path = _checksum_cache_path(file_path)
  46. try:
  47. with open(file_path) as f, open(checksum_path) as f_checksum:
  48. if os.path.getmtime(f_checksum.name) < os.path.getmtime(f.name):
  49. # File was modified after checksum was written
  50. return None
  51. # Read and return the cached checksum
  52. return f_checksum.read().strip()
  53. except OSError:
  54. # File doesn't exist or couldn't be opened
  55. return None
  56. def _write_cached_checksum(file_path: str, checksum: str) -> None:
  57. """Write checksum to cache directory."""
  58. checksum_path = _checksum_cache_path(file_path)
  59. try:
  60. with open(checksum_path, "w") as f:
  61. f.write(checksum)
  62. except OSError:
  63. # Non-critical failure, just log it
  64. logger.debug(f"Failed to write checksum cache for {file_path!r}")
  65. class ArtifactManifestEntry(ArtifactsBase):
  66. """A single entry in an artifact manifest.
  67. External code should avoid instantiating this class directly.
  68. """
  69. path: LogicalPath
  70. digest: Union[B64MD5, ETag, URIStr, FilePathStr]
  71. ref: Union[URIStr, FilePathStr, None] = None
  72. birth_artifact_id: Annotated[Optional[str], Field(alias="birthArtifactID")] = None
  73. size: Optional[NonNegativeInt] = None
  74. extra: Dict[str, Any] = Field(default_factory=dict)
  75. local_path: Optional[str] = None
  76. skip_cache: bool = False
  77. # Note: Pydantic treats these as private attributes, omitting them from
  78. # validation and comparison logic.
  79. _parent_artifact: Optional[Artifact] = None
  80. _download_url: Optional[str] = None
  81. @field_validator("path", mode="before")
  82. def _validate_path(cls, v: Any) -> LogicalPath:
  83. """Coerce `path` to a LogicalPath.
  84. LogicalPath does not implement its own pydantic validator, and adding
  85. one for both pydantic V1 and V2 would add excessive boilerplate. Until
  86. we drop V1 support, coerce to LogicalPath in this field validator.
  87. """
  88. return LogicalPath(v)
  89. @field_validator("local_path", mode="before")
  90. def _validate_local_path(cls, v: Any) -> str | None:
  91. """Coerce `local_path` to a str. Necessary if the input is a `PosixPath`."""
  92. return str(v) if v else None
  93. @model_validator(mode="after")
  94. def _infer_size_from_local_path(self) -> Self:
  95. """If `size` isn't set, try to infer it from `local_path`."""
  96. if (self.size is None) and self.local_path:
  97. self.size = getsize(self.local_path)
  98. return self
  99. def __repr__(self) -> str:
  100. # For compatibility with prior behavior, don't display `extra` if it's empty
  101. exclude = None if self.extra else {"extra"}
  102. repr_dict = self.model_dump(by_alias=False, exclude_none=True, exclude=exclude)
  103. return f"{nameof(type(self))}({', '.join(f'{k}={v!r}' for k, v in repr_dict.items())})"
  104. @property
  105. def name(self) -> LogicalPath:
  106. """Deprecated; use `path` instead."""
  107. warn_and_record_deprecation(
  108. feature=Deprecated(artifactmanifestentry__name=True),
  109. message="ArtifactManifestEntry.name is deprecated, use .path instead.",
  110. )
  111. return self.path
  112. def parent_artifact(self) -> Artifact:
  113. """Get the artifact to which this artifact entry belongs.
  114. Returns:
  115. (PublicArtifact): The parent artifact
  116. """
  117. if self._parent_artifact is None:
  118. raise NotImplementedError
  119. return self._parent_artifact
  120. def download(
  121. self,
  122. root: str | None = None,
  123. skip_cache: bool | None = None,
  124. executor: concurrent.futures.Executor | None = None,
  125. ) -> FilePathStr:
  126. """Download this artifact entry to the specified root path.
  127. Args:
  128. root: (str, optional) The root path in which to download this
  129. artifact entry. Defaults to the artifact's root.
  130. Returns:
  131. (str): The path of the downloaded artifact entry.
  132. """
  133. artifact = self.parent_artifact()
  134. rootdir = artifact._add_download_root(root)
  135. dest_path = os.path.join(rootdir, self.path)
  136. # Skip checking the cache (and possibly downloading) if the file already exists
  137. # and has the digest we're expecting.
  138. # Fast integrity check using cached checksum from persistent cache
  139. with suppress(OSError):
  140. if self.digest == _read_cached_checksum(dest_path):
  141. return FilePathStr(dest_path)
  142. # Fallback to computing/caching the checksum hash
  143. try:
  144. md5_hash = md5_file_b64(dest_path)
  145. except (FileNotFoundError, IsADirectoryError):
  146. logger.debug(f"unable to find {dest_path!r}, skip searching for file")
  147. else:
  148. _write_cached_checksum(dest_path, md5_hash)
  149. if self.digest == md5_hash:
  150. return FilePathStr(dest_path)
  151. # Override the target cache path IF we're skipping the cache.
  152. # Note that `override_cache_path is None` <=> `skip_cache is False`.
  153. override_cache_path = FilePathStr(dest_path) if skip_cache else None
  154. storage_policy = artifact.manifest.storage_policy
  155. if self.ref is not None:
  156. cache_path = storage_policy.load_reference(
  157. self, local=True, dest_path=override_cache_path
  158. )
  159. else:
  160. cache_path = storage_policy.load_file(
  161. artifact, self, dest_path=override_cache_path, executor=executor
  162. )
  163. # Determine the final path
  164. final_path = FilePathStr(
  165. override_cache_path or copy_or_overwrite_changed(cache_path, dest_path)
  166. )
  167. # Cache the checksum for future downloads
  168. _write_cached_checksum(final_path, self.digest)
  169. return final_path
  170. def ref_target(self) -> FilePathStr | URIStr:
  171. """Get the reference URL that is targeted by this artifact entry.
  172. Returns:
  173. (str): The reference URL of this artifact entry.
  174. Raises:
  175. ValueError: If this artifact entry was not a reference.
  176. """
  177. if self.ref is None:
  178. raise ValueError("Only reference entries support ref_target().")
  179. if (parent_artifact := self._parent_artifact) is None:
  180. return self.ref
  181. return parent_artifact.manifest.storage_policy.load_reference(self, local=False)
  182. def ref_url(self) -> str:
  183. """Get a URL to this artifact entry.
  184. These URLs can be referenced by another artifact.
  185. Returns:
  186. (str): A URL representing this artifact entry.
  187. Examples:
  188. Basic usage
  189. ```
  190. ref_url = source_artifact.get_entry("file.txt").ref_url()
  191. derived_artifact.add_reference(ref_url)
  192. ```
  193. """
  194. if (parent_artifact := self.parent_artifact()) is None:
  195. raise ValueError("Parent artifact is not set")
  196. elif (parent_id := parent_artifact.id) is None:
  197. raise ValueError("Parent artifact ID is not set")
  198. return f"{_WB_ARTIFACT_SCHEME}://{b64_to_hex_id(parent_id)}/{self.path}"
  199. def to_json(self) -> dict[str, Any]:
  200. # NOTE: The method name `to_json` is a bit misleading, as this returns a
  201. # python dict, NOT a JSON string. The historical name is kept for continuity,
  202. # but consider deprecating this in favor of `BaseModel.model_dump()`.
  203. return self.model_dump(exclude_none=True) # type: ignore[return-value]
  204. def _is_artifact_reference(self) -> bool:
  205. return self.ref is not None and urlparse(self.ref).scheme == _WB_ARTIFACT_SCHEME
  206. def _referenced_artifact_id(self) -> str | None:
  207. if not self._is_artifact_reference():
  208. return None
  209. return hex_to_b64_id(urlparse(self.ref).netloc)