syncer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. import abc
  2. import logging
  3. import threading
  4. import time
  5. import traceback
  6. from dataclasses import dataclass
  7. from typing import Any, Callable, Dict, List, Optional, Tuple
  8. from ray._private.thirdparty.tabulate.tabulate import tabulate
  9. from ray.util.annotations import Deprecated, DeveloperAPI
  10. from ray.widgets import Template
  11. logger = logging.getLogger(__name__)
  12. # Syncing period for syncing checkpoints between nodes or to cloud.
  13. DEFAULT_SYNC_PERIOD = 300
  14. # Default sync timeout after which syncing processes are aborted
  15. DEFAULT_SYNC_TIMEOUT = 1800
  16. @Deprecated
  17. @dataclass
  18. class SyncConfig:
  19. sync_period: int = DEFAULT_SYNC_PERIOD
  20. sync_timeout: int = DEFAULT_SYNC_TIMEOUT
  21. sync_artifacts: bool = False
  22. sync_artifacts_on_checkpoint: bool = True
  23. def _repr_html_(self) -> str:
  24. """Generate an HTML representation of the SyncConfig."""
  25. return Template("scrollableTable.html.j2").render(
  26. table=tabulate(
  27. {
  28. "Setting": ["Sync period", "Sync timeout"],
  29. "Value": [self.sync_period, self.sync_timeout],
  30. },
  31. tablefmt="html",
  32. showindex=False,
  33. headers="keys",
  34. ),
  35. max_height="none",
  36. )
  37. class _BackgroundProcess:
  38. def __init__(self, fn: Callable):
  39. self._fn = fn
  40. self._process = None
  41. self._result = {}
  42. self._start_time = float("-inf")
  43. @property
  44. def is_running(self):
  45. return self._process and self._process.is_alive()
  46. @property
  47. def start_time(self):
  48. return self._start_time
  49. def start(self, *args, **kwargs):
  50. if self.is_running:
  51. return False
  52. self._result = {}
  53. def entrypoint():
  54. try:
  55. result = self._fn(*args, **kwargs)
  56. except Exception as e:
  57. self._result["exception"] = e
  58. return
  59. self._result["result"] = result
  60. self._process = threading.Thread(target=entrypoint)
  61. self._process.daemon = True
  62. self._process.start()
  63. self._start_time = time.time()
  64. def wait(self, timeout: Optional[float] = None) -> Any:
  65. """Waits for the background process to finish running. Waits until the
  66. background process has run for at least `timeout` seconds, counting from
  67. the time when the process was started."""
  68. if not self._process:
  69. return None
  70. time_remaining = None
  71. if timeout:
  72. elapsed = time.time() - self.start_time
  73. time_remaining = max(timeout - elapsed, 0)
  74. self._process.join(timeout=time_remaining)
  75. if self._process.is_alive():
  76. self._process = None
  77. raise TimeoutError(
  78. f"{getattr(self._fn, '__name__', str(self._fn))} did not finish "
  79. f"running within the timeout of {timeout} seconds."
  80. )
  81. self._process = None
  82. exception = self._result.get("exception")
  83. if exception:
  84. raise exception
  85. result = self._result.get("result")
  86. self._result = {}
  87. return result
  88. @DeveloperAPI
  89. class Syncer(abc.ABC):
  90. """Syncer class for synchronizing data between Ray nodes and remote (cloud) storage.
  91. This class handles data transfer for two cases:
  92. 1. Synchronizing data such as experiment state snapshots from the driver to
  93. cloud storage.
  94. 2. Synchronizing data such as trial checkpoints from remote trainables to
  95. cloud storage.
  96. Synchronizing tasks are usually asynchronous and can be awaited using ``wait()``.
  97. The base class implements a ``wait_or_retry()`` API that will retry a failed
  98. sync command.
  99. The base class also exposes an API to only kick off syncs every ``sync_period``
  100. seconds.
  101. Args:
  102. sync_period: The minimum time in seconds between sync operations, as
  103. used by ``sync_up/down_if_needed``.
  104. sync_timeout: The maximum time to wait for a sync process to finish before
  105. issuing a new sync operation. Ex: should be used by ``wait`` if launching
  106. asynchronous sync tasks.
  107. """
  108. def __init__(
  109. self,
  110. sync_period: float = DEFAULT_SYNC_PERIOD,
  111. sync_timeout: float = DEFAULT_SYNC_TIMEOUT,
  112. ):
  113. self.sync_period = sync_period
  114. self.sync_timeout = sync_timeout
  115. self.last_sync_up_time = float("-inf")
  116. self.last_sync_down_time = float("-inf")
  117. @abc.abstractmethod
  118. def sync_up(
  119. self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
  120. ) -> bool:
  121. """Synchronize local directory to remote directory.
  122. This function can spawn an asynchronous process that can be awaited in
  123. ``wait()``.
  124. Args:
  125. local_dir: Local directory to sync from.
  126. remote_dir: Remote directory to sync up to. This is an URI
  127. (``protocol://remote/path``).
  128. exclude: Pattern of files to exclude, e.g.
  129. ``["*/checkpoint_*]`` to exclude trial checkpoints.
  130. Returns:
  131. True if sync process has been spawned, False otherwise.
  132. """
  133. raise NotImplementedError
  134. @abc.abstractmethod
  135. def sync_down(
  136. self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
  137. ) -> bool:
  138. """Synchronize remote directory to local directory.
  139. This function can spawn an asynchronous process that can be awaited in
  140. ``wait()``.
  141. Args:
  142. remote_dir: Remote directory to sync down from. This is an URI
  143. (``protocol://remote/path``).
  144. local_dir: Local directory to sync to.
  145. exclude: Pattern of files to exclude, e.g.
  146. ``["*/checkpoint_*]`` to exclude trial checkpoints.
  147. Returns:
  148. True if sync process has been spawned, False otherwise.
  149. """
  150. raise NotImplementedError
  151. @abc.abstractmethod
  152. def delete(self, remote_dir: str) -> bool:
  153. """Delete directory on remote storage.
  154. This function can spawn an asynchronous process that can be awaited in
  155. ``wait()``.
  156. Args:
  157. remote_dir: Remote directory to delete. This is an URI
  158. (``protocol://remote/path``).
  159. Returns:
  160. True if sync process has been spawned, False otherwise.
  161. """
  162. raise NotImplementedError
  163. def retry(self):
  164. """Retry the last sync up, sync down, or delete command.
  165. You should implement this method if you spawn asynchronous syncing
  166. processes.
  167. """
  168. pass
  169. def wait(self, timeout: Optional[float] = None):
  170. """Wait for asynchronous sync command to finish.
  171. You should implement this method if you spawn asynchronous syncing
  172. processes. This method should timeout after the asynchronous command
  173. has run for `sync_timeout` seconds and raise a `TimeoutError`.
  174. """
  175. pass
  176. def sync_up_if_needed(
  177. self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
  178. ) -> bool:
  179. """Syncs up if time since last sync up is greater than sync_period.
  180. Args:
  181. local_dir: Local directory to sync from.
  182. remote_dir: Remote directory to sync up to. This is an URI
  183. (``protocol://remote/path``).
  184. exclude: Pattern of files to exclude, e.g.
  185. ``["*/checkpoint_*]`` to exclude trial checkpoints.
  186. """
  187. now = time.time()
  188. if now - self.last_sync_up_time >= self.sync_period:
  189. result = self.sync_up(
  190. local_dir=local_dir, remote_dir=remote_dir, exclude=exclude
  191. )
  192. self.last_sync_up_time = now
  193. return result
  194. def sync_down_if_needed(
  195. self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
  196. ):
  197. """Syncs down if time since last sync down is greater than sync_period.
  198. Args:
  199. remote_dir: Remote directory to sync down from. This is an URI
  200. (``protocol://remote/path``).
  201. local_dir: Local directory to sync to.
  202. exclude: Pattern of files to exclude, e.g.
  203. ``["*/checkpoint_*]`` to exclude trial checkpoints.
  204. """
  205. now = time.time()
  206. if now - self.last_sync_down_time >= self.sync_period:
  207. result = self.sync_down(
  208. remote_dir=remote_dir, local_dir=local_dir, exclude=exclude
  209. )
  210. self.last_sync_down_time = now
  211. return result
  212. def wait_or_retry(self, max_retries: int = 2, backoff_s: int = 5):
  213. assert max_retries > 0
  214. last_error_traceback = None
  215. for i in range(max_retries + 1):
  216. try:
  217. self.wait()
  218. except Exception as e:
  219. attempts_remaining = max_retries - i
  220. # If we're out of retries, then save the full traceback of the last
  221. # error and show it when raising an exception.
  222. if attempts_remaining == 0:
  223. last_error_traceback = traceback.format_exc()
  224. break
  225. logger.error(
  226. f"The latest sync operation failed with the following error: "
  227. f"{repr(e)}\n"
  228. f"Retrying {attempts_remaining} more time(s) after sleeping "
  229. f"for {backoff_s} seconds..."
  230. )
  231. time.sleep(backoff_s)
  232. self.retry()
  233. continue
  234. # Succeeded!
  235. return
  236. raise RuntimeError(
  237. f"Failed sync even after {max_retries} retries. "
  238. f"The latest sync failed with the following error:\n{last_error_traceback}"
  239. )
  240. def reset(self):
  241. self.last_sync_up_time = float("-inf")
  242. self.last_sync_down_time = float("-inf")
  243. def close(self):
  244. pass
  245. def _repr_html_(self) -> str:
  246. return
  247. class _BackgroundSyncer(Syncer):
  248. """Syncer using a background process for asynchronous file transfer."""
  249. def __init__(
  250. self,
  251. sync_period: float = DEFAULT_SYNC_PERIOD,
  252. sync_timeout: float = DEFAULT_SYNC_TIMEOUT,
  253. ):
  254. super(_BackgroundSyncer, self).__init__(
  255. sync_period=sync_period, sync_timeout=sync_timeout
  256. )
  257. self._sync_process = None
  258. self._current_cmd = None
  259. def _should_continue_existing_sync(self):
  260. """Returns whether a previous sync is still running within the timeout."""
  261. return (
  262. self._sync_process
  263. and self._sync_process.is_running
  264. and time.time() - self._sync_process.start_time < self.sync_timeout
  265. )
  266. def _launch_sync_process(self, sync_command: Tuple[Callable, Dict]):
  267. """Waits for the previous sync process to finish,
  268. then launches a new process that runs the given command."""
  269. if self._sync_process:
  270. try:
  271. self.wait()
  272. except Exception:
  273. logger.warning(
  274. f"Last sync command failed with the following error:\n"
  275. f"{traceback.format_exc()}"
  276. )
  277. self._current_cmd = sync_command
  278. self.retry()
  279. def sync_up(
  280. self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
  281. ) -> bool:
  282. if self._should_continue_existing_sync():
  283. logger.debug(
  284. f"Last sync still in progress, "
  285. f"skipping sync up of {local_dir} to {remote_dir}"
  286. )
  287. return False
  288. sync_up_cmd = self._sync_up_command(
  289. local_path=local_dir, uri=remote_dir, exclude=exclude
  290. )
  291. self._launch_sync_process(sync_up_cmd)
  292. return True
  293. def _sync_up_command(
  294. self, local_path: str, uri: str, exclude: Optional[List] = None
  295. ) -> Tuple[Callable, Dict]:
  296. raise NotImplementedError
  297. def sync_down(
  298. self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
  299. ) -> bool:
  300. if self._should_continue_existing_sync():
  301. logger.warning(
  302. f"Last sync still in progress, "
  303. f"skipping sync down of {remote_dir} to {local_dir}"
  304. )
  305. return False
  306. sync_down_cmd = self._sync_down_command(uri=remote_dir, local_path=local_dir)
  307. self._launch_sync_process(sync_down_cmd)
  308. return True
  309. def _sync_down_command(self, uri: str, local_path: str) -> Tuple[Callable, Dict]:
  310. raise NotImplementedError
  311. def delete(self, remote_dir: str) -> bool:
  312. if self._should_continue_existing_sync():
  313. logger.warning(
  314. f"Last sync still in progress, skipping deletion of {remote_dir}"
  315. )
  316. return False
  317. delete_cmd = self._delete_command(uri=remote_dir)
  318. self._launch_sync_process(delete_cmd)
  319. return True
  320. def _delete_command(self, uri: str) -> Tuple[Callable, Dict]:
  321. raise NotImplementedError
  322. def wait(self, timeout: Optional[float] = None):
  323. if self._sync_process:
  324. try:
  325. self._sync_process.wait(timeout=timeout or self.sync_timeout)
  326. except Exception as e:
  327. raise e
  328. finally:
  329. # Regardless of whether the sync process succeeded within the timeout,
  330. # clear the sync process so a new one can be created.
  331. self._sync_process = None
  332. def retry(self):
  333. if not self._current_cmd:
  334. raise RuntimeError("No sync command set, cannot retry.")
  335. cmd, kwargs = self._current_cmd
  336. self._sync_process = _BackgroundProcess(cmd)
  337. self._sync_process.start(**kwargs)
  338. def __getstate__(self):
  339. state = self.__dict__.copy()
  340. state["_sync_process"] = None
  341. return state