upload_job.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. from typing import TYPE_CHECKING
  5. import wandb
  6. from wandb.analytics import get_sentry
  7. from wandb.sdk.lib.paths import LogicalPath
  8. if TYPE_CHECKING:
  9. from wandb.filesync import dir_watcher, stats, step_upload
  10. from wandb.sdk.internal import file_stream, internal_api
  11. logger = logging.getLogger(__name__)
  12. class UploadJob:
  13. def __init__(
  14. self,
  15. stats: stats.Stats,
  16. api: internal_api.Api,
  17. file_stream: file_stream.FileStreamApi,
  18. silent: bool,
  19. save_name: LogicalPath,
  20. path: dir_watcher.PathStr,
  21. artifact_id: str | None,
  22. md5: str | None,
  23. copied: bool,
  24. save_fn: step_upload.SaveFn | None,
  25. digest: str | None,
  26. ) -> None:
  27. """A file uploader.
  28. Args:
  29. push_function: function(save_name, actual_path) which actually uploads
  30. the file.
  31. save_name: string logical location of the file relative to the run
  32. directory.
  33. path: actual string path of the file to upload on the filesystem.
  34. """
  35. self._stats = stats
  36. self._api = api
  37. self._file_stream = file_stream
  38. self.silent = silent
  39. self.save_name = save_name
  40. self.save_path = path
  41. self.artifact_id = artifact_id
  42. self.md5 = md5
  43. self.copied = copied
  44. self.save_fn = save_fn
  45. self.digest = digest
  46. super().__init__()
  47. def run(self) -> None:
  48. success = False
  49. try:
  50. self.push()
  51. success = True
  52. finally:
  53. if self.copied and os.path.isfile(self.save_path):
  54. os.remove(self.save_path)
  55. if success:
  56. self._file_stream.push_success(self.artifact_id, self.save_name) # type: ignore
  57. def push(self) -> None:
  58. if self.save_fn:
  59. # Retry logic must happen in save_fn currently
  60. try:
  61. deduped = self.save_fn(
  62. lambda _, t: self._stats.update_uploaded_file(self.save_path, t)
  63. )
  64. except Exception as e:
  65. self._stats.update_failed_file(self.save_path)
  66. logger.exception("Failed to upload file: %s", self.save_path)
  67. get_sentry().exception(e)
  68. message = str(e)
  69. # TODO: this is usually XML, but could be JSON
  70. if hasattr(e, "response"):
  71. message = e.response.content
  72. wandb.termerror(
  73. f'Error uploading "{self.save_path}": {type(e).__name__}, {message}'
  74. )
  75. raise
  76. if deduped:
  77. logger.info("Skipped uploading %s", self.save_path)
  78. self._stats.set_file_deduped(self.save_path)
  79. else:
  80. logger.info("Uploaded file %s", self.save_path)
  81. return
  82. if self.md5:
  83. # This is the new artifact manifest upload flow, in which we create the
  84. # database entry for the manifest file before creating it. This is used for
  85. # artifact L0 files. Which now is only artifact_manifest.json
  86. _, response = self._api.create_artifact_manifest(
  87. self.save_name, self.md5, self.artifact_id
  88. )
  89. upload_url = response["uploadUrl"]
  90. upload_headers = response["uploadHeaders"]
  91. else:
  92. # The classic file upload flow. We get a signed url and upload the file
  93. # then the backend handles the cloud storage metadata callback to create the
  94. # file entry. This flow has aged like a fine wine.
  95. project = self._api.get_project()
  96. _, upload_headers, result = self._api.upload_urls(project, [self.save_name])
  97. file_info = result[self.save_name]
  98. upload_url = file_info["uploadUrl"]
  99. if upload_url is None:
  100. logger.info("Skipped uploading %s", self.save_path)
  101. self._stats.set_file_deduped(self.save_name)
  102. else:
  103. extra_headers = self._api._extra_http_headers
  104. for upload_header in upload_headers:
  105. key, val = upload_header.split(":", 1)
  106. extra_headers[key] = val
  107. # Copied from push TODO(artifacts): clean up
  108. # If the upload URL is relative, fill it in with the base URL,
  109. # since its a proxied file store like the on-prem VM.
  110. if upload_url.startswith("/"):
  111. upload_url = f"{self._api.api_url}{upload_url}"
  112. try:
  113. with open(self.save_path, "rb") as f:
  114. self._api.upload_file_retry(
  115. upload_url,
  116. f,
  117. lambda _, t: self.progress(t),
  118. extra_headers=extra_headers,
  119. )
  120. logger.info("Uploaded file %s", self.save_path)
  121. except Exception as e:
  122. self._stats.update_failed_file(self.save_name)
  123. logger.exception("Failed to upload file: %s", self.save_path)
  124. get_sentry().exception(e)
  125. if not self.silent:
  126. wandb.termerror(
  127. f'Error uploading "{self.save_name}": {type(e).__name__}, {e}'
  128. )
  129. raise
  130. def progress(self, total_bytes: int) -> None:
  131. self._stats.update_uploaded_file(self.save_name, total_bytes)