step_prepare.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. """Batching file prepare requests to our API."""
  2. from __future__ import annotations
  3. import queue
  4. import threading
  5. import time
  6. from collections.abc import Mapping, Sequence
  7. from typing import TYPE_CHECKING, Callable, NamedTuple, Union
  8. if TYPE_CHECKING:
  9. from wandb.sdk.internal.internal_api import (
  10. Api,
  11. CreateArtifactFileSpecInput,
  12. CreateArtifactFilesResponseFile,
  13. )
  14. # Request for a file to be prepared.
  15. class RequestPrepare(NamedTuple):
  16. file_spec: CreateArtifactFileSpecInput
  17. response_channel: queue.Queue[ResponsePrepare]
  18. class RequestFinish(NamedTuple):
  19. pass
  20. class ResponsePrepare(NamedTuple):
  21. birth_artifact_id: str
  22. upload_url: str | None
  23. upload_headers: Sequence[str]
  24. upload_id: str | None
  25. storage_path: str | None
  26. multipart_upload_urls: dict[int, str] | None
  27. Request = Union[RequestPrepare, RequestFinish]
  28. def _clamp(x: float, low: float, high: float) -> float:
  29. return max(low, min(x, high))
  30. def gather_batch(
  31. request_queue: queue.Queue[Request],
  32. batch_time: float,
  33. inter_event_time: float,
  34. max_batch_size: int,
  35. clock: Callable[[], float] = time.monotonic,
  36. ) -> tuple[bool, Sequence[RequestPrepare]]:
  37. batch_start_time = clock()
  38. remaining_time = batch_time
  39. first_request = request_queue.get()
  40. if isinstance(first_request, RequestFinish):
  41. return True, []
  42. batch: list[RequestPrepare] = [first_request]
  43. while remaining_time > 0 and len(batch) < max_batch_size:
  44. try:
  45. request = request_queue.get(
  46. timeout=_clamp(
  47. x=inter_event_time,
  48. low=1e-12, # 0 = "block forever", so just use something tiny
  49. high=remaining_time,
  50. ),
  51. )
  52. if isinstance(request, RequestFinish):
  53. return True, batch
  54. batch.append(request)
  55. remaining_time = batch_time - (clock() - batch_start_time)
  56. except queue.Empty:
  57. break
  58. return False, batch
  59. def prepare_response(response: CreateArtifactFilesResponseFile) -> ResponsePrepare:
  60. multipart_resp = response.get("uploadMultipartUrls")
  61. part_list = multipart_resp["uploadUrlParts"] if multipart_resp else []
  62. multipart_parts = {u["partNumber"]: u["uploadUrl"] for u in part_list} or None
  63. return ResponsePrepare(
  64. birth_artifact_id=response["artifact"]["id"],
  65. upload_url=response["uploadUrl"],
  66. upload_headers=response["uploadHeaders"],
  67. upload_id=multipart_resp and multipart_resp.get("uploadID"),
  68. storage_path=response.get("storagePath"),
  69. multipart_upload_urls=multipart_parts,
  70. )
  71. class StepPrepare:
  72. """A thread that batches requests to our file prepare API.
  73. Any number of threads may call prepare() in parallel. The PrepareBatcher thread
  74. will batch requests up and send them all to the backend at once.
  75. """
  76. def __init__(
  77. self,
  78. api: Api,
  79. batch_time: float,
  80. inter_event_time: float,
  81. max_batch_size: int,
  82. request_queue: queue.Queue[Request] | None = None,
  83. ) -> None:
  84. self._api = api
  85. self._inter_event_time = inter_event_time
  86. self._batch_time = batch_time
  87. self._max_batch_size = max_batch_size
  88. self._request_queue: queue.Queue[Request] = request_queue or queue.Queue()
  89. self._thread = threading.Thread(target=self._thread_body)
  90. self._thread.daemon = True
  91. def _thread_body(self) -> None:
  92. while True:
  93. finish, batch = gather_batch(
  94. request_queue=self._request_queue,
  95. batch_time=self._batch_time,
  96. inter_event_time=self._inter_event_time,
  97. max_batch_size=self._max_batch_size,
  98. )
  99. if batch:
  100. batch_response = self._prepare_batch(batch)
  101. # send responses
  102. for prepare_request in batch:
  103. name = prepare_request.file_spec["name"]
  104. response_file = batch_response[name]
  105. response = prepare_response(response_file)
  106. prepare_request.response_channel.put(response)
  107. if finish:
  108. break
  109. def _prepare_batch(
  110. self, batch: Sequence[RequestPrepare]
  111. ) -> Mapping[str, CreateArtifactFilesResponseFile]:
  112. """Execute the prepareFiles API call.
  113. Args:
  114. batch: List of RequestPrepare objects
  115. Returns:
  116. dict of (save_name: ResponseFile) pairs where ResponseFile is a dict with
  117. an uploadUrl key. The value of the uploadUrl key is None if the file
  118. already exists, or a url string if the file should be uploaded.
  119. """
  120. return self._api.create_artifact_files([req.file_spec for req in batch])
  121. def prepare(
  122. self, file_spec: CreateArtifactFileSpecInput
  123. ) -> queue.Queue[ResponsePrepare]:
  124. response_queue: queue.Queue[ResponsePrepare] = queue.Queue()
  125. self._request_queue.put(RequestPrepare(file_spec, response_queue))
  126. return response_queue
  127. def start(self) -> None:
  128. self._thread.start()
  129. def finish(self) -> None:
  130. self._request_queue.put(RequestFinish())
  131. def is_alive(self) -> bool:
  132. return self._thread.is_alive()
  133. def shutdown(self) -> None:
  134. self.finish()
  135. self._thread.join()