wait_with_progress.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from __future__ import annotations
  2. import time
  3. from collections.abc import Coroutine
  4. from typing import Any, Callable, TypeVar, cast
  5. from wandb.sdk.lib import asyncio_compat
  6. from .mailbox_handle import MailboxHandle
  7. _T = TypeVar("_T")
  8. def wait_with_progress(
  9. handle: MailboxHandle[_T],
  10. *,
  11. timeout: float | None,
  12. display_progress: Callable[[], Coroutine[Any, Any, None]],
  13. ) -> _T:
  14. """Wait for a handle, possibly displaying progress to the user.
  15. Equivalent to passing a single handle to `wait_all_with_progress`.
  16. """
  17. return wait_all_with_progress(
  18. [handle],
  19. timeout=timeout,
  20. display_progress=display_progress,
  21. )[0]
  22. def wait_all_with_progress(
  23. handle_list: list[MailboxHandle[_T]],
  24. *,
  25. timeout: float | None,
  26. display_progress: Callable[[], Coroutine[Any, Any, None]],
  27. ) -> list[_T]:
  28. """Wait for multiple handles, possibly displaying progress to the user.
  29. Args:
  30. handle_list: The handles to wait for.
  31. timeout: A number of seconds after which to raise a TimeoutError,
  32. or None if this should never timeout.
  33. display_progress: An asyncio function that displays progress to
  34. the user. This function runs using the handles' AsyncioManager.
  35. Returns:
  36. A list where the Nth item is the Nth handle's result.
  37. Raises:
  38. ValueError: If the handles live in different asyncio threads.
  39. TimeoutError: If the overall timeout expires.
  40. HandleAbandonedError: If any handle becomes abandoned.
  41. Exception: Any exception from the display function is propagated.
  42. """
  43. if not handle_list:
  44. return []
  45. asyncer = handle_list[0].asyncer
  46. for handle in handle_list:
  47. if handle.asyncer is not asyncer:
  48. raise ValueError("Handles have different AsyncioManagers.")
  49. start_time = time.monotonic()
  50. async def progress_loop_with_timeout() -> list[_T]:
  51. async with asyncio_compat.cancel_on_exit(display_progress()):
  52. if timeout is not None:
  53. elapsed_time = time.monotonic() - start_time
  54. remaining_timeout = timeout - elapsed_time
  55. else:
  56. remaining_timeout = None
  57. return await _wait_handles_async(
  58. handle_list,
  59. timeout=remaining_timeout,
  60. )
  61. return asyncer.run(progress_loop_with_timeout)
  62. async def _wait_handles_async(
  63. handle_list: list[MailboxHandle[_T]],
  64. *,
  65. timeout: float | None,
  66. ) -> list[_T]:
  67. """Asynchronously wait for multiple mailbox handles.
  68. Just like _wait_handles.
  69. """
  70. results: list[_T | None] = [None for _ in handle_list]
  71. async def wait_single(index: int) -> None:
  72. handle = handle_list[index]
  73. results[index] = await handle.wait_async(timeout=timeout)
  74. async with asyncio_compat.open_task_group() as task_group:
  75. for index in range(len(handle_list)):
  76. task_group.start_soon(wait_single(index))
  77. # NOTE: `list` is not subscriptable until Python 3.10, so we use List.
  78. return cast(list[_T], results)