| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224 |
- """AsyncIO support for zmq
- Requires asyncio and Python 3.
- """
- # Copyright (c) PyZMQ Developers.
- # Distributed under the terms of the Modified BSD License.
- from __future__ import annotations
- import asyncio
- import selectors
- import sys
- import warnings
- from asyncio import Future, SelectorEventLoop
- from weakref import WeakKeyDictionary
- import zmq as _zmq
- from zmq import _future
- # registry of asyncio loop : selector thread
- _selectors: WeakKeyDictionary = WeakKeyDictionary()
- class ProactorSelectorThreadWarning(RuntimeWarning):
- """Warning class for notifying about the extra thread spawned by tornado
- We automatically support proactor via tornado's AddThreadSelectorEventLoop"""
- def _get_selector_windows(
- asyncio_loop,
- ) -> asyncio.AbstractEventLoop:
- """Get selector-compatible loop
- Returns an object with ``add_reader`` family of methods,
- either the loop itself or a SelectorThread instance.
- Workaround Windows proactor removal of
- *reader methods, which we need for zmq sockets.
- """
- if asyncio_loop in _selectors:
- return _selectors[asyncio_loop]
- # detect add_reader instead of checking for proactor?
- if hasattr(asyncio, "ProactorEventLoop") and isinstance(
- asyncio_loop,
- asyncio.ProactorEventLoop, # type: ignore
- ):
- try:
- from tornado.platform.asyncio import AddThreadSelectorEventLoop
- except ImportError:
- raise RuntimeError(
- "Proactor event loop does not implement add_reader family of methods required for zmq."
- " zmq will work with proactor if tornado >= 6.1 can be found."
- " Use `asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())`"
- " or install 'tornado>=6.1' to avoid this error."
- )
- warnings.warn(
- "Proactor event loop does not implement add_reader family of methods required for zmq."
- " Registering an additional selector thread for add_reader support via tornado."
- " Use `asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())`"
- " to avoid this warning.",
- RuntimeWarning,
- # stacklevel 5 matches most likely zmq.asyncio.Context().socket()
- stacklevel=5,
- )
- selector_loop = _selectors[asyncio_loop] = AddThreadSelectorEventLoop(
- asyncio_loop
- ) # type: ignore
- # patch loop.close to also close the selector thread
- loop_close = asyncio_loop.close
- def _close_selector_and_loop():
- # restore original before calling selector.close,
- # which in turn calls eventloop.close!
- asyncio_loop.close = loop_close
- _selectors.pop(asyncio_loop, None)
- selector_loop.close()
- asyncio_loop.close = _close_selector_and_loop # type: ignore # mypy bug - assign a function to method
- return selector_loop
- else:
- return asyncio_loop
- def _get_selector_noop(loop) -> asyncio.AbstractEventLoop:
- """no-op on non-Windows"""
- return loop
- if sys.platform == "win32":
- _get_selector = _get_selector_windows
- else:
- _get_selector = _get_selector_noop
- class _AsyncIO:
- _Future = Future
- _WRITE = selectors.EVENT_WRITE
- _READ = selectors.EVENT_READ
- def _default_loop(self):
- try:
- return asyncio.get_running_loop()
- except RuntimeError:
- warnings.warn(
- "No running event loop. zmq.asyncio should be used from within an asyncio loop.",
- RuntimeWarning,
- stacklevel=4,
- )
- # get_event_loop deprecated in 3.10:
- return asyncio.get_event_loop()
- class Poller(_AsyncIO, _future._AsyncPoller):
- """Poller returning asyncio.Future for poll results."""
- def _watch_raw_socket(self, loop, socket, evt, f):
- """Schedule callback for a raw socket"""
- selector = _get_selector(loop)
- if evt & self._READ:
- selector.add_reader(socket, lambda *args: f())
- if evt & self._WRITE:
- selector.add_writer(socket, lambda *args: f())
- def _unwatch_raw_sockets(self, loop, *sockets):
- """Unschedule callback for a raw socket"""
- selector = _get_selector(loop)
- for socket in sockets:
- selector.remove_reader(socket)
- selector.remove_writer(socket)
- class Socket(_AsyncIO, _future._AsyncSocket):
- """Socket returning asyncio Futures for send/recv/poll methods."""
- _poller_class = Poller
- def _get_selector(self, io_loop=None):
- if io_loop is None:
- io_loop = self._get_loop()
- return _get_selector(io_loop)
- def _init_io_state(self, io_loop=None):
- """initialize the ioloop event handler"""
- self._get_selector(io_loop).add_reader(
- self._fd, lambda: self._handle_events(0, 0)
- )
- def _clear_io_state(self):
- """clear any ioloop event handler
- called once at close
- """
- loop = self._current_loop
- if loop and not loop.is_closed() and self._fd != -1:
- self._get_selector(loop).remove_reader(self._fd)
- Poller._socket_class = Socket
- class Context(_zmq.Context[Socket]):
- """Context for creating asyncio-compatible Sockets"""
- _socket_class = Socket
- # avoid sharing instance with base Context class
- _instance = None
- # overload with no changes to satisfy pyright
- def __init__(
- self: Context,
- io_threads: int | _zmq.Context = 1,
- shadow: _zmq.Context | int = 0,
- ) -> None:
- super().__init__(io_threads, shadow) # type: ignore
- class ZMQEventLoop(SelectorEventLoop):
- """DEPRECATED: AsyncIO eventloop using zmq_poll.
- pyzmq sockets should work with any asyncio event loop as of pyzmq 17.
- """
- def __init__(self, selector=None):
- _deprecated()
- return super().__init__(selector)
- _loop = None
- def _deprecated():
- if _deprecated.called: # type: ignore
- return
- _deprecated.called = True # type: ignore
- warnings.warn(
- "ZMQEventLoop and zmq.asyncio.install are deprecated in pyzmq 17. Special eventloop integration is no longer needed.",
- DeprecationWarning,
- stacklevel=3,
- )
- _deprecated.called = False # type: ignore
- def install():
- """DEPRECATED: No longer needed in pyzmq 17"""
- _deprecated()
- __all__ = [
- "Context",
- "Socket",
- "Poller",
- "ZMQEventLoop",
- "install",
- ]
|