experiment_state.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. import fnmatch
  2. import logging
  3. import os
  4. import time
  5. from collections import Counter
  6. from pathlib import Path
  7. from typing import Callable, Dict, Optional, Union
  8. import pyarrow.fs
  9. from ray.train._internal.storage import (
  10. StorageContext,
  11. _download_from_fs_path,
  12. _list_at_fs_path,
  13. get_fs_and_path,
  14. )
  15. from ray.tune.experiment.trial import Trial
  16. from ray.tune.impl.out_of_band_serialize_dataset import out_of_band_serialize_dataset
  17. logger = logging.getLogger(__name__)
  18. _SLOW_SYNC_WARNING = (
  19. "This could be due to a large number of trials, "
  20. "large logfiles from lots of reported metrics, or throttling from the "
  21. "remote storage if uploading too frequently.\n"
  22. "You may want to consider switching the `RunConfig(storage_filesystem)`"
  23. " to a more performant storage backend such as s3fs for a "
  24. "S3 storage path.\n"
  25. "You can suppress this error by setting the environment variable "
  26. "TUNE_WARN_SLOW_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a higher "
  27. "value than the current threshold ({threshold})."
  28. )
  29. def _find_newest_experiment_checkpoint(
  30. experiment_path: str, fs: Optional[pyarrow.fs.FileSystem] = None
  31. ) -> Optional[str]:
  32. """Returns file name of most recently created experiment checkpoint.
  33. Args:
  34. experiment_path: Local or remote path to the experiment directory
  35. containing at least one experiment checkpoint file.
  36. Returns:
  37. str: The local or remote path to the latest experiment checkpoint file
  38. based on timestamp. None if no experiment checkpoints were found.
  39. """
  40. from ray.tune.execution.tune_controller import TuneController
  41. fs, experiment_fs_path = get_fs_and_path(experiment_path, storage_filesystem=fs)
  42. filenames = _list_at_fs_path(fs=fs, fs_path=experiment_fs_path)
  43. pattern = TuneController.CKPT_FILE_TMPL.format("*")
  44. matching = fnmatch.filter(filenames, pattern)
  45. if not matching:
  46. return None
  47. filename = max(matching)
  48. return Path(experiment_fs_path, filename).as_posix()
  49. class _ExperimentCheckpointManager:
  50. """Helper class for managing experiment-level checkpoints.
  51. This class implements the ``checkpoint()`` method used to checkpoint
  52. experiment state. When called, this will serialize and write to disk
  53. the state of the trial runner, trial executor, and search algorithm, to
  54. a specified checkpoint file.
  55. The checkpoint period is automatically adjusted to
  56. ``max(10, time_per_checkpoint * 19)``. This means that at most 5% of the
  57. time (1/20) will be used for writing checkpoints, while 95% of the time
  58. (19/20) will be used to handle the rest of the training loop.
  59. """
  60. def __init__(
  61. self,
  62. *,
  63. storage: Optional[StorageContext],
  64. checkpoint_period: Union[int, float, str],
  65. sync_every_n_trial_checkpoints: Optional[int] = None,
  66. ):
  67. self._storage = storage
  68. self._last_save_time = float("-inf")
  69. self._last_sync_time = None
  70. # Dynamic checkpointing period
  71. self._auto_checkpoint_enabled = checkpoint_period == "auto"
  72. if self._auto_checkpoint_enabled:
  73. self._checkpoint_period = 10.0 # Initial value
  74. else:
  75. self._checkpoint_period = float(checkpoint_period)
  76. # TODO(justinvyu): This is a non-performant workaround to force sync
  77. # every num_to_keep checkpoints in order to maintain consistency
  78. # between the experiment state's view of the latest checkpoint,
  79. # and the actual latest checkpoint that was uploaded.
  80. self._sync_every_n_trial_checkpoints = sync_every_n_trial_checkpoints
  81. self._trial_num_checkpoints_since_last_sync: Dict[Trial, int] = Counter()
  82. self._should_force_sync_up: bool = False
  83. self._excessive_sync_threshold = float(
  84. os.environ.get(
  85. "TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S", "5"
  86. )
  87. )
  88. self._slow_sync_threshold = float(
  89. os.environ.get(
  90. "TUNE_WARN_SLOW_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S", "30"
  91. )
  92. )
  93. @property
  94. def auto_checkpoint_enabled(self):
  95. return self._auto_checkpoint_enabled
  96. def _update_auto_checkpoint_time(self, time_taken: float):
  97. if self._auto_checkpoint_enabled:
  98. # Multiplying this time by 19 means we spend ~5% of the time
  99. # writing global checkpoints and 95% of the time processing trials
  100. self._checkpoint_period = max(10.0, time_taken * 19)
  101. logger.debug(
  102. f"Experiment state snapshotting took "
  103. f"{time_taken:.2f} seconds. "
  104. f"Adjusting snapshotting period to "
  105. f"{self._checkpoint_period:.2f} seconds."
  106. )
  107. def sync_up_experiment_state(
  108. self,
  109. save_fn: Callable[[], None],
  110. force: bool = False,
  111. wait: bool = False,
  112. ):
  113. """Saves execution state to the experiment directory on the storage path.
  114. This includes an experiment checkpoint file that contains trial statuses
  115. and the searcher state.
  116. Overwrites the current session checkpoint, which starts when self
  117. is instantiated. Throttle depends on self._checkpoint_period.
  118. Args:
  119. save_fn: Function to call to actually save data to the driver
  120. staging path. The files in the driver staging path will be
  121. uploaded to the storage path.
  122. force: Forces an experiment checkpoint and launches a sync to storage.
  123. This happens regardless of checkpoint_period
  124. wait: Waits for the sync up to complete before returning.
  125. """
  126. driver_staging_path = self._storage.experiment_driver_staging_path
  127. force = force or self._should_force_sync_up
  128. now = time.monotonic()
  129. if now - self._last_save_time < self._checkpoint_period and not force:
  130. return
  131. # Checkpoint
  132. checkpoint_time_start = time.monotonic()
  133. # NOTE: This context manager is for Datasets captured in a trial config.
  134. # This is the case when *tuning over datasets*.
  135. # If the datasets have already been full executed, then serializing
  136. # block refs means that this checkpoint is not usable in a new Ray cluster.
  137. # This context will serialize the dataset execution plan instead, if available.
  138. with out_of_band_serialize_dataset():
  139. save_fn()
  140. def wait_for_sync():
  141. try:
  142. self._storage.syncer.wait()
  143. except Exception:
  144. logger.error(
  145. "Saving experiment state to storage at "
  146. f"'{self._storage.experiment_fs_path}' failed with exception: ",
  147. exc_info=True,
  148. )
  149. if force:
  150. start_time = time.monotonic()
  151. wait_for_sync()
  152. wait_time = time.monotonic() - start_time
  153. if wait_time > self._slow_sync_threshold:
  154. logger.warning(
  155. "Saving the experiment state (which holds a global view "
  156. "of trial statuses and is used to restore the experiment) "
  157. f"took ~{wait_time:.2f} seconds, which may be a performance "
  158. "bottleneck.\n"
  159. f"{_SLOW_SYNC_WARNING.format(threshold=self._slow_sync_threshold)}"
  160. )
  161. time_since_last_sync = (
  162. time.monotonic() - self._last_sync_time
  163. if self._last_sync_time is not None
  164. else None
  165. )
  166. launched_sync = self._storage.syncer.sync_up(
  167. driver_staging_path, self._storage.experiment_fs_path
  168. )
  169. if launched_sync:
  170. if (
  171. time_since_last_sync is not None
  172. and time_since_last_sync < self._excessive_sync_threshold
  173. and self._should_force_sync_up
  174. ):
  175. logger.warning(
  176. "Experiment state snapshotting has been triggered multiple "
  177. f"times in the last {self._excessive_sync_threshold} seconds "
  178. "and may become a bottleneck. "
  179. "A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, "
  180. "and a trial has checkpointed >= `num_to_keep` times "
  181. "since the last snapshot.\n"
  182. "You may want to consider increasing the "
  183. "`CheckpointConfig(num_to_keep)` or decreasing the frequency of "
  184. "saving checkpoints.\n"
  185. "You can suppress this warning by setting the environment variable "
  186. "TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S "
  187. "to a smaller value than the current threshold "
  188. f"({self._excessive_sync_threshold}). "
  189. "Set it to 0 to completely suppress this warning."
  190. )
  191. self._last_sync_time = time.monotonic()
  192. # We just synced, so reset the force flag
  193. self._trial_num_checkpoints_since_last_sync.clear()
  194. self._should_force_sync_up = False
  195. else:
  196. if (
  197. time_since_last_sync is not None
  198. and time_since_last_sync > self._slow_sync_threshold
  199. ):
  200. logger.warning(
  201. "Saving the experiment state (which holds a global view "
  202. "of trial statuses and is used to restore the experiment) "
  203. f"has already taken {time_since_last_sync:.2f} seconds, "
  204. "which may cause consistency issues upon restoration if your "
  205. "driver script ungracefully exits.\n"
  206. f"{_SLOW_SYNC_WARNING.format(threshold=self._slow_sync_threshold)}"
  207. )
  208. if wait:
  209. wait_for_sync()
  210. checkpoint_time_taken = time.monotonic() - checkpoint_time_start
  211. # Adjust dynamic checkpointing
  212. self._update_auto_checkpoint_time(time_taken=checkpoint_time_taken)
  213. # Finish
  214. self._last_save_time = time.monotonic()
  215. def sync_down_experiment_state(self) -> None:
  216. fs = self._storage.storage_filesystem
  217. filepaths = _list_at_fs_path(fs=fs, fs_path=self._storage.experiment_fs_path)
  218. # TODO(ekl) we should refactor our restore code to read the necessary data
  219. # directly from the storage context. As a temporary hack, restore all the
  220. # serialized files from the root dir where other modules expect them to be.
  221. matches = [
  222. path
  223. for path in filepaths
  224. if path.endswith(".json") or path.endswith(".pkl")
  225. ]
  226. for relpath in matches:
  227. fs_path = Path(self._storage.experiment_fs_path, relpath).as_posix()
  228. local_path = Path(
  229. self._storage.experiment_driver_staging_path, relpath
  230. ).as_posix()
  231. _download_from_fs_path(fs=fs, fs_path=fs_path, local_path=local_path)
  232. logger.debug(
  233. f"Copied {matches} from:\n(fs, path) = "
  234. f"({self._storage.storage_filesystem.type_name}, "
  235. f"{self._storage.experiment_fs_path})\n"
  236. f"-> {self._storage.experiment_driver_staging_path}"
  237. )
  238. def on_trial_checkpoint(self, trial: Trial):
  239. if not self._sync_every_n_trial_checkpoints:
  240. return
  241. self._trial_num_checkpoints_since_last_sync[trial] += 1
  242. if (
  243. self._trial_num_checkpoints_since_last_sync[trial]
  244. >= self._sync_every_n_trial_checkpoints
  245. ):
  246. self._should_force_sync_up = True