asyncio_manager.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. """Implements an asyncio thread suitable for internal wandb use."""
  2. from __future__ import annotations
  3. import asyncio
  4. import concurrent.futures
  5. import contextlib
  6. import logging
  7. import threading
  8. from collections.abc import Awaitable
  9. from typing import Callable, TypeVar
  10. from . import asyncio_compat
  11. _T = TypeVar("_T")
  12. _logger = logging.getLogger(__name__)
  13. class RunCancelledError(Exception):
  14. """A function passed to AsyncioManager.run() was cancelled."""
  15. class AlreadyJoinedError(Exception):
  16. """AsyncioManager.run() used after join()."""
  17. class AsyncioManager:
  18. """Manages a thread running an asyncio loop.
  19. The thread must be started using start() and should be joined using
  20. join(). The thread is a daemon thread, so if join() is not invoked,
  21. the asyncio work could end abruptly when all non-daemon threads exit.
  22. The run() method allows invoking an async function in the asyncio thread
  23. and waiting until it completes. The run_soon() method allows running
  24. an async function without waiting for it.
  25. Note that although tempting, it is **not** possible to write a safe
  26. run_in_loop() method that chooses whether to use run() or execute a function
  27. directly based on whether it's called from the asyncio thread: Suppose a
  28. function bad() holds a threading.Lock while using run_in_loop() and an
  29. asyncio task calling bad() is scheduled. If bad() is then invoked in a
  30. different thread that reaches run_in_loop(), the aforementioned asyncio task
  31. will deadlock. It is unreasonable to require that run_in_loop() never be
  32. called while holding a lock (which would apply to the callers of its
  33. callers, and so on), so it cannot safely exist.
  34. """
  35. def __init__(self) -> None:
  36. self._runner = asyncio_compat.CancellableRunner()
  37. self._thread = threading.Thread(
  38. target=self._main,
  39. name="wandb-AsyncioManager-main",
  40. daemon=True,
  41. )
  42. self._lock = threading.Lock()
  43. self._ready_event = threading.Event()
  44. """Whether asyncio primitives have been initialized."""
  45. self._joined = False
  46. """Whether join() has been called. Guarded by _lock."""
  47. self._loop: asyncio.AbstractEventLoop
  48. """A handle for interacting with the asyncio event loop."""
  49. self._done_event: asyncio.Event
  50. """Indicates to the asyncio loop that join() was called."""
  51. self._remaining_tasks = 0
  52. """The number of tasks remaining. Guarded by _lock."""
  53. self._task_finished_cond: asyncio.Condition
  54. """Signalled when _remaining_tasks is decremented."""
  55. def start(self) -> None:
  56. """Start the asyncio thread."""
  57. self._thread.start()
  58. def join(self) -> None:
  59. """Stop accepting new asyncio tasks and wait for the remaining ones."""
  60. try:
  61. with self._lock:
  62. # If join() was already called, block until the thread completes
  63. # and then return.
  64. if self._joined:
  65. self._thread.join()
  66. return
  67. self._joined = True
  68. # Wait until _loop and _done_event are initialized.
  69. self._ready_event.wait()
  70. # Set the done event. The main function will exit once all
  71. # tasks complete.
  72. self._loop.call_soon_threadsafe(self._done_event.set)
  73. self._thread.join()
  74. finally:
  75. # Any of the above may get interrupted by Ctrl+C, in which case we
  76. # should cancel all tasks, since join() can only be called once.
  77. # This only matters if the KeyboardInterrupt is suppressed.
  78. self._runner.cancel()
  79. def run(self, fn: Callable[[], Awaitable[_T]]) -> _T:
  80. """Run an async function to completion.
  81. The function is called in the asyncio thread. Blocks until start()
  82. is called. This raises an error if called inside an async function,
  83. and as a consequence, the caller may also not be called inside an
  84. async function.
  85. Args:
  86. fn: The function to run.
  87. Returns:
  88. The return value of fn.
  89. Raises:
  90. Exception: Any exception raised by fn.
  91. RunCancelledError: If fn is cancelled, particularly when join()
  92. is interrupted by Ctrl+C or if it otherwise cancels itself.
  93. AlreadyJoinedError: If join() was already called.
  94. ValueError: If called inside an async function.
  95. """
  96. self._ready_event.wait()
  97. if threading.current_thread().ident == self._thread.ident:
  98. raise ValueError("Cannot use run() inside async loop.")
  99. future = self._schedule(fn, daemon=False)
  100. try:
  101. return future.result()
  102. except concurrent.futures.CancelledError:
  103. raise RunCancelledError from None
  104. except KeyboardInterrupt:
  105. # If we're interrupted here, we only cancel this task rather than
  106. # cancelling all tasks like in join(). This only matters if the
  107. # interrupt is then suppressed (or delayed) in which case we
  108. # should let other tasks progress.
  109. future.cancel()
  110. raise
  111. def run_soon(
  112. self,
  113. fn: Callable[[], Awaitable[None]],
  114. *,
  115. daemon: bool = False,
  116. name: str | None = None,
  117. ) -> None:
  118. """Run an async function without waiting for it to complete.
  119. The function is called in the asyncio thread. Note that since that's
  120. a daemon thread, it will not get joined when the main thread exits,
  121. so fn can stop abruptly.
  122. Unlike run(), it is OK to call this inside an async function.
  123. Blocks until start() is called.
  124. Args:
  125. fn: The function to run.
  126. daemon: If true, join() will cancel fn after all non-daemon
  127. tasks complete. By default, join() blocks until fn
  128. completes.
  129. name: An optional name to give to long-running tasks which can
  130. appear in error traces and be useful to debugging.
  131. Raises:
  132. AlreadyJoinedError: If join() was already called.
  133. """
  134. # Wrap exceptions so that they're not printed to console.
  135. async def fn_wrap_exceptions() -> None:
  136. try:
  137. await fn()
  138. except Exception:
  139. _logger.exception("Uncaught exception in run_soon callback.")
  140. _ = self._schedule(fn_wrap_exceptions, daemon=daemon, name=name)
  141. def _schedule(
  142. self,
  143. fn: Callable[[], Awaitable[_T]],
  144. daemon: bool,
  145. name: str | None = None,
  146. ) -> concurrent.futures.Future[_T]:
  147. # Wait for _loop to be initialized.
  148. self._ready_event.wait()
  149. with self._lock:
  150. if self._joined:
  151. raise AlreadyJoinedError(
  152. "Cannot schedule tasks after join()." #
  153. + " Did you call wandb.teardown()?"
  154. )
  155. if not daemon:
  156. self._remaining_tasks += 1
  157. return asyncio.run_coroutine_threadsafe(
  158. self._wrap(fn, daemon=daemon, name=name),
  159. self._loop,
  160. )
  161. async def _wrap(
  162. self,
  163. fn: Callable[[], Awaitable[_T]],
  164. daemon: bool,
  165. name: str | None,
  166. ) -> _T:
  167. """Run fn to completion and possibly decrement _remaining tasks."""
  168. try:
  169. if name and (task := asyncio.current_task()):
  170. task.set_name(name)
  171. return await fn()
  172. finally:
  173. if not daemon:
  174. async with self._task_finished_cond:
  175. with self._lock:
  176. self._remaining_tasks -= 1
  177. self._task_finished_cond.notify_all()
  178. def _main(self) -> None:
  179. """Run the asyncio loop until join() is called and all tasks finish."""
  180. # A cancellation error is expected if join() is interrupted.
  181. #
  182. # Were it not suppressed, its stacktrace would get printed.
  183. with contextlib.suppress(asyncio_compat.RunnerCancelledError):
  184. self._runner.run(self._main_async)
  185. async def _main_async(self) -> None:
  186. """Wait until join() is called and all tasks finish."""
  187. self._loop = asyncio.get_running_loop()
  188. self._done_event = asyncio.Event()
  189. self._task_finished_cond = asyncio.Condition()
  190. self._ready_event.set()
  191. # Wait until done.
  192. await self._done_event.wait()
  193. # Wait for all tasks to complete.
  194. #
  195. # Once we exit, asyncio will cancel any leftover tasks.
  196. async with self._task_finished_cond:
  197. await self._task_finished_cond.wait_for(
  198. lambda: self._remaining_tasks <= 0,
  199. )