stats.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import threading
  2. from collections.abc import MutableMapping
  3. from typing import NamedTuple
  4. from wandb.sdk.lib import filenames
  5. class FileStats(NamedTuple):
  6. deduped: bool
  7. total: int
  8. uploaded: int
  9. failed: bool
  10. artifact_file: bool
  11. class Summary(NamedTuple):
  12. uploaded_bytes: int
  13. total_bytes: int
  14. deduped_bytes: int
  15. class FileCountsByCategory(NamedTuple):
  16. artifact: int
  17. wandb: int
  18. media: int
  19. other: int
  20. class Stats:
  21. def __init__(self) -> None:
  22. self._stats: MutableMapping[str, FileStats] = {}
  23. self._lock = threading.Lock()
  24. def init_file(
  25. self, save_name: str, size: int, is_artifact_file: bool = False
  26. ) -> None:
  27. with self._lock:
  28. self._stats[save_name] = FileStats(
  29. deduped=False,
  30. total=size,
  31. uploaded=0,
  32. failed=False,
  33. artifact_file=is_artifact_file,
  34. )
  35. def set_file_deduped(self, save_name: str) -> None:
  36. with self._lock:
  37. orig = self._stats[save_name]
  38. self._stats[save_name] = orig._replace(
  39. deduped=True,
  40. uploaded=orig.total,
  41. )
  42. def update_uploaded_file(self, save_name: str, total_uploaded: int) -> None:
  43. with self._lock:
  44. self._stats[save_name] = self._stats[save_name]._replace(
  45. uploaded=total_uploaded,
  46. )
  47. def update_failed_file(self, save_name: str) -> None:
  48. with self._lock:
  49. self._stats[save_name] = self._stats[save_name]._replace(
  50. uploaded=0,
  51. failed=True,
  52. )
  53. def summary(self) -> Summary:
  54. # Need to use list to ensure we get a copy, since other threads may
  55. # modify this while we iterate
  56. with self._lock:
  57. stats = list(self._stats.values())
  58. return Summary(
  59. uploaded_bytes=sum(f.uploaded for f in stats),
  60. total_bytes=sum(f.total for f in stats),
  61. deduped_bytes=sum(f.total for f in stats if f.deduped),
  62. )
  63. def file_counts_by_category(self) -> FileCountsByCategory:
  64. artifact_files = 0
  65. wandb_files = 0
  66. media_files = 0
  67. other_files = 0
  68. # Need to use list to ensure we get a copy, since other threads may
  69. # modify this while we iterate
  70. with self._lock:
  71. file_stats = list(self._stats.items())
  72. for save_name, stats in file_stats:
  73. if stats.artifact_file:
  74. artifact_files += 1
  75. elif filenames.is_wandb_file(save_name):
  76. wandb_files += 1
  77. elif save_name.startswith("media"):
  78. media_files += 1
  79. else:
  80. other_files += 1
  81. return FileCountsByCategory(
  82. artifact=artifact_files,
  83. wandb=wandb_files,
  84. media=media_files,
  85. other=other_files,
  86. )