beta_sync.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. """Implements `wandb sync` using wandb-core."""
  2. from __future__ import annotations
  3. import asyncio
  4. import contextlib
  5. import pathlib
  6. import time
  7. from collections.abc import Iterable, Iterator
  8. from itertools import filterfalse
  9. import wandb
  10. from wandb.errors import term
  11. from wandb.proto.wandb_sync_pb2 import ServerSyncResponse
  12. from wandb.sdk import wandb_setup
  13. from wandb.sdk.lib import asyncio_compat, wbauth
  14. from wandb.sdk.lib.printer import Printer, new_printer
  15. from wandb.sdk.lib.progress import progress_printer
  16. from wandb.sdk.lib.service.service_connection import ServiceConnection
  17. from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
  18. _MAX_LIST_LINES = 20
  19. _POLL_WAIT_SECONDS = 0.1
  20. _SLEEP = asyncio.sleep # patched in tests
  21. def sync(
  22. paths: list[pathlib.Path],
  23. *,
  24. live: bool,
  25. entity: str,
  26. project: str,
  27. run_id: str,
  28. job_type: str,
  29. replace_tags: str,
  30. dry_run: bool,
  31. skip_synced: bool,
  32. verbose: bool,
  33. parallelism: int,
  34. ) -> None:
  35. """Replay one or more .wandb files.
  36. Args:
  37. live: Whether to enable 'live' mode, which indefinitely retries reading
  38. incomplete transaction logs.
  39. entity: The entity override for all paths, or an empty string.
  40. project: The project override for all paths, or an empty string.
  41. run_id: The run ID override for all paths, or an empty string.
  42. job_type: An override for the job type for all runs, or an empty string.
  43. replace_tags: A string in the form 'old1=new1,old2=new2' that defines
  44. how to rename run tags.
  45. paths: One or more .wandb files, run directories containing
  46. .wandb files, and wandb directories containing run directories.
  47. dry_run: If true, just prints what it would do and exits.
  48. skip_synced: If true, skips files that have already been synced
  49. as indicated by a .wandb.synced marker file in the same directory.
  50. verbose: Verbose mode for printing more info.
  51. parallelism: Max number of runs to sync at a time.
  52. """
  53. tag_replacements = _parse_replace_tags(replace_tags)
  54. singleton = wandb_setup.singleton()
  55. try:
  56. cwd = pathlib.Path.cwd()
  57. except OSError:
  58. cwd = None
  59. ask_for_confirmation = False
  60. if not paths:
  61. paths = [pathlib.Path(singleton.settings.wandb_dir)]
  62. ask_for_confirmation = True
  63. wandb_files = _to_unique_files(
  64. (
  65. wandb_file
  66. for path in paths
  67. for wandb_file in _find_wandb_files(path, skip_synced=skip_synced)
  68. ),
  69. verbose=verbose,
  70. )
  71. if not wandb_files:
  72. term.termlog("No runs to sync.")
  73. return
  74. if dry_run:
  75. term.termlog(f"Would sync {len(wandb_files)} run(s):")
  76. _print_sorted_paths(wandb_files, verbose=verbose, root=cwd)
  77. return
  78. term.termlog(f"Syncing {len(wandb_files)} run(s):")
  79. _print_sorted_paths(wandb_files, verbose=verbose, root=cwd)
  80. if ask_for_confirmation and not term.confirm("Sync the listed runs?"):
  81. return
  82. # Authenticate the session. This updates the singleton settings credentials.
  83. if not wbauth.authenticate_session(
  84. host=singleton.settings.base_url,
  85. source="wandb sync",
  86. no_offline=True,
  87. ):
  88. term.termlog("Not authenticated.")
  89. return
  90. service = singleton.ensure_service()
  91. printer = new_printer()
  92. singleton.asyncer.run(
  93. lambda: _do_sync(
  94. wandb_files,
  95. cwd=cwd,
  96. live=live,
  97. service=service,
  98. entity=entity,
  99. project=project,
  100. run_id=run_id,
  101. job_type=job_type,
  102. tag_replacements=tag_replacements,
  103. settings=singleton.settings,
  104. printer=printer,
  105. parallelism=parallelism,
  106. )
  107. )
  108. def _parse_replace_tags(replace_tags: str) -> dict[str, str]:
  109. """Parse the --replace-tags argument to wandb sync."""
  110. if not replace_tags:
  111. return {}
  112. tag_replacements: dict[str, str] = {}
  113. for pair in replace_tags.split(","):
  114. if "=" not in pair:
  115. raise ValueError(
  116. f"Invalid --replace-tags format: {pair}. Expected 'old=new'."
  117. )
  118. old_tag, new_tag = pair.split("=", 1)
  119. tag_replacements[old_tag.strip()] = new_tag.strip()
  120. return tag_replacements
  121. def _to_unique_files(
  122. paths: Iterator[pathlib.Path],
  123. *,
  124. verbose: bool,
  125. ) -> set[pathlib.Path]:
  126. """Returns paths with duplicates removed.
  127. Determines file equality the same way as os.path.samefile().
  128. """
  129. id_to_path: dict[tuple[int, int], pathlib.Path] = dict()
  130. # Sort in reverse so that the last path written to the map is
  131. # alphabetically earliest.
  132. for path in sorted(paths, reverse=True):
  133. try:
  134. stat = path.stat()
  135. except OSError as e:
  136. term.termerror(f"Failed to stat {path}: {e}")
  137. continue
  138. id = (stat.st_ino, stat.st_dev)
  139. if verbose and (other_path := id_to_path.get(id)):
  140. term.termlog(f"{path} is the same as {other_path}")
  141. id_to_path[id] = path
  142. return set(id_to_path.values())
  143. async def _do_sync(
  144. wandb_files: set[pathlib.Path],
  145. *,
  146. cwd: pathlib.Path | None,
  147. live: bool,
  148. service: ServiceConnection,
  149. entity: str,
  150. project: str,
  151. run_id: str,
  152. job_type: str,
  153. tag_replacements: dict[str, str],
  154. settings: wandb.Settings,
  155. printer: Printer,
  156. parallelism: int,
  157. ) -> None:
  158. """Sync the specified files.
  159. This is factored out to make the progress animation testable.
  160. """
  161. init_handle = await service.init_sync(
  162. wandb_files,
  163. settings,
  164. cwd=cwd,
  165. live=live,
  166. entity=entity,
  167. project=project,
  168. run_id=run_id,
  169. job_type=job_type,
  170. tag_replacements=tag_replacements,
  171. )
  172. init_result = await init_handle.wait_async(timeout=5)
  173. sync_handle = await service.sync(init_result.id, parallelism=parallelism)
  174. await _SyncStatusLoop(
  175. init_result.id,
  176. service,
  177. printer,
  178. ).wait_with_progress(sync_handle)
  179. class _SyncStatusLoop:
  180. """Displays a sync operation's status until it completes."""
  181. def __init__(
  182. self,
  183. id: str,
  184. service: ServiceConnection,
  185. printer: Printer,
  186. ) -> None:
  187. self._id = id
  188. self._service = service
  189. self._printer = printer
  190. self._rate_limit_last_time: float | None = None
  191. self._done = asyncio.Event()
  192. async def wait_with_progress(
  193. self,
  194. handle: MailboxHandle[ServerSyncResponse],
  195. ) -> None:
  196. """Display status updates until the handle completes."""
  197. async with asyncio_compat.open_task_group() as group:
  198. group.start_soon(self._wait_then_mark_done(handle))
  199. group.start_soon(self._show_progress_until_done())
  200. async def _wait_then_mark_done(
  201. self,
  202. handle: MailboxHandle[ServerSyncResponse],
  203. ) -> None:
  204. response = await handle.wait_async(timeout=None)
  205. for msg in response.messages:
  206. self._printer.display(msg.content, level=msg.severity)
  207. self._done.set()
  208. async def _show_progress_until_done(self) -> None:
  209. """Show rate-limited status updates until _done is set."""
  210. with progress_printer(self._printer, "Syncing...") as progress:
  211. while not await self._rate_limit_check_done():
  212. handle = await self._service.sync_status(self._id)
  213. response = await handle.wait_async(timeout=None)
  214. for msg in response.new_messages:
  215. self._printer.display(msg.content, level=msg.severity)
  216. progress.update(list(response.stats))
  217. async def _rate_limit_check_done(self) -> bool:
  218. """Wait for rate limit and return whether _done is set."""
  219. now = time.monotonic()
  220. last_time = self._rate_limit_last_time
  221. self._rate_limit_last_time = now
  222. if last_time and (time_since_last := now - last_time) < _POLL_WAIT_SECONDS:
  223. await asyncio_compat.race(
  224. _SLEEP(_POLL_WAIT_SECONDS - time_since_last),
  225. self._done.wait(),
  226. )
  227. return self._done.is_set()
  228. def _find_wandb_files(
  229. path: pathlib.Path,
  230. *,
  231. skip_synced: bool,
  232. ) -> Iterator[pathlib.Path]:
  233. """Returns paths to the .wandb files to sync."""
  234. if skip_synced:
  235. yield from filterfalse(_is_synced, _expand_wandb_files(path))
  236. else:
  237. yield from _expand_wandb_files(path)
  238. def _expand_wandb_files(
  239. path: pathlib.Path,
  240. ) -> Iterator[pathlib.Path]:
  241. """Iterate over .wandb files selected by the path."""
  242. if path.suffix == ".wandb":
  243. yield path
  244. return
  245. files_in_run_directory = path.glob("*.wandb")
  246. try:
  247. first_file = next(files_in_run_directory)
  248. except StopIteration:
  249. pass
  250. else:
  251. yield first_file
  252. yield from files_in_run_directory
  253. return
  254. yield from path.glob("*/*.wandb")
  255. def _is_synced(path: pathlib.Path) -> bool:
  256. """Returns whether the .wandb file is synced."""
  257. return path.with_suffix(".wandb.synced").exists()
  258. def _print_sorted_paths(
  259. paths: Iterable[pathlib.Path],
  260. verbose: bool,
  261. *,
  262. root: pathlib.Path | None,
  263. ) -> None:
  264. """Print file paths, sorting them and truncating the list if needed.
  265. Args:
  266. paths: Paths to print. Must be absolute with symlinks resolved.
  267. verbose: If true, doesn't truncate paths.
  268. root: A root directory for making paths relative.
  269. """
  270. # Prefer to print paths relative to the current working directory.
  271. formatted_paths: list[str] = []
  272. for path in paths:
  273. formatted_path = str(path)
  274. if root:
  275. with contextlib.suppress(ValueError):
  276. formatted_path = str(path.relative_to(root))
  277. formatted_paths.append(formatted_path)
  278. sorted_paths = sorted(formatted_paths)
  279. max_lines = len(sorted_paths) if verbose else _MAX_LIST_LINES
  280. for i in range(min(len(sorted_paths), max_lines)):
  281. term.termlog(f" {sorted_paths[i]}")
  282. if len(sorted_paths) > max_lines:
  283. remaining = len(sorted_paths) - max_lines
  284. term.termlog(f" +{remaining:,d} more (pass --verbose to see all)")