file_pusher.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from __future__ import annotations
  2. import concurrent.futures
  3. import logging
  4. import os
  5. import queue
  6. import tempfile
  7. import threading
  8. import time
  9. from typing import TYPE_CHECKING
  10. import wandb
  11. import wandb.util
  12. from wandb.filesync import stats, step_checksum, step_upload
  13. from wandb.sdk.lib.paths import LogicalPath
  14. if TYPE_CHECKING:
  15. from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
  16. from wandb.sdk.artifacts.artifact_saver import SaveFn
  17. from wandb.sdk.internal import file_stream, internal_api
  18. from wandb.sdk.internal.settings_static import SettingsStatic
  19. logger = logging.getLogger(__name__)
  20. class FilePusher:
  21. """Parallel file upload class.
  22. This manages uploading multiple files in parallel. It will restart a given file's
  23. upload job if it receives a notification that that file has been modified. The
  24. finish() method will block until all events have been processed and all uploads are
  25. complete.
  26. """
  27. MAX_UPLOAD_JOBS = 64
  28. def __init__(
  29. self,
  30. api: internal_api.Api,
  31. file_stream: file_stream.FileStreamApi,
  32. settings: SettingsStatic | None = None,
  33. ) -> None:
  34. self._api = api
  35. # Temporary directory for copies we make of some file types to
  36. # reduce the probability that the file gets changed while we're
  37. # uploading it.
  38. self._tempdir = tempfile.TemporaryDirectory("wandb")
  39. self._stats = stats.Stats()
  40. self._incoming_queue: queue.Queue[step_checksum.Event] = queue.Queue()
  41. self._event_queue: queue.Queue[step_upload.Event] = queue.Queue()
  42. self._step_checksum = step_checksum.StepChecksum(
  43. self._api,
  44. self._tempdir,
  45. self._incoming_queue,
  46. self._event_queue,
  47. self._stats,
  48. )
  49. self._step_checksum.start()
  50. self._step_upload = step_upload.StepUpload(
  51. self._api,
  52. self._stats,
  53. self._event_queue,
  54. self.MAX_UPLOAD_JOBS,
  55. file_stream=file_stream,
  56. settings=settings,
  57. )
  58. self._step_upload.start()
  59. self._stats_thread_stop = threading.Event()
  60. if os.environ.get("WANDB_DEBUG"):
  61. # debug thread to monitor and report file pusher stats
  62. self._stats_thread = threading.Thread(
  63. target=self._file_pusher_stats,
  64. daemon=True,
  65. name="FPStatsThread",
  66. )
  67. self._stats_thread.start()
  68. def _file_pusher_stats(self) -> None:
  69. while not self._stats_thread_stop.is_set():
  70. logger.info(f"FilePusher stats: {self._stats._stats}")
  71. time.sleep(1)
  72. def get_status(self) -> tuple[bool, stats.Summary]:
  73. running = self.is_alive()
  74. summary = self._stats.summary()
  75. return running, summary
  76. def print_status(self, prefix: bool = True) -> None:
  77. step = 0
  78. spinner_states = ["-", "\\", "|", "/"]
  79. stop = False
  80. while True:
  81. if not self.is_alive():
  82. stop = True
  83. summary = self._stats.summary()
  84. line = f" {summary.uploaded_bytes / 1048576.0:.2f}MB of {summary.total_bytes / 1048576.0:.2f}MB uploaded ({summary.deduped_bytes / 1048576.0:.2f}MB deduped)\r"
  85. line = spinner_states[step % 4] + line
  86. step += 1
  87. wandb.termlog(line, newline=False, prefix=prefix)
  88. if stop:
  89. break
  90. time.sleep(0.25)
  91. dedupe_fraction = (
  92. summary.deduped_bytes / float(summary.total_bytes)
  93. if summary.total_bytes > 0
  94. else 0
  95. )
  96. if dedupe_fraction > 0.01:
  97. wandb.termlog(
  98. "W&B sync reduced upload amount by %.1f%% "
  99. % (dedupe_fraction * 100),
  100. prefix=prefix,
  101. )
  102. # clear progress line.
  103. wandb.termlog(" " * 79, prefix=prefix)
  104. def file_counts_by_category(self) -> stats.FileCountsByCategory:
  105. return self._stats.file_counts_by_category()
  106. def file_changed(self, save_name: LogicalPath, path: str, copy: bool = True):
  107. """Tell the file pusher that a file's changed and should be uploaded.
  108. Args:
  109. save_name: string logical location of the file relative to the run
  110. directory.
  111. path: actual string path of the file to upload on the filesystem.
  112. """
  113. # Tests in linux were failing because wandb-events.jsonl didn't exist
  114. if not os.path.exists(path) or not os.path.isfile(path):
  115. return
  116. if os.path.getsize(path) == 0:
  117. return
  118. event = step_checksum.RequestUpload(path, save_name, copy)
  119. self._incoming_queue.put(event)
  120. def store_manifest_files(
  121. self,
  122. manifest: ArtifactManifest,
  123. artifact_id: str,
  124. save_fn: SaveFn,
  125. ) -> None:
  126. event = step_checksum.RequestStoreManifestFiles(manifest, artifact_id, save_fn)
  127. self._incoming_queue.put(event)
  128. def commit_artifact(
  129. self,
  130. artifact_id: str,
  131. *,
  132. finalize: bool = True,
  133. before_commit: step_upload.PreCommitFn,
  134. result_future: concurrent.futures.Future[None],
  135. ):
  136. event = step_checksum.RequestCommitArtifact(
  137. artifact_id, finalize, before_commit, result_future
  138. )
  139. self._incoming_queue.put(event)
  140. def finish(self, callback: step_upload.OnRequestFinishFn | None = None):
  141. logger.info("shutting down file pusher")
  142. self._incoming_queue.put(step_checksum.RequestFinish(callback))
  143. self._stats_thread_stop.set()
  144. def join(self) -> None:
  145. # NOTE: must have called finish before join
  146. logger.info("waiting for file pusher")
  147. while self.is_alive():
  148. time.sleep(0.5)
  149. self._tempdir.cleanup()
  150. def is_alive(self) -> bool:
  151. return self._step_checksum.is_alive() or self._step_upload.is_alive()