| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- """progress."""
- from __future__ import annotations
- import os
- from typing import IO, TYPE_CHECKING
- from wandb.errors import CommError
- if TYPE_CHECKING:
- from typing import Protocol
- class ProgressFn(Protocol):
- def __call__(self, new_bytes: int, total_bytes: int) -> None:
- pass
- class Progress:
- """A helper class for displaying progress."""
- ITER_BYTES = 1024 * 1024
- def __init__(self, file: IO[bytes], callback: ProgressFn | None = None) -> None:
- self.file = file
- if callback is None:
- def callback_(new_bytes: int, total_bytes: int) -> None:
- pass
- callback = callback_
- self.callback: ProgressFn = callback
- self.bytes_read = 0
- self.len = os.fstat(file.fileno()).st_size
- def read(self, size=-1):
- """Read bytes and call the callback."""
- bites = self.file.read(size)
- self.bytes_read += len(bites)
- if not bites and self.bytes_read < self.len:
- # Files shrinking during uploads causes request timeouts. Maybe
- # we could avoid those by updating the self.len in real-time, but
- # files getting truncated while uploading seems like something
- # that shouldn't really be happening anyway.
- raise CommError(
- f"File {self.file.name} size shrank from {self.len} to {self.bytes_read} while it was being uploaded."
- )
- # Growing files are also likely to be bad, but our code didn't break
- # on those in the past, so it's riskier to make that an error now.
- self.callback(len(bites), self.bytes_read)
- return bites
- def rewind(self) -> None:
- self.callback(-self.bytes_read, 0)
- self.bytes_read = 0
- self.file.seek(0)
- def __getattr__(self, name):
- """Fallback to the file object for attrs not defined here."""
- if hasattr(self.file, name):
- return getattr(self.file, name)
- else:
- raise AttributeError
- def __iter__(self):
- return self
- def __next__(self):
- bites = self.read(self.ITER_BYTES)
- if len(bites) == 0:
- raise StopIteration
- return bites
- def __len__(self):
- return self.len
- next = __next__
|