| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688 |
- # Derived from iostream.py from tornado 1.0, Copyright 2009 Facebook
- # Used under Apache License Version 2.0
- #
- # Modifications are Copyright (C) PyZMQ Developers
- # Distributed under the terms of the Modified BSD License.
- """A utility class for event-based messaging on a zmq socket using tornado.
- .. seealso::
- - :mod:`zmq.asyncio`
- - :mod:`zmq.eventloop.future`
- """
- from __future__ import annotations
- import asyncio
- import pickle
- import warnings
- from queue import Queue
- from typing import Any, Awaitable, Callable, Literal, Sequence, cast, overload
- from tornado.ioloop import IOLoop
- from tornado.log import gen_log
- import zmq
- import zmq._future
- from zmq import POLLIN, POLLOUT
- from zmq.utils import jsonapi
- class ZMQStream:
- """A utility class to register callbacks when a zmq socket sends and receives
- For use with tornado IOLoop.
- There are three main methods
- Methods:
- * **on_recv(callback, copy=True):**
- register a callback to be run every time the socket has something to receive
- * **on_send(callback):**
- register a callback to be run every time you call send
- * **send_multipart(self, msg, flags=0, copy=False, callback=None):**
- perform a send that will trigger the callback
- if callback is passed, on_send is also called.
- There are also send_multipart(), send_json(), send_pyobj()
- Three other methods for deactivating the callbacks:
- * **stop_on_recv():**
- turn off the recv callback
- * **stop_on_send():**
- turn off the send callback
- which simply call ``on_<evt>(None)``.
- The entire socket interface, excluding direct recv methods, is also
- provided, primarily through direct-linking the methods.
- e.g.
- >>> stream.bind is stream.socket.bind
- True
- .. versionadded:: 25
- send/recv callbacks can be coroutines.
- .. versionchanged:: 25
- ZMQStreams only support base zmq.Socket classes (this has always been true, but not enforced).
- If ZMQStreams are created with e.g. async Socket subclasses,
- a RuntimeWarning will be shown,
- and the socket cast back to the default zmq.Socket
- before connecting events.
- Previously, using async sockets (or any zmq.Socket subclass) would result in undefined behavior for the
- arguments passed to callback functions.
- Now, the callback functions reliably get the return value of the base `zmq.Socket` send/recv_multipart methods
- (the list of message frames).
- """
- socket: zmq.Socket
- io_loop: IOLoop
- poller: zmq.Poller
- _send_queue: Queue
- _recv_callback: Callable | None
- _send_callback: Callable | None
- _close_callback: Callable | None
- _state: int = 0
- _flushed: bool = False
- _recv_copy: bool = False
- _fd: int
- def __init__(self, socket: zmq.Socket, io_loop: IOLoop | None = None):
- if isinstance(socket, zmq._future._AsyncSocket):
- warnings.warn(
- f"""ZMQStream only supports the base zmq.Socket class.
- Use zmq.Socket(shadow=other_socket)
- or `ctx.socket(zmq.{socket._type_name}, socket_class=zmq.Socket)`
- to create a base zmq.Socket object,
- no matter what other kind of socket your Context creates.
- """,
- RuntimeWarning,
- stacklevel=2,
- )
- # shadow back to base zmq.Socket,
- # otherwise callbacks like `on_recv` will get the wrong types.
- socket = zmq.Socket(shadow=socket)
- self.socket = socket
- # IOLoop.current() is deprecated if called outside the event loop
- # that means
- self.io_loop = io_loop or IOLoop.current()
- self.poller = zmq.Poller()
- self._fd = cast(int, self.socket.FD)
- self._send_queue = Queue()
- self._recv_callback = None
- self._send_callback = None
- self._close_callback = None
- self._recv_copy = False
- self._flushed = False
- self._state = 0
- self._init_io_state()
- # shortcircuit some socket methods
- self.bind = self.socket.bind
- self.bind_to_random_port = self.socket.bind_to_random_port
- self.connect = self.socket.connect
- self.setsockopt = self.socket.setsockopt
- self.getsockopt = self.socket.getsockopt
- self.setsockopt_string = self.socket.setsockopt_string
- self.getsockopt_string = self.socket.getsockopt_string
- self.setsockopt_unicode = self.socket.setsockopt_unicode
- self.getsockopt_unicode = self.socket.getsockopt_unicode
- def stop_on_recv(self):
- """Disable callback and automatic receiving."""
- return self.on_recv(None)
- def stop_on_send(self):
- """Disable callback on sending."""
- return self.on_send(None)
- def stop_on_err(self):
- """DEPRECATED, does nothing"""
- gen_log.warn("on_err does nothing, and will be removed")
- def on_err(self, callback: Callable):
- """DEPRECATED, does nothing"""
- gen_log.warn("on_err does nothing, and will be removed")
- @overload
- def on_recv(
- self,
- callback: Callable[[list[bytes]], Any],
- ) -> None: ...
- @overload
- def on_recv(
- self,
- callback: Callable[[list[bytes]], Any],
- copy: Literal[True],
- ) -> None: ...
- @overload
- def on_recv(
- self,
- callback: Callable[[list[zmq.Frame]], Any],
- copy: Literal[False],
- ) -> None: ...
- @overload
- def on_recv(
- self,
- callback: Callable[[list[zmq.Frame]], Any] | Callable[[list[bytes]], Any],
- copy: bool = ...,
- ): ...
- def on_recv(
- self,
- callback: Callable[[list[zmq.Frame]], Any] | Callable[[list[bytes]], Any],
- copy: bool = True,
- ) -> None:
- """Register a callback for when a message is ready to recv.
- There can be only one callback registered at a time, so each
- call to `on_recv` replaces previously registered callbacks.
- on_recv(None) disables recv event polling.
- Use on_recv_stream(callback) instead, to register a callback that will receive
- both this ZMQStream and the message, instead of just the message.
- Parameters
- ----------
- callback : callable
- callback must take exactly one argument, which will be a
- list, as returned by socket.recv_multipart()
- if callback is None, recv callbacks are disabled.
- copy : bool
- copy is passed directly to recv, so if copy is False,
- callback will receive Message objects. If copy is True,
- then callback will receive bytes/str objects.
- Returns : None
- """
- self._check_closed()
- assert callback is None or callable(callback)
- self._recv_callback = callback
- self._recv_copy = copy
- if callback is None:
- self._drop_io_state(zmq.POLLIN)
- else:
- self._add_io_state(zmq.POLLIN)
- @overload
- def on_recv_stream(
- self,
- callback: Callable[[ZMQStream, list[bytes]], Any],
- ) -> None: ...
- @overload
- def on_recv_stream(
- self,
- callback: Callable[[ZMQStream, list[bytes]], Any],
- copy: Literal[True],
- ) -> None: ...
- @overload
- def on_recv_stream(
- self,
- callback: Callable[[ZMQStream, list[zmq.Frame]], Any],
- copy: Literal[False],
- ) -> None: ...
- @overload
- def on_recv_stream(
- self,
- callback: (
- Callable[[ZMQStream, list[zmq.Frame]], Any]
- | Callable[[ZMQStream, list[bytes]], Any]
- ),
- copy: bool = ...,
- ): ...
- def on_recv_stream(
- self,
- callback: (
- Callable[[ZMQStream, list[zmq.Frame]], Any]
- | Callable[[ZMQStream, list[bytes]], Any]
- ),
- copy: bool = True,
- ):
- """Same as on_recv, but callback will get this stream as first argument
- callback must take exactly two arguments, as it will be called as::
- callback(stream, msg)
- Useful when a single callback should be used with multiple streams.
- """
- if callback is None:
- self.stop_on_recv()
- else:
- def stream_callback(msg):
- return callback(self, msg)
- self.on_recv(stream_callback, copy=copy)
- def on_send(
- self, callback: Callable[[Sequence[Any], zmq.MessageTracker | None], Any]
- ):
- """Register a callback to be called on each send
- There will be two arguments::
- callback(msg, status)
- * `msg` will be the list of sendable objects that was just sent
- * `status` will be the return result of socket.send_multipart(msg) -
- MessageTracker or None.
- Non-copying sends return a MessageTracker object whose
- `done` attribute will be True when the send is complete.
- This allows users to track when an object is safe to write to
- again.
- The second argument will always be None if copy=True
- on the send.
- Use on_send_stream(callback) to register a callback that will be passed
- this ZMQStream as the first argument, in addition to the other two.
- on_send(None) disables recv event polling.
- Parameters
- ----------
- callback : callable
- callback must take exactly two arguments, which will be
- the message being sent (always a list),
- and the return result of socket.send_multipart(msg) -
- MessageTracker or None.
- if callback is None, send callbacks are disabled.
- """
- self._check_closed()
- assert callback is None or callable(callback)
- self._send_callback = callback
- def on_send_stream(
- self,
- callback: Callable[[ZMQStream, Sequence[Any], zmq.MessageTracker | None], Any],
- ):
- """Same as on_send, but callback will get this stream as first argument
- Callback will be passed three arguments::
- callback(stream, msg, status)
- Useful when a single callback should be used with multiple streams.
- """
- if callback is None:
- self.stop_on_send()
- else:
- self.on_send(lambda msg, status: callback(self, msg, status))
- def send(self, msg, flags=0, copy=True, track=False, callback=None, **kwargs):
- """Send a message, optionally also register a new callback for sends.
- See zmq.socket.send for details.
- """
- return self.send_multipart(
- [msg], flags=flags, copy=copy, track=track, callback=callback, **kwargs
- )
- def send_multipart(
- self,
- msg: Sequence[Any],
- flags: int = 0,
- copy: bool = True,
- track: bool = False,
- callback: Callable | None = None,
- **kwargs: Any,
- ) -> None:
- """Send a multipart message, optionally also register a new callback for sends.
- See zmq.socket.send_multipart for details.
- """
- kwargs.update(dict(flags=flags, copy=copy, track=track))
- self._send_queue.put((msg, kwargs))
- callback = callback or self._send_callback
- if callback is not None:
- self.on_send(callback)
- else:
- # noop callback
- self.on_send(lambda *args: None)
- self._add_io_state(zmq.POLLOUT)
- def send_string(
- self,
- u: str,
- flags: int = 0,
- encoding: str = 'utf-8',
- callback: Callable | None = None,
- **kwargs: Any,
- ):
- """Send a unicode message with an encoding.
- See zmq.socket.send_unicode for details.
- """
- if not isinstance(u, str):
- raise TypeError("unicode/str objects only")
- return self.send(u.encode(encoding), flags=flags, callback=callback, **kwargs)
- send_unicode = send_string
- def send_json(
- self,
- obj: Any,
- flags: int = 0,
- callback: Callable | None = None,
- **kwargs: Any,
- ):
- """Send json-serialized version of an object.
- See zmq.socket.send_json for details.
- """
- msg = jsonapi.dumps(obj)
- return self.send(msg, flags=flags, callback=callback, **kwargs)
- def send_pyobj(
- self,
- obj: Any,
- flags: int = 0,
- protocol: int = -1,
- callback: Callable | None = None,
- **kwargs: Any,
- ):
- """Send a Python object as a message using pickle to serialize.
- See zmq.socket.send_json for details.
- """
- msg = pickle.dumps(obj, protocol)
- return self.send(msg, flags, callback=callback, **kwargs)
- def _finish_flush(self):
- """callback for unsetting _flushed flag."""
- self._flushed = False
- def flush(self, flag: int = zmq.POLLIN | zmq.POLLOUT, limit: int | None = None):
- """Flush pending messages.
- This method safely handles all pending incoming and/or outgoing messages,
- bypassing the inner loop, passing them to the registered callbacks.
- A limit can be specified, to prevent blocking under high load.
- flush will return the first time ANY of these conditions are met:
- * No more events matching the flag are pending.
- * the total number of events handled reaches the limit.
- Note that if ``flag|POLLIN != 0``, recv events will be flushed even if no callback
- is registered, unlike normal IOLoop operation. This allows flush to be
- used to remove *and ignore* incoming messages.
- Parameters
- ----------
- flag : int
- default=POLLIN|POLLOUT
- 0MQ poll flags.
- If flag|POLLIN, recv events will be flushed.
- If flag|POLLOUT, send events will be flushed.
- Both flags can be set at once, which is the default.
- limit : None or int, optional
- The maximum number of messages to send or receive.
- Both send and recv count against this limit.
- Returns
- -------
- int :
- count of events handled (both send and recv)
- """
- self._check_closed()
- # unset self._flushed, so callbacks will execute, in case flush has
- # already been called this iteration
- already_flushed = self._flushed
- self._flushed = False
- # initialize counters
- count = 0
- def update_flag():
- """Update the poll flag, to prevent registering POLLOUT events
- if we don't have pending sends."""
- return flag & zmq.POLLIN | (self.sending() and flag & zmq.POLLOUT)
- flag = update_flag()
- if not flag:
- # nothing to do
- return 0
- self.poller.register(self.socket, flag)
- events = self.poller.poll(0)
- while events and (not limit or count < limit):
- s, event = events[0]
- if event & POLLIN: # receiving
- self._handle_recv()
- count += 1
- if self.socket is None:
- # break if socket was closed during callback
- break
- if event & POLLOUT and self.sending():
- self._handle_send()
- count += 1
- if self.socket is None:
- # break if socket was closed during callback
- break
- flag = update_flag()
- if flag:
- self.poller.register(self.socket, flag)
- events = self.poller.poll(0)
- else:
- events = []
- if count: # only bypass loop if we actually flushed something
- # skip send/recv callbacks this iteration
- self._flushed = True
- # reregister them at the end of the loop
- if not already_flushed: # don't need to do it again
- self.io_loop.add_callback(self._finish_flush)
- elif already_flushed:
- self._flushed = True
- # update ioloop poll state, which may have changed
- self._rebuild_io_state()
- return count
- def set_close_callback(self, callback: Callable | None):
- """Call the given callback when the stream is closed."""
- self._close_callback = callback
- def close(self, linger: int | None = None) -> None:
- """Close this stream."""
- if self.socket is not None:
- if self.socket.closed:
- # fallback on raw fd for closed sockets
- # hopefully this happened promptly after close,
- # otherwise somebody else may have the FD
- warnings.warn(
- f"Unregistering FD {self._fd} after closing socket. "
- "This could result in unregistering handlers for the wrong socket. "
- "Please use stream.close() instead of closing the socket directly.",
- stacklevel=2,
- )
- self.io_loop.remove_handler(self._fd)
- else:
- self.io_loop.remove_handler(self.socket)
- self.socket.close(linger)
- self.socket = None # type: ignore
- if self._close_callback:
- self._run_callback(self._close_callback)
- def receiving(self) -> bool:
- """Returns True if we are currently receiving from the stream."""
- return self._recv_callback is not None
- def sending(self) -> bool:
- """Returns True if we are currently sending to the stream."""
- return not self._send_queue.empty()
- def closed(self) -> bool:
- if self.socket is None:
- return True
- if self.socket.closed:
- # underlying socket has been closed, but not by us!
- # trigger our cleanup
- self.close()
- return True
- return False
- def _run_callback(self, callback, *args, **kwargs):
- """Wrap running callbacks in try/except to allow us to
- close our socket."""
- try:
- f = callback(*args, **kwargs)
- if isinstance(f, Awaitable):
- f = asyncio.ensure_future(f)
- else:
- f = None
- except Exception:
- gen_log.error("Uncaught exception in ZMQStream callback", exc_info=True)
- # Re-raise the exception so that IOLoop.handle_callback_exception
- # can see it and log the error
- raise
- if f is not None:
- # handle async callbacks
- def _log_error(f):
- try:
- f.result()
- except Exception:
- gen_log.error(
- "Uncaught exception in ZMQStream callback", exc_info=True
- )
- f.add_done_callback(_log_error)
- def _handle_events(self, fd, events):
- """This method is the actual handler for IOLoop, that gets called whenever
- an event on my socket is posted. It dispatches to _handle_recv, etc."""
- if not self.socket:
- gen_log.warning("Got events for closed stream %s", self)
- return
- try:
- zmq_events = self.socket.EVENTS
- except zmq.ContextTerminated:
- gen_log.warning("Got events for stream %s after terminating context", self)
- # trigger close check, this will unregister callbacks
- self.closed()
- return
- except zmq.ZMQError as e:
- # run close check
- # shadow sockets may have been closed elsewhere,
- # which should show up as ENOTSOCK here
- if self.closed():
- gen_log.warning(
- "Got events for stream %s attached to closed socket: %s", self, e
- )
- else:
- gen_log.error("Error getting events for %s: %s", self, e)
- return
- try:
- # dispatch events:
- if zmq_events & zmq.POLLIN and self.receiving():
- self._handle_recv()
- if not self.socket:
- return
- if zmq_events & zmq.POLLOUT and self.sending():
- self._handle_send()
- if not self.socket:
- return
- # rebuild the poll state
- self._rebuild_io_state()
- except Exception:
- gen_log.error("Uncaught exception in zmqstream callback", exc_info=True)
- raise
- def _handle_recv(self):
- """Handle a recv event."""
- if self._flushed:
- return
- try:
- msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy)
- except zmq.ZMQError as e:
- if e.errno == zmq.EAGAIN:
- # state changed since poll event
- pass
- else:
- raise
- else:
- if self._recv_callback:
- callback = self._recv_callback
- self._run_callback(callback, msg)
- def _handle_send(self):
- """Handle a send event."""
- if self._flushed:
- return
- if not self.sending():
- gen_log.error("Shouldn't have handled a send event")
- return
- msg, kwargs = self._send_queue.get()
- try:
- status = self.socket.send_multipart(msg, **kwargs)
- except zmq.ZMQError as e:
- gen_log.error("SEND Error: %s", e)
- status = e
- if self._send_callback:
- callback = self._send_callback
- self._run_callback(callback, msg, status)
- def _check_closed(self):
- if not self.socket:
- raise OSError("Stream is closed")
- def _rebuild_io_state(self):
- """rebuild io state based on self.sending() and receiving()"""
- if self.socket is None:
- return
- state = 0
- if self.receiving():
- state |= zmq.POLLIN
- if self.sending():
- state |= zmq.POLLOUT
- self._state = state
- self._update_handler(state)
- def _add_io_state(self, state):
- """Add io_state to poller."""
- self._state = self._state | state
- self._update_handler(self._state)
- def _drop_io_state(self, state):
- """Stop poller from watching an io_state."""
- self._state = self._state & (~state)
- self._update_handler(self._state)
- def _update_handler(self, state):
- """Update IOLoop handler with state."""
- if self.socket is None:
- return
- if state & self.socket.events:
- # events still exist that haven't been processed
- # explicitly schedule handling to avoid missing events due to edge-triggered FDs
- self.io_loop.add_callback(lambda: self._handle_events(self.socket, 0))
- def _init_io_state(self):
- """initialize the ioloop event handler"""
- self.io_loop.add_handler(self.socket, self._handle_events, self.io_loop.READ)
|