| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- from __future__ import annotations
- import time
- from collections.abc import Coroutine
- from typing import Any, Callable, TypeVar, cast
- from wandb.sdk.lib import asyncio_compat
- from .mailbox_handle import MailboxHandle
- _T = TypeVar("_T")
- def wait_with_progress(
- handle: MailboxHandle[_T],
- *,
- timeout: float | None,
- display_progress: Callable[[], Coroutine[Any, Any, None]],
- ) -> _T:
- """Wait for a handle, possibly displaying progress to the user.
- Equivalent to passing a single handle to `wait_all_with_progress`.
- """
- return wait_all_with_progress(
- [handle],
- timeout=timeout,
- display_progress=display_progress,
- )[0]
- def wait_all_with_progress(
- handle_list: list[MailboxHandle[_T]],
- *,
- timeout: float | None,
- display_progress: Callable[[], Coroutine[Any, Any, None]],
- ) -> list[_T]:
- """Wait for multiple handles, possibly displaying progress to the user.
- Args:
- handle_list: The handles to wait for.
- timeout: A number of seconds after which to raise a TimeoutError,
- or None if this should never timeout.
- display_progress: An asyncio function that displays progress to
- the user. This function runs using the handles' AsyncioManager.
- Returns:
- A list where the Nth item is the Nth handle's result.
- Raises:
- ValueError: If the handles live in different asyncio threads.
- TimeoutError: If the overall timeout expires.
- HandleAbandonedError: If any handle becomes abandoned.
- Exception: Any exception from the display function is propagated.
- """
- if not handle_list:
- return []
- asyncer = handle_list[0].asyncer
- for handle in handle_list:
- if handle.asyncer is not asyncer:
- raise ValueError("Handles have different AsyncioManagers.")
- start_time = time.monotonic()
- async def progress_loop_with_timeout() -> list[_T]:
- async with asyncio_compat.cancel_on_exit(display_progress()):
- if timeout is not None:
- elapsed_time = time.monotonic() - start_time
- remaining_timeout = timeout - elapsed_time
- else:
- remaining_timeout = None
- return await _wait_handles_async(
- handle_list,
- timeout=remaining_timeout,
- )
- return asyncer.run(progress_loop_with_timeout)
- async def _wait_handles_async(
- handle_list: list[MailboxHandle[_T]],
- *,
- timeout: float | None,
- ) -> list[_T]:
- """Asynchronously wait for multiple mailbox handles.
- Just like _wait_handles.
- """
- results: list[_T | None] = [None for _ in handle_list]
- async def wait_single(index: int) -> None:
- handle = handle_list[index]
- results[index] = await handle.wait_async(timeout=timeout)
- async with asyncio_compat.open_task_group() as task_group:
- for index in range(len(handle_list)):
- task_group.start_soon(wait_single(index))
- # NOTE: `list` is not subscriptable until Python 3.10, so we use List.
- return cast(list[_T], results)
|