| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347 |
- """Implements `wandb sync` using wandb-core."""
- from __future__ import annotations
- import asyncio
- import contextlib
- import pathlib
- import time
- from collections.abc import Iterable, Iterator
- from itertools import filterfalse
- import wandb
- from wandb.errors import term
- from wandb.proto.wandb_sync_pb2 import ServerSyncResponse
- from wandb.sdk import wandb_setup
- from wandb.sdk.lib import asyncio_compat, wbauth
- from wandb.sdk.lib.printer import Printer, new_printer
- from wandb.sdk.lib.progress import progress_printer
- from wandb.sdk.lib.service.service_connection import ServiceConnection
- from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
- _MAX_LIST_LINES = 20
- _POLL_WAIT_SECONDS = 0.1
- _SLEEP = asyncio.sleep # patched in tests
- def sync(
- paths: list[pathlib.Path],
- *,
- live: bool,
- entity: str,
- project: str,
- run_id: str,
- job_type: str,
- replace_tags: str,
- dry_run: bool,
- skip_synced: bool,
- verbose: bool,
- parallelism: int,
- ) -> None:
- """Replay one or more .wandb files.
- Args:
- live: Whether to enable 'live' mode, which indefinitely retries reading
- incomplete transaction logs.
- entity: The entity override for all paths, or an empty string.
- project: The project override for all paths, or an empty string.
- run_id: The run ID override for all paths, or an empty string.
- job_type: An override for the job type for all runs, or an empty string.
- replace_tags: A string in the form 'old1=new1,old2=new2' that defines
- how to rename run tags.
- paths: One or more .wandb files, run directories containing
- .wandb files, and wandb directories containing run directories.
- dry_run: If true, just prints what it would do and exits.
- skip_synced: If true, skips files that have already been synced
- as indicated by a .wandb.synced marker file in the same directory.
- verbose: Verbose mode for printing more info.
- parallelism: Max number of runs to sync at a time.
- """
- tag_replacements = _parse_replace_tags(replace_tags)
- singleton = wandb_setup.singleton()
- try:
- cwd = pathlib.Path.cwd()
- except OSError:
- cwd = None
- ask_for_confirmation = False
- if not paths:
- paths = [pathlib.Path(singleton.settings.wandb_dir)]
- ask_for_confirmation = True
- wandb_files = _to_unique_files(
- (
- wandb_file
- for path in paths
- for wandb_file in _find_wandb_files(path, skip_synced=skip_synced)
- ),
- verbose=verbose,
- )
- if not wandb_files:
- term.termlog("No runs to sync.")
- return
- if dry_run:
- term.termlog(f"Would sync {len(wandb_files)} run(s):")
- _print_sorted_paths(wandb_files, verbose=verbose, root=cwd)
- return
- term.termlog(f"Syncing {len(wandb_files)} run(s):")
- _print_sorted_paths(wandb_files, verbose=verbose, root=cwd)
- if ask_for_confirmation and not term.confirm("Sync the listed runs?"):
- return
- # Authenticate the session. This updates the singleton settings credentials.
- if not wbauth.authenticate_session(
- host=singleton.settings.base_url,
- source="wandb sync",
- no_offline=True,
- ):
- term.termlog("Not authenticated.")
- return
- service = singleton.ensure_service()
- printer = new_printer()
- singleton.asyncer.run(
- lambda: _do_sync(
- wandb_files,
- cwd=cwd,
- live=live,
- service=service,
- entity=entity,
- project=project,
- run_id=run_id,
- job_type=job_type,
- tag_replacements=tag_replacements,
- settings=singleton.settings,
- printer=printer,
- parallelism=parallelism,
- )
- )
- def _parse_replace_tags(replace_tags: str) -> dict[str, str]:
- """Parse the --replace-tags argument to wandb sync."""
- if not replace_tags:
- return {}
- tag_replacements: dict[str, str] = {}
- for pair in replace_tags.split(","):
- if "=" not in pair:
- raise ValueError(
- f"Invalid --replace-tags format: {pair}. Expected 'old=new'."
- )
- old_tag, new_tag = pair.split("=", 1)
- tag_replacements[old_tag.strip()] = new_tag.strip()
- return tag_replacements
- def _to_unique_files(
- paths: Iterator[pathlib.Path],
- *,
- verbose: bool,
- ) -> set[pathlib.Path]:
- """Returns paths with duplicates removed.
- Determines file equality the same way as os.path.samefile().
- """
- id_to_path: dict[tuple[int, int], pathlib.Path] = dict()
- # Sort in reverse so that the last path written to the map is
- # alphabetically earliest.
- for path in sorted(paths, reverse=True):
- try:
- stat = path.stat()
- except OSError as e:
- term.termerror(f"Failed to stat {path}: {e}")
- continue
- id = (stat.st_ino, stat.st_dev)
- if verbose and (other_path := id_to_path.get(id)):
- term.termlog(f"{path} is the same as {other_path}")
- id_to_path[id] = path
- return set(id_to_path.values())
- async def _do_sync(
- wandb_files: set[pathlib.Path],
- *,
- cwd: pathlib.Path | None,
- live: bool,
- service: ServiceConnection,
- entity: str,
- project: str,
- run_id: str,
- job_type: str,
- tag_replacements: dict[str, str],
- settings: wandb.Settings,
- printer: Printer,
- parallelism: int,
- ) -> None:
- """Sync the specified files.
- This is factored out to make the progress animation testable.
- """
- init_handle = await service.init_sync(
- wandb_files,
- settings,
- cwd=cwd,
- live=live,
- entity=entity,
- project=project,
- run_id=run_id,
- job_type=job_type,
- tag_replacements=tag_replacements,
- )
- init_result = await init_handle.wait_async(timeout=5)
- sync_handle = await service.sync(init_result.id, parallelism=parallelism)
- await _SyncStatusLoop(
- init_result.id,
- service,
- printer,
- ).wait_with_progress(sync_handle)
- class _SyncStatusLoop:
- """Displays a sync operation's status until it completes."""
- def __init__(
- self,
- id: str,
- service: ServiceConnection,
- printer: Printer,
- ) -> None:
- self._id = id
- self._service = service
- self._printer = printer
- self._rate_limit_last_time: float | None = None
- self._done = asyncio.Event()
- async def wait_with_progress(
- self,
- handle: MailboxHandle[ServerSyncResponse],
- ) -> None:
- """Display status updates until the handle completes."""
- async with asyncio_compat.open_task_group() as group:
- group.start_soon(self._wait_then_mark_done(handle))
- group.start_soon(self._show_progress_until_done())
- async def _wait_then_mark_done(
- self,
- handle: MailboxHandle[ServerSyncResponse],
- ) -> None:
- response = await handle.wait_async(timeout=None)
- for msg in response.messages:
- self._printer.display(msg.content, level=msg.severity)
- self._done.set()
- async def _show_progress_until_done(self) -> None:
- """Show rate-limited status updates until _done is set."""
- with progress_printer(self._printer, "Syncing...") as progress:
- while not await self._rate_limit_check_done():
- handle = await self._service.sync_status(self._id)
- response = await handle.wait_async(timeout=None)
- for msg in response.new_messages:
- self._printer.display(msg.content, level=msg.severity)
- progress.update(list(response.stats))
- async def _rate_limit_check_done(self) -> bool:
- """Wait for rate limit and return whether _done is set."""
- now = time.monotonic()
- last_time = self._rate_limit_last_time
- self._rate_limit_last_time = now
- if last_time and (time_since_last := now - last_time) < _POLL_WAIT_SECONDS:
- await asyncio_compat.race(
- _SLEEP(_POLL_WAIT_SECONDS - time_since_last),
- self._done.wait(),
- )
- return self._done.is_set()
- def _find_wandb_files(
- path: pathlib.Path,
- *,
- skip_synced: bool,
- ) -> Iterator[pathlib.Path]:
- """Returns paths to the .wandb files to sync."""
- if skip_synced:
- yield from filterfalse(_is_synced, _expand_wandb_files(path))
- else:
- yield from _expand_wandb_files(path)
- def _expand_wandb_files(
- path: pathlib.Path,
- ) -> Iterator[pathlib.Path]:
- """Iterate over .wandb files selected by the path."""
- if path.suffix == ".wandb":
- yield path
- return
- files_in_run_directory = path.glob("*.wandb")
- try:
- first_file = next(files_in_run_directory)
- except StopIteration:
- pass
- else:
- yield first_file
- yield from files_in_run_directory
- return
- yield from path.glob("*/*.wandb")
- def _is_synced(path: pathlib.Path) -> bool:
- """Returns whether the .wandb file is synced."""
- return path.with_suffix(".wandb.synced").exists()
- def _print_sorted_paths(
- paths: Iterable[pathlib.Path],
- verbose: bool,
- *,
- root: pathlib.Path | None,
- ) -> None:
- """Print file paths, sorting them and truncating the list if needed.
- Args:
- paths: Paths to print. Must be absolute with symlinks resolved.
- verbose: If true, doesn't truncate paths.
- root: A root directory for making paths relative.
- """
- # Prefer to print paths relative to the current working directory.
- formatted_paths: list[str] = []
- for path in paths:
- formatted_path = str(path)
- if root:
- with contextlib.suppress(ValueError):
- formatted_path = str(path.relative_to(root))
- formatted_paths.append(formatted_path)
- sorted_paths = sorted(formatted_paths)
- max_lines = len(sorted_paths) if verbose else _MAX_LIST_LINES
- for i in range(min(len(sorted_paths), max_lines)):
- term.termlog(f" {sorted_paths[i]}")
- if len(sorted_paths) > max_lines:
- remaining = len(sorted_paths) - max_lines
- term.termlog(f" +{remaining:,d} more (pass --verbose to see all)")
|