step_upload.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. """Batching file prepare requests to our API."""
  2. from __future__ import annotations
  3. import concurrent.futures
  4. import logging
  5. import queue
  6. import sys
  7. import threading
  8. from collections.abc import MutableMapping, MutableSequence, MutableSet
  9. from typing import TYPE_CHECKING, Callable, NamedTuple, Union
  10. from wandb.errors.term import termerror
  11. from wandb.filesync import upload_job
  12. from wandb.sdk.lib.paths import LogicalPath
  13. if TYPE_CHECKING:
  14. from typing import TypedDict
  15. from wandb.filesync import stats
  16. from wandb.sdk.internal import file_stream, internal_api, progress
  17. from wandb.sdk.internal.settings_static import SettingsStatic
  18. class ArtifactStatus(TypedDict):
  19. finalize: bool
  20. pending_count: int
  21. commit_requested: bool
  22. pre_commit_callbacks: MutableSet[PreCommitFn]
  23. result_futures: MutableSet[concurrent.futures.Future[None]]
  24. PreCommitFn = Callable[[], None]
  25. OnRequestFinishFn = Callable[[], None]
  26. SaveFn = Callable[["progress.ProgressFn"], bool]
  27. logger = logging.getLogger(__name__)
  28. class RequestUpload(NamedTuple):
  29. path: str
  30. save_name: LogicalPath
  31. artifact_id: str | None
  32. md5: str | None
  33. copied: bool
  34. save_fn: SaveFn | None
  35. digest: str | None
  36. class RequestCommitArtifact(NamedTuple):
  37. artifact_id: str
  38. finalize: bool
  39. before_commit: PreCommitFn
  40. result_future: concurrent.futures.Future[None]
  41. class RequestFinish(NamedTuple):
  42. callback: OnRequestFinishFn | None
  43. class EventJobDone(NamedTuple):
  44. job: RequestUpload
  45. exc: BaseException | None
  46. Event = Union[RequestUpload, RequestCommitArtifact, RequestFinish, EventJobDone]
  47. class StepUpload:
  48. def __init__(
  49. self,
  50. api: internal_api.Api,
  51. stats: stats.Stats,
  52. event_queue: queue.Queue[Event],
  53. max_threads: int,
  54. file_stream: file_stream.FileStreamApi,
  55. settings: SettingsStatic | None = None,
  56. ) -> None:
  57. self._api = api
  58. self._stats = stats
  59. self._event_queue = event_queue
  60. self._file_stream = file_stream
  61. self._thread = threading.Thread(target=self._thread_body)
  62. self._thread.daemon = True
  63. self._pool = concurrent.futures.ThreadPoolExecutor(
  64. thread_name_prefix="wandb-upload",
  65. max_workers=max_threads,
  66. )
  67. # Indexed by files' `save_name`'s, which are their ID's in the Run.
  68. self._running_jobs: MutableMapping[LogicalPath, RequestUpload] = {}
  69. self._pending_jobs: MutableSequence[RequestUpload] = []
  70. self._artifacts: MutableMapping[str, ArtifactStatus] = {}
  71. self.silent = bool(settings.silent) if settings else False
  72. def _thread_body(self) -> None:
  73. event: Event | None
  74. # Wait for event in the queue, and process one by one until a
  75. # finish event is received
  76. finish_callback = None
  77. while True:
  78. event = self._event_queue.get()
  79. if isinstance(event, RequestFinish):
  80. finish_callback = event.callback
  81. break
  82. self._handle_event(event)
  83. # We've received a finish event. At this point, further Upload requests
  84. # are invalid.
  85. # After a finish event is received, iterate through the event queue
  86. # one by one and process all remaining events.
  87. while True:
  88. try:
  89. event = self._event_queue.get(True, 0.2)
  90. except queue.Empty:
  91. event = None
  92. if event:
  93. self._handle_event(event)
  94. elif not self._running_jobs:
  95. # Queue was empty and no jobs left.
  96. self._pool.shutdown(wait=False)
  97. if finish_callback:
  98. finish_callback()
  99. break
  100. def _handle_event(self, event: Event) -> None:
  101. if isinstance(event, EventJobDone):
  102. job = event.job
  103. if event.exc is not None:
  104. logger.exception(
  105. "Failed to upload file: %s", job.path, exc_info=event.exc
  106. )
  107. if job.artifact_id:
  108. if event.exc is None:
  109. self._artifacts[job.artifact_id]["pending_count"] -= 1
  110. self._maybe_commit_artifact(job.artifact_id)
  111. else:
  112. if not self.silent:
  113. termerror(
  114. "Uploading artifact file failed. Artifact won't be committed."
  115. )
  116. self._fail_artifact_futures(job.artifact_id, event.exc)
  117. self._running_jobs.pop(job.save_name)
  118. # If we have any pending jobs, start one now
  119. if self._pending_jobs:
  120. event = self._pending_jobs.pop(0)
  121. self._start_upload_job(event)
  122. elif isinstance(event, RequestCommitArtifact):
  123. if event.artifact_id not in self._artifacts:
  124. self._init_artifact(event.artifact_id)
  125. self._artifacts[event.artifact_id]["commit_requested"] = True
  126. self._artifacts[event.artifact_id]["finalize"] = event.finalize
  127. self._artifacts[event.artifact_id]["pre_commit_callbacks"].add(
  128. event.before_commit
  129. )
  130. self._artifacts[event.artifact_id]["result_futures"].add(
  131. event.result_future
  132. )
  133. self._maybe_commit_artifact(event.artifact_id)
  134. elif isinstance(event, RequestUpload):
  135. if event.artifact_id is not None:
  136. if event.artifact_id not in self._artifacts:
  137. self._init_artifact(event.artifact_id)
  138. self._artifacts[event.artifact_id]["pending_count"] += 1
  139. self._start_upload_job(event)
  140. else:
  141. raise TypeError(f"Event has unexpected type: {event!s}")
  142. def _start_upload_job(self, event: RequestUpload) -> None:
  143. # Operations on a single backend file must be serialized. if
  144. # we're already uploading this file, put the event on the
  145. # end of the queue
  146. if event.save_name in self._running_jobs:
  147. self._pending_jobs.append(event)
  148. return
  149. self._spawn_upload(event)
  150. def _spawn_upload(self, event: RequestUpload) -> None:
  151. """Spawn an upload job, and handles the bookkeeping of `self._running_jobs`.
  152. Context: it's important that, whenever we add an entry to `self._running_jobs`,
  153. we ensure that a corresponding `EventJobDone` message will eventually get handled;
  154. otherwise, the `_running_jobs` entry will never get removed, and the StepUpload
  155. will never shut down.
  156. The sole purpose of this function is to make sure that the code that adds an entry
  157. to `self._running_jobs` is textually right next to the code that eventually enqueues
  158. the `EventJobDone` message. This should help keep them in sync.
  159. """
  160. # Adding the entry to `self._running_jobs` MUST happen in the main thread,
  161. # NOT in the job that gets submitted to the thread-pool, to guard against
  162. # this sequence of events:
  163. # - StepUpload receives a RequestUpload
  164. # ...and therefore spawns a thread to do the upload
  165. # - StepUpload receives a RequestFinish
  166. # ...and checks `self._running_jobs` to see if there are any tasks to wait for...
  167. # ...and there are none, because the addition to `self._running_jobs` happens in
  168. # the background thread, which the scheduler hasn't yet run...
  169. # ...so the StepUpload shuts down. Even though we haven't uploaded the file!
  170. #
  171. # This would be very bad!
  172. # So, this line has to happen _outside_ the `pool.submit()`.
  173. self._running_jobs[event.save_name] = event
  174. def run_and_notify() -> None:
  175. try:
  176. self._do_upload(event)
  177. finally:
  178. self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
  179. self._pool.submit(run_and_notify)
  180. def _do_upload(self, event: RequestUpload) -> None:
  181. job = upload_job.UploadJob(
  182. self._stats,
  183. self._api,
  184. self._file_stream,
  185. self.silent,
  186. event.save_name,
  187. event.path,
  188. event.artifact_id,
  189. event.md5,
  190. event.copied,
  191. event.save_fn,
  192. event.digest,
  193. )
  194. job.run()
  195. def _init_artifact(self, artifact_id: str) -> None:
  196. self._artifacts[artifact_id] = {
  197. "finalize": False,
  198. "pending_count": 0,
  199. "commit_requested": False,
  200. "pre_commit_callbacks": set(),
  201. "result_futures": set(),
  202. }
  203. def _maybe_commit_artifact(self, artifact_id: str) -> None:
  204. artifact_status = self._artifacts[artifact_id]
  205. if (
  206. artifact_status["pending_count"] == 0
  207. and artifact_status["commit_requested"]
  208. ):
  209. try:
  210. for pre_callback in artifact_status["pre_commit_callbacks"]:
  211. pre_callback()
  212. if artifact_status["finalize"]:
  213. self._api.commit_artifact(artifact_id)
  214. except Exception as exc:
  215. termerror(
  216. f"Committing artifact failed. Artifact {artifact_id} won't be finalized."
  217. )
  218. termerror(str(exc))
  219. self._fail_artifact_futures(artifact_id, exc)
  220. else:
  221. self._resolve_artifact_futures(artifact_id)
  222. def _fail_artifact_futures(self, artifact_id: str, exc: BaseException) -> None:
  223. futures = self._artifacts[artifact_id]["result_futures"]
  224. for result_future in futures:
  225. result_future.set_exception(exc)
  226. futures.clear()
  227. def _resolve_artifact_futures(self, artifact_id: str) -> None:
  228. futures = self._artifacts[artifact_id]["result_futures"]
  229. for result_future in futures:
  230. result_future.set_result(None)
  231. futures.clear()
  232. def start(self) -> None:
  233. self._thread.start()
  234. def is_alive(self) -> bool:
  235. return self._thread.is_alive()