| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279 |
- """Functions for compatibility with asyncio."""
- from __future__ import annotations
- import asyncio
- import concurrent
- import concurrent.futures
- import contextlib
- import threading
- from collections.abc import AsyncIterator, Coroutine
- from typing import Any, Callable, TypeVar
- _T = TypeVar("_T")
- def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
- """Run `fn` in an asyncio loop in a new thread.
- This must always be used instead of `asyncio.run` which fails if there is
- an active `asyncio` event loop in the current thread. Since `wandb` was not
- originally designed with `asyncio` in mind, using `asyncio.run` would break
- users who were calling `wandb` methods from an `asyncio` loop.
- Note that due to starting a new thread, this is slightly slow.
- """
- with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
- runner = CancellableRunner()
- future = executor.submit(runner.run, fn)
- try:
- return future.result()
- finally:
- runner.cancel()
- class RunnerCancelledError(Exception):
- """The `CancellableRunner.run()` invocation was cancelled."""
- class CancellableRunner:
- """Runs an asyncio event loop allowing cancellation.
- The `run()` method is like `asyncio.run()`. The `cancel()` method may
- be used in a different thread, for instance in a `finally` block, to cancel
- all tasks, and it is a no-op if `run()` completed.
- Without this, it is impossible to make `asyncio.run()` stop if it runs
- in a non-main thread. In particular, a KeyboardInterrupt causes the
- ThreadPoolExecutor above to block until the asyncio thread completes,
- but there is no way to tell the asyncio thread to cancel its work.
- A second KeyboardInterrupt makes ThreadPoolExecutor give up while the
- asyncio thread still runs in the background, with terrible effects if it
- prints to the user's terminal.
- """
- def __init__(self) -> None:
- self._lock = threading.Lock()
- self._is_cancelled = False
- self._started = False
- self._done = False
- self._loop: asyncio.AbstractEventLoop | None = None
- self._cancel_event: asyncio.Event | None = None
- def run(self, fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
- """Run a coroutine in asyncio, cancelling it on `cancel()`.
- Returns:
- The result of the coroutine returned by `fn`.
- Raises:
- RunnerCancelledError: If `cancel()` is called.
- """
- return asyncio.run(self._run_or_cancel(fn))
- async def _run_or_cancel(
- self,
- fn: Callable[[], Coroutine[Any, Any, _T]],
- ) -> _T:
- with self._lock:
- if self._is_cancelled:
- raise RunnerCancelledError()
- self._loop = asyncio.get_running_loop()
- self._cancel_event = asyncio.Event()
- self._started = True
- cancellation_task = asyncio.create_task(self._cancel_event.wait())
- fn_task = asyncio.create_task(fn())
- try:
- await asyncio.wait(
- [cancellation_task, fn_task],
- return_when=asyncio.FIRST_COMPLETED,
- )
- if fn_task.done():
- return fn_task.result()
- else:
- raise RunnerCancelledError()
- finally:
- # NOTE: asyncio.run() cancels all tasks after the main task exits,
- # but this is not documented, so we cancel them explicitly here
- # as well. It also blocks until canceled tasks complete.
- cancellation_task.cancel()
- fn_task.cancel()
- with self._lock:
- self._done = True
- def cancel(self) -> None:
- """Cancel all asyncio work started by `run()`."""
- with self._lock:
- if self._is_cancelled:
- return
- self._is_cancelled = True
- if self._done or not self._started:
- # If the runner already finished, no need to cancel it.
- #
- # If the runner hasn't started the loop yet, then it will not
- # as we already set _is_cancelled.
- return
- assert self._loop
- assert self._cancel_event
- self._loop.call_soon_threadsafe(self._cancel_event.set)
- class TaskGroup:
- """Object that `open_task_group()` yields."""
- def __init__(self) -> None:
- self._tasks: list[asyncio.Task[None]] = []
- def start_soon(self, coro: Coroutine[Any, Any, Any]) -> None:
- """Schedule a task in the group.
- Args:
- coro: The return value of the `async` function defining the task.
- """
- self._tasks.append(asyncio.create_task(coro))
- async def _wait_all(self, *, race: bool, timeout: float | None) -> None:
- """Block until tasks complete.
- Args:
- race: If true, blocks until the first task completes and then
- cancels the rest. Otherwise, waits for all tasks or until
- the first exception.
- timeout: How long to wait.
- Raises:
- TimeoutError: If the timeout expires.
- Exception: If one or more tasks raises an exception, one of these
- is raised arbitrarily.
- """
- if not self._tasks:
- return
- if race:
- return_when = asyncio.FIRST_COMPLETED
- else:
- return_when = asyncio.FIRST_EXCEPTION
- done, pending = await asyncio.wait(
- self._tasks,
- timeout=timeout,
- return_when=return_when,
- )
- if not done:
- raise TimeoutError(f"Timed out after {timeout} seconds.")
- # If any of the finished tasks raised an exception, pick the first one.
- for task in done:
- if exc := task.exception():
- raise exc
- # Wait for remaining tasks to clean up, then re-raise any exceptions
- # that arise. Note that pending is only non-empty when race=True.
- for task in pending:
- task.cancel()
- await asyncio.gather(*pending, return_exceptions=True)
- for task in pending:
- if task.cancelled():
- continue
- if exc := task.exception():
- raise exc
- async def _cancel_all(self) -> None:
- """Cancel all tasks.
- Blocks until cancelled tasks complete to allow them to clean up.
- Ignores exceptions.
- """
- for task in self._tasks:
- # NOTE: It is safe to cancel tasks that have already completed.
- task.cancel()
- await asyncio.gather(*self._tasks, return_exceptions=True)
- @contextlib.asynccontextmanager
- async def open_task_group(
- *,
- exit_timeout: float | None = None,
- race: bool = False,
- ) -> AsyncIterator[TaskGroup]:
- """Create a task group.
- `asyncio` gained task groups in Python 3.11.
- This is an async context manager, meant to be used with `async with`.
- On exit, it blocks until all subtasks complete. If any subtask fails, or if
- the current task is cancelled, it cancels all subtasks in the group and
- raises the subtask's exception. If multiple subtasks fail simultaneously,
- one of their exceptions is chosen arbitrarily.
- NOTE: Subtask exceptions do not propagate until the context manager exits.
- This means that the task group cannot cancel code running inside the
- `async with` block .
- Args:
- exit_timeout: An optional timeout in seconds. When exiting the
- context manager, if tasks don't complete in this time,
- they are cancelled and a TimeoutError is raised.
- race: If true, all pending tasks are cancelled once any task
- in the group completes. Prefer to use the race() function instead.
- Raises:
- TimeoutError: if exit_timeout is specified and tasks don't finish
- in time.
- """
- task_group = TaskGroup()
- try:
- yield task_group
- await task_group._wait_all(race=race, timeout=exit_timeout)
- finally:
- await task_group._cancel_all()
- @contextlib.asynccontextmanager
- async def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> AsyncIterator[None]:
- """Schedule a task, cancelling it when exiting the context manager.
- If the context manager exits successfully but the given coroutine raises
- an exception, that exception is reraised. The exception is suppressed
- if the context manager raises an exception.
- """
- async def stop_immediately():
- pass
- async with open_task_group(race=True) as group:
- group.start_soon(stop_immediately())
- group.start_soon(coro)
- yield
- async def race(*coros: Coroutine[Any, Any, Any]) -> None:
- """Wait until the first completed task.
- After any coroutine completes, all others are cancelled.
- If the current task is cancelled, all coroutines are cancelled too.
- If coroutines complete simultaneously and any one of them raises
- an exception, an arbitrary one is propagated. Similarly, if any coroutines
- raise exceptions during cancellation, one of them propagates.
- Args:
- coros: Coroutines to race.
- """
- async with open_task_group(race=True) as tg:
- for coro in coros:
- tg.start_soon(coro)
|