progress.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. """progress."""
  2. from __future__ import annotations
  3. import os
  4. from typing import IO, TYPE_CHECKING
  5. from wandb.errors import CommError
  6. if TYPE_CHECKING:
  7. from typing import Protocol
  8. class ProgressFn(Protocol):
  9. def __call__(self, new_bytes: int, total_bytes: int) -> None:
  10. pass
  11. class Progress:
  12. """A helper class for displaying progress."""
  13. ITER_BYTES = 1024 * 1024
  14. def __init__(self, file: IO[bytes], callback: ProgressFn | None = None) -> None:
  15. self.file = file
  16. if callback is None:
  17. def callback_(new_bytes: int, total_bytes: int) -> None:
  18. pass
  19. callback = callback_
  20. self.callback: ProgressFn = callback
  21. self.bytes_read = 0
  22. self.len = os.fstat(file.fileno()).st_size
  23. def read(self, size=-1):
  24. """Read bytes and call the callback."""
  25. bites = self.file.read(size)
  26. self.bytes_read += len(bites)
  27. if not bites and self.bytes_read < self.len:
  28. # Files shrinking during uploads causes request timeouts. Maybe
  29. # we could avoid those by updating the self.len in real-time, but
  30. # files getting truncated while uploading seems like something
  31. # that shouldn't really be happening anyway.
  32. raise CommError(
  33. f"File {self.file.name} size shrank from {self.len} to {self.bytes_read} while it was being uploaded."
  34. )
  35. # Growing files are also likely to be bad, but our code didn't break
  36. # on those in the past, so it's riskier to make that an error now.
  37. self.callback(len(bites), self.bytes_read)
  38. return bites
  39. def rewind(self) -> None:
  40. self.callback(-self.bytes_read, 0)
  41. self.bytes_read = 0
  42. self.file.seek(0)
  43. def __getattr__(self, name):
  44. """Fallback to the file object for attrs not defined here."""
  45. if hasattr(self.file, name):
  46. return getattr(self.file, name)
  47. else:
  48. raise AttributeError
  49. def __iter__(self):
  50. return self
  51. def __next__(self):
  52. bites = self.read(self.ITER_BYTES)
  53. if len(bites) == 0:
  54. raise StopIteration
  55. return bites
  56. def __len__(self):
  57. return self.len
  58. next = __next__