| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- """Patch asyncio to allow nested event loops."""
- import asyncio
- import asyncio.events as events
- import os
- import sys
- import threading
- from contextlib import contextmanager, suppress
- from heapq import heappop
- def apply(loop=None):
- """Patch asyncio to make its event loop reentrant."""
- _patch_asyncio()
- _patch_policy()
- _patch_tornado()
- loop = loop or asyncio.get_event_loop()
- _patch_loop(loop)
- def _patch_asyncio():
- """Patch asyncio module to use pure Python tasks and futures."""
- def run(main, *, debug=False):
- loop = asyncio.get_event_loop()
- loop.set_debug(debug)
- task = asyncio.ensure_future(main)
- try:
- return loop.run_until_complete(task)
- finally:
- if not task.done():
- task.cancel()
- with suppress(asyncio.CancelledError):
- loop.run_until_complete(task)
- def _get_event_loop(stacklevel=3):
- loop = events._get_running_loop()
- if loop is None:
- loop = events.get_event_loop_policy().get_event_loop()
- return loop
- # Use module level _current_tasks, all_tasks and patch run method.
- if hasattr(asyncio, '_nest_patched'):
- return
- if sys.version_info >= (3, 6, 0):
- asyncio.Task = asyncio.tasks._CTask = asyncio.tasks.Task = \
- asyncio.tasks._PyTask
- asyncio.Future = asyncio.futures._CFuture = asyncio.futures.Future = \
- asyncio.futures._PyFuture
- if sys.version_info < (3, 7, 0):
- asyncio.tasks._current_tasks = asyncio.tasks.Task._current_tasks
- asyncio.all_tasks = asyncio.tasks.Task.all_tasks
- if sys.version_info >= (3, 9, 0):
- events._get_event_loop = events.get_event_loop = \
- asyncio.get_event_loop = _get_event_loop
- asyncio.run = run
- asyncio._nest_patched = True
- def _patch_policy():
- """Patch the policy to always return a patched loop."""
- def get_event_loop(self):
- if self._local._loop is None:
- loop = self.new_event_loop()
- _patch_loop(loop)
- self.set_event_loop(loop)
- return self._local._loop
- policy = events.get_event_loop_policy()
- policy.__class__.get_event_loop = get_event_loop
- def _patch_loop(loop):
- """Patch loop to make it reentrant."""
- def run_forever(self):
- with manage_run(self), manage_asyncgens(self):
- while True:
- self._run_once()
- if self._stopping:
- break
- self._stopping = False
- def run_until_complete(self, future):
- with manage_run(self):
- f = asyncio.ensure_future(future, loop=self)
- if f is not future:
- f._log_destroy_pending = False
- while not f.done():
- self._run_once()
- if self._stopping:
- break
- if not f.done():
- raise RuntimeError(
- 'Event loop stopped before Future completed.')
- return f.result()
- def _run_once(self):
- """
- Simplified re-implementation of asyncio's _run_once that
- runs handles as they become ready.
- """
- ready = self._ready
- scheduled = self._scheduled
- while scheduled and scheduled[0]._cancelled:
- heappop(scheduled)
- timeout = (
- 0 if ready or self._stopping
- else min(max(
- scheduled[0]._when - self.time(), 0), 86400) if scheduled
- else None)
- event_list = self._selector.select(timeout)
- self._process_events(event_list)
- end_time = self.time() + self._clock_resolution
- while scheduled and scheduled[0]._when < end_time:
- handle = heappop(scheduled)
- ready.append(handle)
- for _ in range(len(ready)):
- if not ready:
- break
- handle = ready.popleft()
- if not handle._cancelled:
- # preempt the current task so that that checks in
- # Task.__step do not raise
- curr_task = curr_tasks.pop(self, None)
- try:
- handle._run()
- finally:
- # restore the current task
- if curr_task is not None:
- curr_tasks[self] = curr_task
- handle = None
- @contextmanager
- def manage_run(self):
- """Set up the loop for running."""
- self._check_closed()
- old_thread_id = self._thread_id
- old_running_loop = events._get_running_loop()
- try:
- self._thread_id = threading.get_ident()
- events._set_running_loop(self)
- self._num_runs_pending += 1
- if self._is_proactorloop:
- if self._self_reading_future is None:
- self.call_soon(self._loop_self_reading)
- yield
- finally:
- self._thread_id = old_thread_id
- events._set_running_loop(old_running_loop)
- self._num_runs_pending -= 1
- if self._is_proactorloop:
- if (self._num_runs_pending == 0
- and self._self_reading_future is not None):
- ov = self._self_reading_future._ov
- self._self_reading_future.cancel()
- if ov is not None:
- self._proactor._unregister(ov)
- self._self_reading_future = None
- @contextmanager
- def manage_asyncgens(self):
- if not hasattr(sys, 'get_asyncgen_hooks'):
- # Python version is too old.
- return
- old_agen_hooks = sys.get_asyncgen_hooks()
- try:
- self._set_coroutine_origin_tracking(self._debug)
- if self._asyncgens is not None:
- sys.set_asyncgen_hooks(
- firstiter=self._asyncgen_firstiter_hook,
- finalizer=self._asyncgen_finalizer_hook)
- yield
- finally:
- self._set_coroutine_origin_tracking(False)
- if self._asyncgens is not None:
- sys.set_asyncgen_hooks(*old_agen_hooks)
- def _check_running(self):
- """Do not throw exception if loop is already running."""
- pass
- if hasattr(loop, '_nest_patched'):
- return
- if not isinstance(loop, asyncio.BaseEventLoop):
- raise ValueError('Can\'t patch loop of type %s' % type(loop))
- cls = loop.__class__
- cls.run_forever = run_forever
- cls.run_until_complete = run_until_complete
- cls._run_once = _run_once
- cls._check_running = _check_running
- cls._check_runnung = _check_running # typo in Python 3.7 source
- cls._num_runs_pending = 1 if loop.is_running() else 0
- cls._is_proactorloop = (
- os.name == 'nt' and issubclass(cls, asyncio.ProactorEventLoop))
- if sys.version_info < (3, 7, 0):
- cls._set_coroutine_origin_tracking = cls._set_coroutine_wrapper
- curr_tasks = asyncio.tasks._current_tasks \
- if sys.version_info >= (3, 7, 0) else asyncio.Task._current_tasks
- cls._nest_patched = True
- def _patch_tornado():
- """
- If tornado is imported before nest_asyncio, make tornado aware of
- the pure-Python asyncio Future.
- """
- if 'tornado' in sys.modules:
- import tornado.concurrent as tc # type: ignore
- tc.Future = asyncio.Future
- if asyncio.Future not in tc.FUTURES:
- tc.FUTURES += (asyncio.Future,)
|