progress.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. """Defines an object for printing run progress at the end of a script."""
  2. from __future__ import annotations
  3. import asyncio
  4. import contextlib
  5. import time
  6. from collections.abc import Iterator
  7. from typing import NoReturn
  8. from wandb.proto import wandb_internal_pb2 as pb
  9. from wandb.sdk.interface import interface
  10. from wandb.sdk.lib import asyncio_compat
  11. from . import printer as p
  12. _INDENT = " "
  13. _MAX_LINES_TO_PRINT = 6
  14. _MAX_OPS_TO_PRINT = 5
  15. async def loop_printing_operation_stats(
  16. progress: ProgressPrinter,
  17. interface: interface.InterfaceBase,
  18. ) -> None:
  19. """Poll and display ongoing tasks in the internal service process.
  20. This never returns and must be cancelled. This is meant to be used with
  21. `mailbox.wait_with_progress()`.
  22. Args:
  23. progress: The printer to update with operation stats.
  24. interface: The interface to use to poll for updates.
  25. Raises:
  26. HandleAbandonedError: If the mailbox associated with the interface
  27. becomes closed.
  28. Exception: Any other problem communicating with the service process.
  29. """
  30. stats: pb.OperationStats | None = None
  31. async def loop_update_screen() -> NoReturn:
  32. while True:
  33. if stats:
  34. progress.update(stats)
  35. await asyncio.sleep(0.1)
  36. async def loop_poll_stats() -> NoReturn:
  37. nonlocal stats
  38. while True:
  39. start_time = time.monotonic()
  40. handle = await interface.deliver_async(
  41. pb.Record(
  42. request=pb.Request(operations=pb.OperationStatsRequest()),
  43. )
  44. )
  45. result = await handle.wait_async(timeout=None)
  46. stats = result.response.operations_response.operation_stats
  47. elapsed_time = time.monotonic() - start_time
  48. if elapsed_time < 0.5:
  49. await asyncio.sleep(0.5 - elapsed_time)
  50. async with asyncio_compat.open_task_group() as task_group:
  51. task_group.start_soon(loop_update_screen())
  52. task_group.start_soon(loop_poll_stats())
  53. @contextlib.contextmanager
  54. def progress_printer(
  55. printer: p.Printer,
  56. default_text: str,
  57. ) -> Iterator[ProgressPrinter]:
  58. """Context manager providing an object for printing run progress.
  59. Args:
  60. printer: The printer to use.
  61. default_text: The text to show if no information is available.
  62. """
  63. with printer.dynamic_text() as text_area:
  64. try:
  65. yield ProgressPrinter(
  66. printer,
  67. text_area,
  68. default_text=default_text,
  69. )
  70. finally:
  71. printer.progress_close()
  72. class ProgressPrinter:
  73. """Displays PollExitResponse results to the user."""
  74. def __init__(
  75. self,
  76. printer: p.Printer,
  77. progress_text_area: p.DynamicText | None,
  78. default_text: str,
  79. ) -> None:
  80. self._printer = printer
  81. self._progress_text_area = progress_text_area
  82. self._default_text = default_text
  83. self._tick = -1
  84. self._last_printed_line = ""
  85. def update(
  86. self,
  87. stats_or_list: pb.OperationStats | list[pb.OperationStats],
  88. ) -> None:
  89. """Update the displayed information.
  90. Args:
  91. stats_or_list: A single group of operations, or zero or more
  92. labeled operation groups.
  93. """
  94. self._tick += 1
  95. if not self._progress_text_area:
  96. line = self._to_static_text(stats_or_list)
  97. if line and line != self._last_printed_line:
  98. self._printer.display(line)
  99. self._last_printed_line = line
  100. return
  101. lines = self._to_dynamic_text(stats_or_list)
  102. if not lines:
  103. loading_symbol = self._printer.loading_symbol(self._tick)
  104. if loading_symbol:
  105. lines = [f"{loading_symbol} {self._default_text}"]
  106. else:
  107. lines = [self._default_text]
  108. self._progress_text_area.set_text("\n".join(lines))
  109. def _to_dynamic_text(
  110. self,
  111. stats_or_list: pb.OperationStats | list[pb.OperationStats],
  112. ) -> list[str]:
  113. """Returns text to show in a dynamic text area."""
  114. loading_symbol = self._printer.loading_symbol(self._tick)
  115. if isinstance(stats_or_list, list):
  116. return _GroupedOperationStatsPrinter(
  117. self._printer,
  118. _MAX_LINES_TO_PRINT,
  119. loading_symbol,
  120. ).render(stats_or_list)
  121. else:
  122. return _OperationStatsPrinter(
  123. self._printer,
  124. _MAX_LINES_TO_PRINT,
  125. loading_symbol,
  126. ).render(stats_or_list)
  127. def _to_static_text(
  128. self,
  129. stats_or_list: pb.OperationStats | list[pb.OperationStats],
  130. ) -> str:
  131. """Returns a single line of text to print out."""
  132. if isinstance(stats_or_list, list):
  133. sorted_prefixed_stats = list(
  134. (f"[{stats.label}] " if stats.label else "", stats)
  135. for stats in sorted(stats_or_list, key=lambda s: s.label)
  136. )
  137. else:
  138. sorted_prefixed_stats = [("", stats_or_list)]
  139. group_strs: list[str] = []
  140. total_operations = 0
  141. total_printed = 0
  142. for prefix, stats in sorted_prefixed_stats:
  143. total_operations += stats.total_operations
  144. if not stats.operations:
  145. continue
  146. group_ops: list[str] = []
  147. i = 0
  148. while total_printed < _MAX_OPS_TO_PRINT and i < len(stats.operations):
  149. group_ops.append(stats.operations[i].desc)
  150. total_printed += 1
  151. i += 1
  152. if group_ops:
  153. group_strs.append(prefix + "; ".join(group_ops))
  154. line = "; ".join(group_strs)
  155. remaining = total_operations - total_printed
  156. if total_printed > 0 and remaining > 0:
  157. line += f" (+ {remaining} more)"
  158. return line
  159. class _GroupedOperationStatsPrinter:
  160. """Renders a list of labeled operation stats groups into lines of text."""
  161. def __init__(
  162. self,
  163. printer: p.Printer,
  164. max_lines: int,
  165. loading_symbol: str,
  166. ) -> None:
  167. self._printer = printer
  168. self._max_lines = max_lines
  169. self._loading_symbol = loading_symbol
  170. def render(self, stats_list: list[pb.OperationStats]) -> list[str]:
  171. """Convert labeled operation stats groups into text to display.
  172. Args:
  173. stats_list: A list of labeled operation stats.
  174. Returns:
  175. The lines of text to print. The lines do not end with the newline
  176. character. Returns an empty list if there are no operations.
  177. """
  178. lines: list[str] = []
  179. for stats in sorted(stats_list, key=lambda s: s.label):
  180. # Don't display empty groups.
  181. if not stats.operations:
  182. continue
  183. if stats.label:
  184. remaining_non_header_lines = self._max_lines - len(lines) - 1
  185. indent = _INDENT
  186. else:
  187. remaining_non_header_lines = self._max_lines - len(lines)
  188. indent = ""
  189. # Ensure enough space left for at least one line of content
  190. # after the header.
  191. if remaining_non_header_lines < 1:
  192. break
  193. # Group header (if not empty).
  194. if stats.label:
  195. lines.append(stats.label)
  196. # Group content.
  197. stats_lines = _OperationStatsPrinter(
  198. printer=self._printer,
  199. max_lines=remaining_non_header_lines,
  200. loading_symbol=self._loading_symbol,
  201. ).render(stats)
  202. for line in stats_lines:
  203. lines.append(f"{indent}{line}")
  204. return lines
  205. class _OperationStatsPrinter:
  206. """Renders operation stats into lines of text."""
  207. def __init__(
  208. self,
  209. printer: p.Printer,
  210. max_lines: int,
  211. loading_symbol: str,
  212. ) -> None:
  213. self._printer = printer
  214. self._max_lines = max_lines
  215. self._loading_symbol = loading_symbol
  216. self._lines: list[str] = []
  217. self._ops_shown = 0
  218. def render(self, stats: pb.OperationStats) -> list[str]:
  219. """Convert the stats into a list of lines to display.
  220. Args:
  221. stats: Collection of operations to display.
  222. Returns:
  223. The lines of text to print. The lines do not end with the newline
  224. character. Returns an empty list if there are no operations.
  225. """
  226. for op in stats.operations:
  227. self._add_operation(op, is_subtask=False, indent="")
  228. if self._ops_shown < stats.total_operations:
  229. if 1 <= self._max_lines <= len(self._lines):
  230. self._ops_shown -= 1
  231. self._lines.pop()
  232. remaining = stats.total_operations - self._ops_shown
  233. self._lines.append(f"+ {remaining} more task(s)")
  234. return self._lines
  235. def _add_operation(self, op: pb.Operation, is_subtask: bool, indent: str) -> None:
  236. """Add the operation to `self._lines`."""
  237. if len(self._lines) >= self._max_lines:
  238. return
  239. if not is_subtask:
  240. self._ops_shown += 1
  241. status_indent_level = 0 # alignment for the status message, if any
  242. parts: list[str] = []
  243. # Subtask indicator.
  244. if is_subtask and self._printer.supports_unicode:
  245. status_indent_level += 2 # +1 for space
  246. parts.append("↳")
  247. # Loading symbol.
  248. if self._loading_symbol:
  249. status_indent_level += 2 # +1 for space
  250. parts.append(self._loading_symbol)
  251. # Task name.
  252. parts.append(op.desc)
  253. # Progress information.
  254. if op.progress:
  255. parts.append(f"{op.progress}")
  256. # Task duration.
  257. parts.append(f"({_time_to_string(seconds=op.runtime_seconds)})")
  258. # Error status.
  259. self._lines.append(indent + " ".join(parts))
  260. if op.error_status:
  261. error_word = self._printer.error("ERROR")
  262. error_desc = self._printer.secondary_text(op.error_status)
  263. status_indent = " " * status_indent_level
  264. self._lines.append(
  265. f"{indent}{status_indent}{error_word} {error_desc}",
  266. )
  267. # Subtasks.
  268. if op.subtasks:
  269. subtask_indent = indent + _INDENT
  270. for task in op.subtasks:
  271. self._add_operation(
  272. task,
  273. is_subtask=True,
  274. indent=subtask_indent,
  275. )
  276. def _time_to_string(seconds: float) -> str:
  277. """Returns a short string representing the duration."""
  278. if seconds < 10:
  279. return f"{seconds:.1f}s"
  280. if seconds < 60:
  281. return f"{seconds:.0f}s"
  282. if seconds < 60 * 60:
  283. minutes = seconds / 60
  284. return f"{minutes:.1f}m"
  285. hours = int(seconds / (60 * 60))
  286. minutes = int((seconds / 60) % 60)
  287. return f"{hours}h{minutes}m"