_checkpoint.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import contextlib
  2. import glob
  3. import json
  4. import logging
  5. import os
  6. import platform
  7. import shutil
  8. import tempfile
  9. import traceback
  10. import uuid
  11. from pathlib import Path
  12. from typing import Any, Dict, Iterator, List, Optional, Union
  13. import pyarrow.fs
  14. from ray.air._internal.filelock import TempFileLock
  15. from ray.train._internal.storage import _download_from_fs_path, _exists_at_fs_path
  16. from ray.util.annotations import PublicAPI
  17. logger = logging.getLogger(__name__)
  18. # The filename of the file that stores user metadata set on the checkpoint.
  19. _METADATA_FILE_NAME = ".metadata.json"
  20. # The prefix of the temp checkpoint directory that `to_directory` downloads to
  21. # on the local filesystem.
  22. _CHECKPOINT_TEMP_DIR_PREFIX = "checkpoint_tmp_"
  23. class _CheckpointMetaClass(type):
  24. def __getattr__(self, item):
  25. try:
  26. return super().__getattribute__(item)
  27. except AttributeError as exc:
  28. if item in {
  29. "from_dict",
  30. "to_dict",
  31. "from_bytes",
  32. "to_bytes",
  33. "get_internal_representation",
  34. }:
  35. raise _get_migration_error(item) from exc
  36. elif item in {
  37. "from_uri",
  38. "to_uri",
  39. "uri",
  40. }:
  41. raise _get_uri_error(item) from exc
  42. elif item in {"get_preprocessor", "set_preprocessor"}:
  43. raise _get_preprocessor_error(item) from exc
  44. raise exc
  45. @PublicAPI(stability="beta")
  46. class Checkpoint(metaclass=_CheckpointMetaClass):
  47. """A reference to data persisted as a directory in local or remote storage.
  48. Access the checkpoint contents locally using ``checkpoint.to_directory()``
  49. or ``checkpoint.as_directory``.
  50. Attributes
  51. ----------
  52. path: A path on the filesystem containing the checkpoint contents.
  53. filesystem: PyArrow FileSystem that can be used to access data at the `path`.
  54. See Also
  55. --------
  56. ray.train.report : Report a checkpoint during training (with Ray Train/Tune).
  57. ray.train.get_checkpoint : Get the latest checkpoint during training
  58. (for restoration).
  59. :ref:`train-checkpointing`
  60. :ref:`persistent-storage-guide`
  61. Examples
  62. --------
  63. Creating a checkpoint using ``Checkpoint.from_directory``:
  64. >>> from ray.train import Checkpoint
  65. >>> checkpoint = Checkpoint.from_directory("/tmp/example_checkpoint_dir")
  66. >>> checkpoint.filesystem # doctest: +ELLIPSIS
  67. <pyarrow._fs.LocalFileSystem object...
  68. >>> checkpoint.path
  69. '/tmp/example_checkpoint_dir'
  70. Creating a checkpoint from a remote URI:
  71. >>> checkpoint = Checkpoint("s3://bucket/path/to/checkpoint")
  72. >>> checkpoint.filesystem # doctest: +ELLIPSIS
  73. <pyarrow._s3fs.S3FileSystem object...
  74. >>> checkpoint.path
  75. 'bucket/path/to/checkpoint'
  76. Creating a checkpoint with a custom filesystem:
  77. >>> checkpoint = Checkpoint(
  78. ... path="bucket/path/to/checkpoint",
  79. ... filesystem=pyarrow.fs.S3FileSystem(),
  80. ... )
  81. >>> checkpoint.filesystem # doctest: +ELLIPSIS
  82. <pyarrow._s3fs.S3FileSystem object...
  83. >>> checkpoint.path
  84. 'bucket/path/to/checkpoint'
  85. Accessing a checkpoint's contents:
  86. >>> import os # doctest: +SKIP
  87. >>> with checkpoint.as_directory() as local_checkpoint_dir: # doctest: +SKIP
  88. ... print(os.listdir(local_checkpoint_dir)) # doctest: +SKIP
  89. ['model.pt', 'optimizer.pt', 'misc.pt']
  90. """
  91. def __init__(
  92. self,
  93. path: Union[str, os.PathLike],
  94. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  95. ):
  96. """Construct a Checkpoint.
  97. Args:
  98. path: A local path or remote URI containing the checkpoint data.
  99. If a filesystem is provided, then this path must NOT be a URI.
  100. It should be a path on the filesystem with the prefix already stripped.
  101. filesystem: PyArrow FileSystem to use to access data at the path.
  102. If not specified, this is inferred from the URI scheme.
  103. """
  104. self.path = str(path)
  105. self.filesystem = filesystem
  106. if path and not filesystem:
  107. self.filesystem, self.path = pyarrow.fs.FileSystem.from_uri(path)
  108. # This random UUID is used to create a temporary directory name on the
  109. # local filesystem, which will be used for downloading checkpoint data.
  110. # This ensures that if multiple processes download the same checkpoint object
  111. # only one process performs the actual download while the others wait.
  112. # This prevents duplicated download efforts and data.
  113. # NOTE: Calling `to_directory` from multiple `Checkpoint` objects
  114. # that point to the same (fs, path) will still download the data multiple times.
  115. # This only ensures a canonical temp directory name for a single `Checkpoint`.
  116. self._uuid = uuid.uuid4()
  117. def __repr__(self):
  118. return f"Checkpoint(filesystem={self.filesystem.type_name}, path={self.path})"
  119. def get_metadata(self) -> Dict[str, Any]:
  120. """Return the metadata dict stored with the checkpoint.
  121. If no metadata is stored, an empty dict is returned.
  122. """
  123. metadata_path = Path(self.path, _METADATA_FILE_NAME).as_posix()
  124. if not _exists_at_fs_path(self.filesystem, metadata_path):
  125. return {}
  126. with self.filesystem.open_input_file(metadata_path) as f:
  127. return json.loads(f.readall().decode("utf-8"))
  128. def set_metadata(self, metadata: Dict[str, Any]) -> None:
  129. """Set the metadata stored with this checkpoint.
  130. This will overwrite any existing metadata stored with this checkpoint.
  131. """
  132. metadata_path = Path(self.path, _METADATA_FILE_NAME).as_posix()
  133. with self.filesystem.open_output_stream(metadata_path) as f:
  134. f.write(json.dumps(metadata).encode("utf-8"))
  135. def update_metadata(self, metadata: Dict[str, Any]) -> None:
  136. """Update the metadata stored with this checkpoint.
  137. This will update any existing metadata stored with this checkpoint.
  138. """
  139. existing_metadata = self.get_metadata()
  140. existing_metadata.update(metadata)
  141. self.set_metadata(existing_metadata)
  142. @classmethod
  143. def from_directory(cls, path: Union[str, os.PathLike]) -> "Checkpoint":
  144. """Create checkpoint object from a local directory.
  145. Args:
  146. path: Local directory containing checkpoint data.
  147. Returns:
  148. A ray.train.Checkpoint object.
  149. """
  150. return cls(path, filesystem=pyarrow.fs.LocalFileSystem())
  151. def to_directory(self, path: Optional[Union[str, os.PathLike]] = None) -> str:
  152. """Write checkpoint data to a local directory.
  153. *If multiple processes on the same node call this method simultaneously,*
  154. only a single process will perform the download, while the others
  155. wait for the download to finish. Once the download finishes, all processes
  156. receive the same local directory to read from.
  157. Args:
  158. path: Target directory to download data to. If not specified,
  159. this method will use a temporary directory.
  160. Returns:
  161. str: Directory containing checkpoint data.
  162. """
  163. user_provided_path = path is not None
  164. local_path = (
  165. path if user_provided_path else self._get_temporary_checkpoint_dir()
  166. )
  167. local_path = os.path.normpath(os.path.expanduser(str(local_path)))
  168. os.makedirs(local_path, exist_ok=True)
  169. try:
  170. # Timeout 0 means there will be only one attempt to acquire
  171. # the file lock. If it cannot be acquired, throw a TimeoutError
  172. with TempFileLock(local_path, timeout=0):
  173. _download_from_fs_path(
  174. fs=self.filesystem, fs_path=self.path, local_path=local_path
  175. )
  176. except TimeoutError:
  177. # if the directory is already locked, then wait but do not do anything.
  178. with TempFileLock(local_path, timeout=-1):
  179. pass
  180. if not os.path.exists(local_path):
  181. raise RuntimeError(
  182. f"Checkpoint directory {local_path} does not exist, "
  183. "even though it should have been created by "
  184. "another process. Please raise an issue on GitHub: "
  185. "https://github.com/ray-project/ray/issues"
  186. )
  187. return local_path
  188. @contextlib.contextmanager
  189. def as_directory(self) -> Iterator[str]:
  190. """Returns checkpoint contents in a local directory as a context.
  191. This function makes checkpoint data available as a directory while avoiding
  192. unnecessary copies and left-over temporary data.
  193. *If the checkpoint points to a local directory*, this method just returns the
  194. local directory path without making a copy, and nothing will be cleaned up
  195. after exiting the context.
  196. *If the checkpoint points to a remote directory*, this method will download the
  197. checkpoint to a local temporary directory and return the path
  198. to the temporary directory.
  199. *If multiple processes on the same node call this method simultaneously,*
  200. only a single process will perform the download, while the others
  201. wait for the download to finish. Once the download finishes, all processes
  202. receive the same local (temporary) directory to read from.
  203. Once all processes have finished working with the checkpoint,
  204. the temporary directory is cleaned up.
  205. Users should treat the returned checkpoint directory as read-only and avoid
  206. changing any data within it, as it may be deleted when exiting the context.
  207. Example:
  208. .. testcode::
  209. :hide:
  210. from pathlib import Path
  211. import tempfile
  212. from ray.train import Checkpoint
  213. temp_dir = tempfile.mkdtemp()
  214. (Path(temp_dir) / "example.txt").write_text("example checkpoint data")
  215. checkpoint = Checkpoint.from_directory(temp_dir)
  216. .. testcode::
  217. with checkpoint.as_directory() as checkpoint_dir:
  218. # Do some read-only processing of files within checkpoint_dir
  219. pass
  220. # At this point, if a temporary directory was created, it will have
  221. # been deleted.
  222. """
  223. if isinstance(self.filesystem, pyarrow.fs.LocalFileSystem):
  224. yield self.path
  225. else:
  226. del_lock_path = _get_del_lock_path(self._get_temporary_checkpoint_dir())
  227. open(del_lock_path, "a").close()
  228. temp_dir = self.to_directory()
  229. try:
  230. yield temp_dir
  231. finally:
  232. # Always cleanup the del lock after we're done with the directory.
  233. # This avoids leaving a lock file behind in the case of an exception
  234. # in the user code.
  235. try:
  236. os.remove(del_lock_path)
  237. except Exception:
  238. logger.warning(
  239. f"Could not remove {del_lock_path} deletion file lock. "
  240. f"Traceback:\n{traceback.format_exc()}"
  241. )
  242. # If there are no more lock files, that means there are no more
  243. # readers of this directory, and we can safely delete it.
  244. # In the edge case (process crash before del lock file is removed),
  245. # we do not remove the directory at all.
  246. # Since it's in /tmp, this is not that big of a deal.
  247. # check if any lock files are remaining
  248. remaining_locks = _list_existing_del_locks(temp_dir)
  249. if not remaining_locks:
  250. try:
  251. # Timeout 0 means there will be only one attempt to acquire
  252. # the file lock. If it cannot be acquired, a TimeoutError
  253. # will be thrown.
  254. with TempFileLock(temp_dir, timeout=0):
  255. shutil.rmtree(temp_dir, ignore_errors=True)
  256. except TimeoutError:
  257. pass
  258. def _get_temporary_checkpoint_dir(self) -> str:
  259. """Return the name for the temporary checkpoint dir that this checkpoint
  260. will get downloaded to, if accessing via `to_directory` or `as_directory`.
  261. """
  262. tmp_dir_path = tempfile.gettempdir()
  263. checkpoint_dir_name = _CHECKPOINT_TEMP_DIR_PREFIX + self._uuid.hex
  264. if platform.system() == "Windows":
  265. # Max path on Windows is 260 chars, -1 for joining \
  266. # Also leave a little for the del lock
  267. del_lock_name = _get_del_lock_path("")
  268. checkpoint_dir_name = (
  269. _CHECKPOINT_TEMP_DIR_PREFIX
  270. + self._uuid.hex[
  271. -259
  272. + len(_CHECKPOINT_TEMP_DIR_PREFIX)
  273. + len(tmp_dir_path)
  274. + len(del_lock_name) :
  275. ]
  276. )
  277. if not checkpoint_dir_name.startswith(_CHECKPOINT_TEMP_DIR_PREFIX):
  278. raise RuntimeError(
  279. "Couldn't create checkpoint directory due to length "
  280. "constraints. Try specifying a shorter checkpoint path."
  281. )
  282. return Path(tmp_dir_path, checkpoint_dir_name).as_posix()
  283. def __fspath__(self):
  284. raise TypeError(
  285. "You cannot use `Checkpoint` objects directly as paths. "
  286. "Use `Checkpoint.to_directory()` or `Checkpoint.as_directory()` instead."
  287. )
  288. def _get_del_lock_path(path: str, suffix: str = None) -> str:
  289. """Get the path to the deletion lock file for a file/directory at `path`.
  290. Example:
  291. >>> _get_del_lock_path("/tmp/checkpoint_tmp") # doctest: +ELLIPSIS
  292. '/tmp/checkpoint_tmp.del_lock_...
  293. >>> _get_del_lock_path("/tmp/checkpoint_tmp/") # doctest: +ELLIPSIS
  294. '/tmp/checkpoint_tmp.del_lock_...
  295. >>> _get_del_lock_path("/tmp/checkpoint_tmp.txt") # doctest: +ELLIPSIS
  296. '/tmp/checkpoint_tmp.txt.del_lock_...
  297. """
  298. suffix = suffix if suffix is not None else str(os.getpid())
  299. return f"{path.rstrip('/')}.del_lock_{suffix}"
  300. def _list_existing_del_locks(path: str) -> List[str]:
  301. """List all the deletion lock files for a file/directory at `path`.
  302. For example, if 2 checkpoints are being read via `as_directory`,
  303. then this should return a list of 2 deletion lock files.
  304. """
  305. return list(glob.glob(f"{_get_del_lock_path(path, suffix='*')}"))
  306. def _get_migration_error(name: str):
  307. return AttributeError(
  308. f"The new `ray.train.Checkpoint` class does not support `{name}()`. "
  309. f"Instead, only directories are supported.\n\n"
  310. f"Example to store a dictionary in a checkpoint:\n\n"
  311. f"import os, tempfile\n"
  312. f"import ray.cloudpickle as pickle\n"
  313. f"from ray import train\n"
  314. f"from ray.train import Checkpoint\n\n"
  315. f"with tempfile.TemporaryDirectory() as checkpoint_dir:\n"
  316. f" with open(os.path.join(checkpoint_dir, 'data.pkl'), 'wb') as fp:\n"
  317. f" pickle.dump({{'data': 'value'}}, fp)\n\n"
  318. f" checkpoint = Checkpoint.from_directory(checkpoint_dir)\n"
  319. f" train.report(..., checkpoint=checkpoint)\n\n"
  320. f"Example to load a dictionary from a checkpoint:\n\n"
  321. f"if train.get_checkpoint():\n"
  322. f" with train.get_checkpoint().as_directory() as checkpoint_dir:\n"
  323. f" with open(os.path.join(checkpoint_dir, 'data.pkl'), 'rb') as fp:\n"
  324. f" data = pickle.load(fp)"
  325. )
  326. def _get_uri_error(name: str):
  327. return AttributeError(
  328. f"The new `ray.train.Checkpoint` class does not support `{name}()`. "
  329. f"To create a checkpoint from remote storage, create a `Checkpoint` using its "
  330. f"constructor instead of `from_directory`.\n"
  331. f'Example: `Checkpoint(path="s3://a/b/c")`.\n'
  332. f"Then, access the contents of the checkpoint with "
  333. f"`checkpoint.as_directory()` / `checkpoint.to_directory()`.\n"
  334. f"To upload data to remote storage, use e.g. `pyarrow.fs.FileSystem` "
  335. f"or your client of choice."
  336. )
  337. def _get_preprocessor_error(name: str):
  338. return AttributeError(
  339. f"The new `ray.train.Checkpoint` class does not support `{name}()`. "
  340. f"To include preprocessor information in checkpoints, "
  341. f"pass it as metadata in the <Framework>Trainer constructor.\n"
  342. f"Example: `TorchTrainer(..., metadata={{...}})`.\n"
  343. f"After training, access it in the checkpoint via `checkpoint.get_metadata()`. "
  344. f"See here: https://docs.ray.io/en/master/train/user-guides/"
  345. f"data-loading-preprocessing.html#preprocessing-structured-data"
  346. )