artifact_file_cache.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. """Artifact cache."""
  2. from __future__ import annotations
  3. import contextlib
  4. import errno
  5. import hashlib
  6. import os
  7. import shutil
  8. import subprocess
  9. import sys
  10. from collections.abc import Iterator
  11. from contextlib import AbstractContextManager
  12. from functools import lru_cache
  13. from pathlib import Path
  14. from tempfile import NamedTemporaryFile
  15. from typing import IO, Protocol
  16. import wandb
  17. from wandb import env, util
  18. from wandb.sdk.lib.filesystem import files_in
  19. from wandb.sdk.lib.hashutil import B64MD5, ETag, b64_to_hex_id
  20. from wandb.sdk.lib.paths import FilePathStr, StrPath, URIStr
  21. class Opener(Protocol):
  22. def __call__(self, mode: str = ...) -> AbstractContextManager[IO]: ...
  23. def artifacts_cache_dir() -> Path:
  24. """Get the artifacts cache directory."""
  25. return env.get_cache_dir() / "artifacts"
  26. def _get_sys_umask_threadsafe() -> int:
  27. # Workaround to get the current system umask, since
  28. # - `os.umask()` isn't thread-safe
  29. # - we don't want to inadvertently change the umask of the current process
  30. # See: https://stackoverflow.com/questions/53227072/reading-umask-thread-safe
  31. umask_cmd = (sys.executable, "-c", "import os; print(os.umask(22))")
  32. return int(subprocess.check_output(umask_cmd))
  33. class ArtifactFileCache:
  34. def __init__(self, cache_dir: StrPath) -> None:
  35. self._cache_dir = Path(cache_dir)
  36. self._obj_dir = self._cache_dir / "obj"
  37. self._temp_dir = self._cache_dir / "tmp"
  38. self._ensure_write_permissions()
  39. # NamedTemporaryFile sets the file mode to 600 [1], we reset to the default.
  40. # [1] https://stackoverflow.com/questions/10541760/can-i-set-the-umask-for-tempfile-namedtemporaryfile-in-python
  41. self._sys_umask = _get_sys_umask_threadsafe()
  42. self._override_cache_path: StrPath | None = None
  43. def check_md5_obj_path(
  44. self, b64_md5: B64MD5, size: int
  45. ) -> tuple[FilePathStr, bool, Opener]:
  46. # Check if we're using vs skipping the cache
  47. if self._override_cache_path is not None:
  48. skip_cache = True
  49. path = Path(self._override_cache_path)
  50. else:
  51. skip_cache = False
  52. hex_md5 = b64_to_hex_id(b64_md5)
  53. path = self._obj_dir / "md5" / hex_md5[:2] / hex_md5[2:]
  54. return self._check_or_create(path, size, skip_cache=skip_cache)
  55. # TODO(spencerpearson): this method at least needs its signature changed.
  56. # An ETag is not (necessarily) a checksum.
  57. def check_etag_obj_path(
  58. self,
  59. url: URIStr,
  60. etag: ETag,
  61. size: int,
  62. ) -> tuple[FilePathStr, bool, Opener]:
  63. # Check if we're using vs skipping the cache
  64. if self._override_cache_path is not None:
  65. skip_cache = True
  66. path = Path(self._override_cache_path)
  67. else:
  68. skip_cache = False
  69. hexhash = hashlib.sha256(
  70. hashlib.sha256(url.encode("utf-8")).digest()
  71. + hashlib.sha256(etag.encode("utf-8")).digest()
  72. ).hexdigest()
  73. path = self._obj_dir / "etag" / hexhash[:2] / hexhash[2:]
  74. return self._check_or_create(path, size, skip_cache=skip_cache)
  75. def _check_or_create(
  76. self, path: Path, size: int, skip_cache: bool = False
  77. ) -> tuple[FilePathStr, bool, Opener]:
  78. opener = self._opener(path, size, skip_cache=skip_cache)
  79. hit = path.is_file() and path.stat().st_size == size
  80. return FilePathStr(path), hit, opener
  81. def cleanup(
  82. self,
  83. target_size: int | None = None,
  84. remove_temp: bool = False,
  85. target_fraction: float | None = None,
  86. ) -> int:
  87. """Clean up the cache, removing the least recently used files first.
  88. Args:
  89. target_size: The target size of the cache in bytes. If the cache is larger
  90. than this, we will remove the least recently used files until the cache
  91. is smaller than this size.
  92. remove_temp: Whether to remove temporary files. Temporary files are files
  93. that are currently being written to the cache. If remove_temp is True,
  94. all temp files will be removed, regardless of the target_size or
  95. target_fraction.
  96. target_fraction: The target fraction of the cache to reclaim. If the cache
  97. is larger than this, we will remove the least recently used files until
  98. the cache is smaller than this fraction of its current size. It is an
  99. error to specify both target_size and target_fraction.
  100. Returns:
  101. The number of bytes reclaimed.
  102. """
  103. if target_size is None and target_fraction is None:
  104. # Default to clearing the entire cache.
  105. target_size = 0
  106. if target_size is not None and target_fraction is not None:
  107. raise ValueError("Cannot specify both target_size and target_fraction")
  108. if target_size is not None and target_size < 0:
  109. raise ValueError("target_size must be non-negative")
  110. if target_fraction is not None and (target_fraction < 0 or target_fraction > 1):
  111. raise ValueError("target_fraction must be between 0 and 1")
  112. bytes_reclaimed = 0
  113. total_size = 0
  114. temp_size = 0
  115. # Remove all temporary files if requested. Otherwise sum their size.
  116. for entry in files_in(self._temp_dir):
  117. size = entry.stat().st_size
  118. total_size += size
  119. if remove_temp:
  120. try:
  121. os.remove(entry.path)
  122. bytes_reclaimed += size
  123. except OSError:
  124. pass
  125. else:
  126. temp_size += size
  127. if temp_size:
  128. wandb.termwarn(
  129. f"Cache contains {util.to_human_size(temp_size)} of temporary files. "
  130. "Run `wandb artifact cache cleanup --remove-temp` to remove them."
  131. )
  132. entries = []
  133. for file_entry in files_in(self._obj_dir):
  134. total_size += file_entry.stat().st_size
  135. entries.append(file_entry)
  136. if target_fraction is not None:
  137. target_size = int(total_size * target_fraction)
  138. assert target_size is not None
  139. for entry in sorted(entries, key=lambda x: x.stat().st_atime):
  140. if total_size <= target_size:
  141. return bytes_reclaimed
  142. try:
  143. os.remove(entry.path)
  144. except OSError:
  145. pass
  146. total_size -= entry.stat().st_size
  147. bytes_reclaimed += entry.stat().st_size
  148. if total_size > target_size:
  149. wandb.termerror(
  150. f"Failed to reclaim enough space in {self._cache_dir}. Try running"
  151. " `wandb artifact cache cleanup --remove-temp` to remove temporary files."
  152. )
  153. return bytes_reclaimed
  154. def _free_space(self) -> int:
  155. """Return the number of bytes of free space in the cache directory."""
  156. return shutil.disk_usage(self._cache_dir)[2]
  157. def _reserve_space(self, size: int) -> None:
  158. """If a `size` write would exceed disk space, remove cached items to make space.
  159. Raises:
  160. OSError: If there is not enough space to write `size` bytes, even after
  161. removing cached items.
  162. """
  163. if size <= self._free_space():
  164. return
  165. wandb.termwarn("Cache size exceeded. Attempting to reclaim space...")
  166. self.cleanup(target_fraction=0.5)
  167. if size <= self._free_space():
  168. return
  169. self.cleanup(target_size=0)
  170. if size > self._free_space():
  171. raise OSError(errno.ENOSPC, f"Insufficient free space in {self._cache_dir}")
  172. def _opener(self, path: Path, size: int, skip_cache: bool = False) -> Opener:
  173. @contextlib.contextmanager
  174. def atomic_open(mode: str = "w") -> Iterator[IO]:
  175. if "a" in mode:
  176. raise ValueError("Appending to cache files is not supported")
  177. if skip_cache:
  178. # Skip the cache but still use an intermediate temporary file to
  179. # ensure atomicity. Place the temp file in the same root as the
  180. # destination file to avoid cross-filesystem move/copy operations.
  181. temp_dir = path.parent
  182. else:
  183. self._reserve_space(size)
  184. temp_dir = self._temp_dir
  185. temp_dir.mkdir(parents=True, exist_ok=True)
  186. temp_file = NamedTemporaryFile(dir=temp_dir, mode=mode, delete=False)
  187. try:
  188. yield temp_file
  189. temp_file.close()
  190. os.chmod(temp_file.name, 0o666 & ~self._sys_umask)
  191. path.parent.mkdir(parents=True, exist_ok=True)
  192. os.replace(temp_file.name, path)
  193. except Exception:
  194. os.remove(temp_file.name)
  195. raise
  196. return atomic_open
  197. def _ensure_write_permissions(self) -> None:
  198. """Raise an error if we cannot write to the cache directory."""
  199. try:
  200. self._temp_dir.mkdir(parents=True, exist_ok=True)
  201. with NamedTemporaryFile(dir=self._temp_dir) as f:
  202. f.write(b"wandb")
  203. except PermissionError as e:
  204. raise PermissionError(
  205. f"Unable to write to {self._cache_dir}. "
  206. "Ensure that the current user has write permissions."
  207. ) from e
  208. # Memo `ArtifactFileCache` instances while avoiding reliance on global
  209. # variable(s). Notes:
  210. # - @lru_cache should be thread-safe.
  211. # - We don't memoize `get_artifact_file_cache` directly, as the cache_dir
  212. # may change at runtime. This is likely rare in practice, though.
  213. @lru_cache(maxsize=1)
  214. def _build_artifact_file_cache(cache_dir: StrPath) -> ArtifactFileCache:
  215. return ArtifactFileCache(cache_dir)
  216. def get_artifact_file_cache() -> ArtifactFileCache:
  217. return _build_artifact_file_cache(artifacts_cache_dir())