trio_runner.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """A trio loop runner."""
  2. import builtins
  3. import logging
  4. import signal
  5. import threading
  6. import traceback
  7. import warnings
  8. import trio
  9. class TrioRunner:
  10. """A trio loop runner."""
  11. def __init__(self):
  12. """Initialize the runner."""
  13. self._cell_cancel_scope = None
  14. self._trio_token = None
  15. def initialize(self, kernel, io_loop):
  16. """Initialize the runner."""
  17. kernel.shell.set_trio_runner(self)
  18. kernel.shell.run_line_magic("autoawait", "trio")
  19. kernel.shell.magics_manager.magics["line"]["autoawait"] = lambda _: warnings.warn(
  20. "Autoawait isn't allowed in Trio background loop mode.", stacklevel=2
  21. )
  22. self._interrupted = False
  23. bg_thread = threading.Thread(target=io_loop.start, daemon=True, name="TornadoBackground")
  24. bg_thread.start()
  25. def interrupt(self, signum, frame):
  26. """Interrupt the runner."""
  27. if self._cell_cancel_scope:
  28. self._cell_cancel_scope.cancel()
  29. else:
  30. msg = "Kernel interrupted but no cell is running"
  31. raise Exception(msg)
  32. def run(self):
  33. """Run the loop."""
  34. old_sig = signal.signal(signal.SIGINT, self.interrupt)
  35. def log_nursery_exc(exc):
  36. exc = "\n".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
  37. logging.error("An exception occurred in a global nursery task.\n%s", exc)
  38. async def trio_main():
  39. """Run the main loop."""
  40. self._trio_token = trio.lowlevel.current_trio_token()
  41. async with trio.open_nursery() as nursery:
  42. # TODO This hack prevents the nursery from cancelling all child
  43. # tasks when an uncaught exception occurs, but it's ugly.
  44. nursery._add_exc = log_nursery_exc
  45. builtins.GLOBAL_NURSERY = nursery # type:ignore[attr-defined]
  46. await trio.sleep_forever()
  47. trio.run(trio_main)
  48. signal.signal(signal.SIGINT, old_sig)
  49. def __call__(self, async_fn):
  50. """Handle a function call."""
  51. async def loc(coro):
  52. """A thread runner context."""
  53. self._cell_cancel_scope = trio.CancelScope()
  54. with self._cell_cancel_scope:
  55. return await coro
  56. self._cell_cancel_scope = None # type:ignore[unreachable]
  57. return None
  58. return trio.from_thread.run(loc, async_fn, trio_token=self._trio_token)