| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281 |
- """Batching file prepare requests to our API."""
- from __future__ import annotations
- import concurrent.futures
- import logging
- import queue
- import sys
- import threading
- from collections.abc import MutableMapping, MutableSequence, MutableSet
- from typing import TYPE_CHECKING, Callable, NamedTuple, Union
- from wandb.errors.term import termerror
- from wandb.filesync import upload_job
- from wandb.sdk.lib.paths import LogicalPath
- if TYPE_CHECKING:
- from typing import TypedDict
- from wandb.filesync import stats
- from wandb.sdk.internal import file_stream, internal_api, progress
- from wandb.sdk.internal.settings_static import SettingsStatic
- class ArtifactStatus(TypedDict):
- finalize: bool
- pending_count: int
- commit_requested: bool
- pre_commit_callbacks: MutableSet[PreCommitFn]
- result_futures: MutableSet[concurrent.futures.Future[None]]
- PreCommitFn = Callable[[], None]
- OnRequestFinishFn = Callable[[], None]
- SaveFn = Callable[["progress.ProgressFn"], bool]
- logger = logging.getLogger(__name__)
- class RequestUpload(NamedTuple):
- path: str
- save_name: LogicalPath
- artifact_id: str | None
- md5: str | None
- copied: bool
- save_fn: SaveFn | None
- digest: str | None
- class RequestCommitArtifact(NamedTuple):
- artifact_id: str
- finalize: bool
- before_commit: PreCommitFn
- result_future: concurrent.futures.Future[None]
- class RequestFinish(NamedTuple):
- callback: OnRequestFinishFn | None
- class EventJobDone(NamedTuple):
- job: RequestUpload
- exc: BaseException | None
- Event = Union[RequestUpload, RequestCommitArtifact, RequestFinish, EventJobDone]
- class StepUpload:
- def __init__(
- self,
- api: internal_api.Api,
- stats: stats.Stats,
- event_queue: queue.Queue[Event],
- max_threads: int,
- file_stream: file_stream.FileStreamApi,
- settings: SettingsStatic | None = None,
- ) -> None:
- self._api = api
- self._stats = stats
- self._event_queue = event_queue
- self._file_stream = file_stream
- self._thread = threading.Thread(target=self._thread_body)
- self._thread.daemon = True
- self._pool = concurrent.futures.ThreadPoolExecutor(
- thread_name_prefix="wandb-upload",
- max_workers=max_threads,
- )
- # Indexed by files' `save_name`'s, which are their ID's in the Run.
- self._running_jobs: MutableMapping[LogicalPath, RequestUpload] = {}
- self._pending_jobs: MutableSequence[RequestUpload] = []
- self._artifacts: MutableMapping[str, ArtifactStatus] = {}
- self.silent = bool(settings.silent) if settings else False
- def _thread_body(self) -> None:
- event: Event | None
- # Wait for event in the queue, and process one by one until a
- # finish event is received
- finish_callback = None
- while True:
- event = self._event_queue.get()
- if isinstance(event, RequestFinish):
- finish_callback = event.callback
- break
- self._handle_event(event)
- # We've received a finish event. At this point, further Upload requests
- # are invalid.
- # After a finish event is received, iterate through the event queue
- # one by one and process all remaining events.
- while True:
- try:
- event = self._event_queue.get(True, 0.2)
- except queue.Empty:
- event = None
- if event:
- self._handle_event(event)
- elif not self._running_jobs:
- # Queue was empty and no jobs left.
- self._pool.shutdown(wait=False)
- if finish_callback:
- finish_callback()
- break
- def _handle_event(self, event: Event) -> None:
- if isinstance(event, EventJobDone):
- job = event.job
- if event.exc is not None:
- logger.exception(
- "Failed to upload file: %s", job.path, exc_info=event.exc
- )
- if job.artifact_id:
- if event.exc is None:
- self._artifacts[job.artifact_id]["pending_count"] -= 1
- self._maybe_commit_artifact(job.artifact_id)
- else:
- if not self.silent:
- termerror(
- "Uploading artifact file failed. Artifact won't be committed."
- )
- self._fail_artifact_futures(job.artifact_id, event.exc)
- self._running_jobs.pop(job.save_name)
- # If we have any pending jobs, start one now
- if self._pending_jobs:
- event = self._pending_jobs.pop(0)
- self._start_upload_job(event)
- elif isinstance(event, RequestCommitArtifact):
- if event.artifact_id not in self._artifacts:
- self._init_artifact(event.artifact_id)
- self._artifacts[event.artifact_id]["commit_requested"] = True
- self._artifacts[event.artifact_id]["finalize"] = event.finalize
- self._artifacts[event.artifact_id]["pre_commit_callbacks"].add(
- event.before_commit
- )
- self._artifacts[event.artifact_id]["result_futures"].add(
- event.result_future
- )
- self._maybe_commit_artifact(event.artifact_id)
- elif isinstance(event, RequestUpload):
- if event.artifact_id is not None:
- if event.artifact_id not in self._artifacts:
- self._init_artifact(event.artifact_id)
- self._artifacts[event.artifact_id]["pending_count"] += 1
- self._start_upload_job(event)
- else:
- raise TypeError(f"Event has unexpected type: {event!s}")
- def _start_upload_job(self, event: RequestUpload) -> None:
- # Operations on a single backend file must be serialized. if
- # we're already uploading this file, put the event on the
- # end of the queue
- if event.save_name in self._running_jobs:
- self._pending_jobs.append(event)
- return
- self._spawn_upload(event)
- def _spawn_upload(self, event: RequestUpload) -> None:
- """Spawn an upload job, and handles the bookkeeping of `self._running_jobs`.
- Context: it's important that, whenever we add an entry to `self._running_jobs`,
- we ensure that a corresponding `EventJobDone` message will eventually get handled;
- otherwise, the `_running_jobs` entry will never get removed, and the StepUpload
- will never shut down.
- The sole purpose of this function is to make sure that the code that adds an entry
- to `self._running_jobs` is textually right next to the code that eventually enqueues
- the `EventJobDone` message. This should help keep them in sync.
- """
- # Adding the entry to `self._running_jobs` MUST happen in the main thread,
- # NOT in the job that gets submitted to the thread-pool, to guard against
- # this sequence of events:
- # - StepUpload receives a RequestUpload
- # ...and therefore spawns a thread to do the upload
- # - StepUpload receives a RequestFinish
- # ...and checks `self._running_jobs` to see if there are any tasks to wait for...
- # ...and there are none, because the addition to `self._running_jobs` happens in
- # the background thread, which the scheduler hasn't yet run...
- # ...so the StepUpload shuts down. Even though we haven't uploaded the file!
- #
- # This would be very bad!
- # So, this line has to happen _outside_ the `pool.submit()`.
- self._running_jobs[event.save_name] = event
- def run_and_notify() -> None:
- try:
- self._do_upload(event)
- finally:
- self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
- self._pool.submit(run_and_notify)
- def _do_upload(self, event: RequestUpload) -> None:
- job = upload_job.UploadJob(
- self._stats,
- self._api,
- self._file_stream,
- self.silent,
- event.save_name,
- event.path,
- event.artifact_id,
- event.md5,
- event.copied,
- event.save_fn,
- event.digest,
- )
- job.run()
- def _init_artifact(self, artifact_id: str) -> None:
- self._artifacts[artifact_id] = {
- "finalize": False,
- "pending_count": 0,
- "commit_requested": False,
- "pre_commit_callbacks": set(),
- "result_futures": set(),
- }
- def _maybe_commit_artifact(self, artifact_id: str) -> None:
- artifact_status = self._artifacts[artifact_id]
- if (
- artifact_status["pending_count"] == 0
- and artifact_status["commit_requested"]
- ):
- try:
- for pre_callback in artifact_status["pre_commit_callbacks"]:
- pre_callback()
- if artifact_status["finalize"]:
- self._api.commit_artifact(artifact_id)
- except Exception as exc:
- termerror(
- f"Committing artifact failed. Artifact {artifact_id} won't be finalized."
- )
- termerror(str(exc))
- self._fail_artifact_futures(artifact_id, exc)
- else:
- self._resolve_artifact_futures(artifact_id)
- def _fail_artifact_futures(self, artifact_id: str, exc: BaseException) -> None:
- futures = self._artifacts[artifact_id]["result_futures"]
- for result_future in futures:
- result_future.set_exception(exc)
- futures.clear()
- def _resolve_artifact_futures(self, artifact_id: str) -> None:
- futures = self._artifacts[artifact_id]["result_futures"]
- for result_future in futures:
- result_future.set_result(None)
- futures.clear()
- def start(self) -> None:
- self._thread.start()
- def is_alive(self) -> bool:
- return self._thread.is_alive()
|