artifact_download_logger.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. """Artifact download logger."""
  2. from __future__ import annotations
  3. import multiprocessing.dummy
  4. import time
  5. from typing import Callable
  6. from wandb.errors.term import termlog
  7. class ArtifactDownloadLogger:
  8. def __init__(
  9. self,
  10. nfiles: int,
  11. clock_for_testing: Callable[[], float] = time.monotonic,
  12. termlog_for_testing: Callable[..., None] = termlog,
  13. ) -> None:
  14. self._nfiles = nfiles
  15. self._clock = clock_for_testing
  16. self._termlog = termlog_for_testing
  17. self._n_files_downloaded = 0
  18. self._spinner_index = 0
  19. self._last_log_time = self._clock()
  20. self._lock = multiprocessing.dummy.Lock()
  21. def notify_downloaded(self) -> None:
  22. with self._lock:
  23. self._n_files_downloaded += 1
  24. if self._n_files_downloaded == self._nfiles:
  25. self._termlog(
  26. f" {self._nfiles} of {self._nfiles} files downloaded. ",
  27. # ^ trailing spaces to wipe out ellipsis from previous logs
  28. newline=True,
  29. )
  30. self._last_log_time = self._clock()
  31. elif self._clock() - self._last_log_time > 0.1:
  32. self._spinner_index += 1
  33. spinner = r"-\|/"[self._spinner_index % 4]
  34. self._termlog(
  35. f"{spinner} {self._n_files_downloaded} of {self._nfiles} files downloaded...\r",
  36. newline=False,
  37. )
  38. self._last_log_time = self._clock()