step_checksum.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. """Batching file prepare requests to our API."""
  2. from __future__ import annotations
  3. import concurrent.futures
  4. import functools
  5. import os
  6. import queue
  7. import shutil
  8. import threading
  9. from typing import TYPE_CHECKING, NamedTuple, Union, cast
  10. from wandb.filesync import step_upload
  11. from wandb.sdk.lib import filesystem, runid
  12. from wandb.sdk.lib.paths import LogicalPath
  13. if TYPE_CHECKING:
  14. import tempfile
  15. from wandb.filesync import stats
  16. from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
  17. from wandb.sdk.artifacts.artifact_saver import SaveFn
  18. from wandb.sdk.internal import internal_api
  19. class RequestUpload(NamedTuple):
  20. path: str
  21. save_name: LogicalPath
  22. copy: bool
  23. class RequestStoreManifestFiles(NamedTuple):
  24. manifest: ArtifactManifest
  25. artifact_id: str
  26. save_fn: SaveFn
  27. class RequestCommitArtifact(NamedTuple):
  28. artifact_id: str
  29. finalize: bool
  30. before_commit: step_upload.PreCommitFn
  31. result_future: concurrent.futures.Future[None]
  32. class RequestFinish(NamedTuple):
  33. callback: step_upload.OnRequestFinishFn | None
  34. Event = Union[
  35. RequestUpload, RequestStoreManifestFiles, RequestCommitArtifact, RequestFinish
  36. ]
  37. class StepChecksum:
  38. def __init__(
  39. self,
  40. api: internal_api.Api,
  41. tempdir: tempfile.TemporaryDirectory,
  42. request_queue: queue.Queue[Event],
  43. output_queue: queue.Queue[step_upload.Event],
  44. stats: stats.Stats,
  45. ) -> None:
  46. self._api = api
  47. self._tempdir = tempdir
  48. self._request_queue = request_queue
  49. self._output_queue = output_queue
  50. self._stats = stats
  51. self._thread = threading.Thread(target=self._thread_body)
  52. self._thread.daemon = True
  53. def _thread_body(self) -> None:
  54. while True:
  55. req = self._request_queue.get()
  56. if isinstance(req, RequestUpload):
  57. path = req.path
  58. if req.copy:
  59. path = os.path.join(
  60. self._tempdir.name,
  61. f"{runid.generate_id()}-{req.save_name}",
  62. )
  63. filesystem.mkdir_exists_ok(os.path.dirname(path))
  64. try:
  65. # certain linux distros throw an exception when copying
  66. # large files: https://bugs.python.org/issue43743
  67. shutil.copy2(req.path, path)
  68. except OSError:
  69. shutil._USE_CP_SENDFILE = False # type: ignore[attr-defined]
  70. shutil.copy2(req.path, path)
  71. self._stats.init_file(req.save_name, os.path.getsize(path))
  72. self._output_queue.put(
  73. step_upload.RequestUpload(
  74. path,
  75. req.save_name,
  76. None,
  77. None,
  78. req.copy,
  79. None,
  80. None,
  81. )
  82. )
  83. elif isinstance(req, RequestStoreManifestFiles):
  84. for entry in req.manifest.entries.values():
  85. if entry.local_path:
  86. self._stats.init_file(
  87. entry.local_path,
  88. cast(int, entry.size),
  89. is_artifact_file=True,
  90. )
  91. self._output_queue.put(
  92. step_upload.RequestUpload(
  93. entry.local_path,
  94. entry.path,
  95. req.artifact_id,
  96. entry.digest,
  97. False,
  98. functools.partial(req.save_fn, entry),
  99. entry.digest,
  100. )
  101. )
  102. elif isinstance(req, RequestCommitArtifact):
  103. self._output_queue.put(
  104. step_upload.RequestCommitArtifact(
  105. req.artifact_id,
  106. req.finalize,
  107. req.before_commit,
  108. req.result_future,
  109. )
  110. )
  111. elif isinstance(req, RequestFinish):
  112. break
  113. else:
  114. raise TypeError
  115. self._output_queue.put(step_upload.RequestFinish(req.callback))
  116. def start(self) -> None:
  117. self._thread.start()
  118. def is_alive(self) -> bool:
  119. return self._thread.is_alive()
  120. def finish(self) -> None:
  121. self._request_queue.put(RequestFinish(None))