asyncio_compat.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. """Functions for compatibility with asyncio."""
  2. from __future__ import annotations
  3. import asyncio
  4. import concurrent
  5. import concurrent.futures
  6. import contextlib
  7. import threading
  8. from collections.abc import AsyncIterator, Coroutine
  9. from typing import Any, Callable, TypeVar
  10. _T = TypeVar("_T")
  11. def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
  12. """Run `fn` in an asyncio loop in a new thread.
  13. This must always be used instead of `asyncio.run` which fails if there is
  14. an active `asyncio` event loop in the current thread. Since `wandb` was not
  15. originally designed with `asyncio` in mind, using `asyncio.run` would break
  16. users who were calling `wandb` methods from an `asyncio` loop.
  17. Note that due to starting a new thread, this is slightly slow.
  18. """
  19. with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
  20. runner = CancellableRunner()
  21. future = executor.submit(runner.run, fn)
  22. try:
  23. return future.result()
  24. finally:
  25. runner.cancel()
  26. class RunnerCancelledError(Exception):
  27. """The `CancellableRunner.run()` invocation was cancelled."""
  28. class CancellableRunner:
  29. """Runs an asyncio event loop allowing cancellation.
  30. The `run()` method is like `asyncio.run()`. The `cancel()` method may
  31. be used in a different thread, for instance in a `finally` block, to cancel
  32. all tasks, and it is a no-op if `run()` completed.
  33. Without this, it is impossible to make `asyncio.run()` stop if it runs
  34. in a non-main thread. In particular, a KeyboardInterrupt causes the
  35. ThreadPoolExecutor above to block until the asyncio thread completes,
  36. but there is no way to tell the asyncio thread to cancel its work.
  37. A second KeyboardInterrupt makes ThreadPoolExecutor give up while the
  38. asyncio thread still runs in the background, with terrible effects if it
  39. prints to the user's terminal.
  40. """
  41. def __init__(self) -> None:
  42. self._lock = threading.Lock()
  43. self._is_cancelled = False
  44. self._started = False
  45. self._done = False
  46. self._loop: asyncio.AbstractEventLoop | None = None
  47. self._cancel_event: asyncio.Event | None = None
  48. def run(self, fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
  49. """Run a coroutine in asyncio, cancelling it on `cancel()`.
  50. Returns:
  51. The result of the coroutine returned by `fn`.
  52. Raises:
  53. RunnerCancelledError: If `cancel()` is called.
  54. """
  55. return asyncio.run(self._run_or_cancel(fn))
  56. async def _run_or_cancel(
  57. self,
  58. fn: Callable[[], Coroutine[Any, Any, _T]],
  59. ) -> _T:
  60. with self._lock:
  61. if self._is_cancelled:
  62. raise RunnerCancelledError()
  63. self._loop = asyncio.get_running_loop()
  64. self._cancel_event = asyncio.Event()
  65. self._started = True
  66. cancellation_task = asyncio.create_task(self._cancel_event.wait())
  67. fn_task = asyncio.create_task(fn())
  68. try:
  69. await asyncio.wait(
  70. [cancellation_task, fn_task],
  71. return_when=asyncio.FIRST_COMPLETED,
  72. )
  73. if fn_task.done():
  74. return fn_task.result()
  75. else:
  76. raise RunnerCancelledError()
  77. finally:
  78. # NOTE: asyncio.run() cancels all tasks after the main task exits,
  79. # but this is not documented, so we cancel them explicitly here
  80. # as well. It also blocks until canceled tasks complete.
  81. cancellation_task.cancel()
  82. fn_task.cancel()
  83. with self._lock:
  84. self._done = True
  85. def cancel(self) -> None:
  86. """Cancel all asyncio work started by `run()`."""
  87. with self._lock:
  88. if self._is_cancelled:
  89. return
  90. self._is_cancelled = True
  91. if self._done or not self._started:
  92. # If the runner already finished, no need to cancel it.
  93. #
  94. # If the runner hasn't started the loop yet, then it will not
  95. # as we already set _is_cancelled.
  96. return
  97. assert self._loop
  98. assert self._cancel_event
  99. self._loop.call_soon_threadsafe(self._cancel_event.set)
  100. class TaskGroup:
  101. """Object that `open_task_group()` yields."""
  102. def __init__(self) -> None:
  103. self._tasks: list[asyncio.Task[None]] = []
  104. def start_soon(self, coro: Coroutine[Any, Any, Any]) -> None:
  105. """Schedule a task in the group.
  106. Args:
  107. coro: The return value of the `async` function defining the task.
  108. """
  109. self._tasks.append(asyncio.create_task(coro))
  110. async def _wait_all(self, *, race: bool, timeout: float | None) -> None:
  111. """Block until tasks complete.
  112. Args:
  113. race: If true, blocks until the first task completes and then
  114. cancels the rest. Otherwise, waits for all tasks or until
  115. the first exception.
  116. timeout: How long to wait.
  117. Raises:
  118. TimeoutError: If the timeout expires.
  119. Exception: If one or more tasks raises an exception, one of these
  120. is raised arbitrarily.
  121. """
  122. if not self._tasks:
  123. return
  124. if race:
  125. return_when = asyncio.FIRST_COMPLETED
  126. else:
  127. return_when = asyncio.FIRST_EXCEPTION
  128. done, pending = await asyncio.wait(
  129. self._tasks,
  130. timeout=timeout,
  131. return_when=return_when,
  132. )
  133. if not done:
  134. raise TimeoutError(f"Timed out after {timeout} seconds.")
  135. # If any of the finished tasks raised an exception, pick the first one.
  136. for task in done:
  137. if exc := task.exception():
  138. raise exc
  139. # Wait for remaining tasks to clean up, then re-raise any exceptions
  140. # that arise. Note that pending is only non-empty when race=True.
  141. for task in pending:
  142. task.cancel()
  143. await asyncio.gather(*pending, return_exceptions=True)
  144. for task in pending:
  145. if task.cancelled():
  146. continue
  147. if exc := task.exception():
  148. raise exc
  149. async def _cancel_all(self) -> None:
  150. """Cancel all tasks.
  151. Blocks until cancelled tasks complete to allow them to clean up.
  152. Ignores exceptions.
  153. """
  154. for task in self._tasks:
  155. # NOTE: It is safe to cancel tasks that have already completed.
  156. task.cancel()
  157. await asyncio.gather(*self._tasks, return_exceptions=True)
  158. @contextlib.asynccontextmanager
  159. async def open_task_group(
  160. *,
  161. exit_timeout: float | None = None,
  162. race: bool = False,
  163. ) -> AsyncIterator[TaskGroup]:
  164. """Create a task group.
  165. `asyncio` gained task groups in Python 3.11.
  166. This is an async context manager, meant to be used with `async with`.
  167. On exit, it blocks until all subtasks complete. If any subtask fails, or if
  168. the current task is cancelled, it cancels all subtasks in the group and
  169. raises the subtask's exception. If multiple subtasks fail simultaneously,
  170. one of their exceptions is chosen arbitrarily.
  171. NOTE: Subtask exceptions do not propagate until the context manager exits.
  172. This means that the task group cannot cancel code running inside the
  173. `async with` block .
  174. Args:
  175. exit_timeout: An optional timeout in seconds. When exiting the
  176. context manager, if tasks don't complete in this time,
  177. they are cancelled and a TimeoutError is raised.
  178. race: If true, all pending tasks are cancelled once any task
  179. in the group completes. Prefer to use the race() function instead.
  180. Raises:
  181. TimeoutError: if exit_timeout is specified and tasks don't finish
  182. in time.
  183. """
  184. task_group = TaskGroup()
  185. try:
  186. yield task_group
  187. await task_group._wait_all(race=race, timeout=exit_timeout)
  188. finally:
  189. await task_group._cancel_all()
  190. @contextlib.asynccontextmanager
  191. async def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> AsyncIterator[None]:
  192. """Schedule a task, cancelling it when exiting the context manager.
  193. If the context manager exits successfully but the given coroutine raises
  194. an exception, that exception is reraised. The exception is suppressed
  195. if the context manager raises an exception.
  196. """
  197. async def stop_immediately():
  198. pass
  199. async with open_task_group(race=True) as group:
  200. group.start_soon(stop_immediately())
  201. group.start_soon(coro)
  202. yield
  203. async def race(*coros: Coroutine[Any, Any, Any]) -> None:
  204. """Wait until the first completed task.
  205. After any coroutine completes, all others are cancelled.
  206. If the current task is cancelled, all coroutines are cancelled too.
  207. If coroutines complete simultaneously and any one of them raises
  208. an exception, an arbitrary one is propagated. Similarly, if any coroutines
  209. raise exceptions during cancellation, one of them propagates.
  210. Args:
  211. coros: Coroutines to race.
  212. """
  213. async with open_task_group(race=True) as tg:
  214. for coro in coros:
  215. tg.start_soon(coro)